In [None]:
import datatable as dt
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
from matplotlib.patches import Patch, RegularPolygon
from matplotlib import cm, colors
import matplotlib
import warnings

def read_file(file_path):
    # Read file
    try:
        file_delim = "," if file_path.endswith(".csv") else "\t"
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=pd.errors.ParserWarning)
            file_data = dt.fread(file_path, header=True)
            colnames = pd.read_csv(file_path, sep=file_delim, nrows=1, index_col=0).columns
            rownames = file_data[:, 0].to_pandas().values.flatten()
            file_data = file_data[:, 1:].to_pandas()
            file_data.index = rownames
            file_data.columns = colnames
   
    except Exception as e:
        raise IOError("Make sure you provided the correct path to input files. "
                      "The following input file formats are supported: .csv with comma ',' as "
                      "delimiter, .txt or .tsv with tab '\\t' as delimiter.")

    return file_data

def format_label(label,max_length=14,max_lines=3):
    label_words = label.split(' ')
    label_chunks = []
    chunk = ''
    for word in label_words:
        if len(chunk)==0:
            if len(word)>max_length:
                word = word[:max_length]+'...'
                label_chunks.append(word)
            else:
                chunk += word
        elif len(chunk)+len(word)<=(max_length-1):
            chunk = chunk+' '+word
        else:
            label_chunks.append(chunk)
            chunk = word
    if len(chunk)>0:
        label_chunks.append(chunk)
    
    if len(label_chunks)>max_lines:
        label_chunks = label_chunks[:max_lines]
        last_line = label_chunks[-1]
        if len(last_line)>(max_length-3):
            label_chunks[-1] = last_line[:max_length]+'...'
        else: 
            label_chunks[-1] = last_line+'...'

    return '\n'.join(label_chunks)

def rand_jitter(arr,interval):
    return arr + np.random.uniform(-interval/2,interval/2,len(arr))

def  my_plot_results_bulk_ST_by_spot(assigned_locations, coordinates_data, dir_out, output_prefix, geometry='honeycomb', num_cols=3):
 
    # Define output files
    fout_png_all = os.path.join(dir_out, f"{output_prefix}_cell_type_assignments_by_spot.png")
    fout_pdf_all = os.path.join(dir_out, f"{output_prefix}_cell_type_assignments_by_spot.pdf")
    if output_prefix=='Predict':
        metadata = assigned_locations.loc[:, ['Predict', 'CellType']].value_counts().unstack(fill_value=0)\
                        .reindex(index=assigned_locations.Predict.unique(), columns=assigned_locations.CellType.unique())
    elif output_prefix=='gt':
        metadata = assigned_locations.loc[:, ['SpotID', 'CellType']].value_counts().unstack(fill_value=0)\
                        .reindex(index=assigned_locations.SpotID.unique(), columns=assigned_locations.CellType.unique())
    # metadata = assigned_locations.loc[:, ['spot_index', 'discrete_label_ct']].value_counts().unstack(fill_value=0)\
    #     .reindex(index=assigned_locations.spot_index.unique(), columns=assigned_locations.discrete_label_ct.unique())
    # metadata['Total cells'] = metadata.sum(axis=1)
    metadata = metadata.astype(int)
    cell_types = list(metadata.columns)[:-1]
    cell_types = list(np.sort(cell_types))
    #cell_types.insert(0,'Total cells')

    coordinates = coordinates_data.loc[metadata.index,:]
    X = coordinates.iloc[:,0]
    Y = coordinates.iloc[:,1]
    
    # distinguish between row/col indices and coordinates
    # based on range (500) and type (int vs. float) of values
    #判断Y是否小于500且是整数
    scale = Y.max() < 500 and ((Y - Y.round()).abs() < 1e-5).all()

    # define representative interval between each adjacent spot/point
    y_int = 1 if scale else np.median(np.unique(np.diff(np.sort(np.unique(Y)))))
    x_int = 1 if scale else np.median(np.unique(np.diff(np.sort(np.unique(X)))))
    print(scale)
    if geometry == 'honeycomb' and scale:
        print('Detecting row and column indexing of Visium data; rescaling for coordinates')
        
        # #Rotate
        # X_prev = X
        # Y_prev = Y
        # X = Y_prev
        # Y = 1-X_prev
        
        # # Rescale
        # Y = 1.75*Y

    elif geometry == 'square' and scale:
        print('Detecting row and column indexing of legacy ST data; rotating for coordinates')
        
        # Rotate
        # X_prev = X
        # Y_prev = Y
        # X = Y_prev
        # Y = 1-X_prev

    else:        
        # Rotate 
        Y = 1-Y
    #这段代码根据几何形状（蜂窝状或正方形）和标度确定六边形的半径和顶点数。如果几何形状为蜂窝状，则将
    # hex_vert设置为6，表示六边形的顶点数为6。如果scale为True，即数据是Visium数据，并且每个位置的坐标是整数，表示行和列的索引，则将
    # hex_rot设置为0，表示六边形的方向与坐标轴对齐。将hex_rad设置为y_int，表示六边形的半径等于Y坐标轴的间隔。然后将
    # hex_rad增加20％，以便六边形之间有一定的间隔。如果几何形状为正方形，则将hex_vert
    # 设置为4，表示六边形的顶点数为4。将hex_rot设置为45，表示六边形的方向相对于坐标轴旋转45度。将
    # hex_rad设置为0.5乘以相邻坐标之间的距离的平方根，以便六边形的大小适合于正方形的间隔。 
    if geometry == 'honeycomb':
        hex_vert = 6

        if scale:
            hex_rot = 0
            hex_rad = y_int
            hex_rad = 0.75*hex_rad

        else:
            hex_rot = 0
            hex_rad = x_int
            hex_rad = 1*hex_rad

    elif geometry == 'square':
        hex_vert = 4
        hex_rot = 45
        interval = y_int
        hex_rad = 1*np.sqrt(2*interval**2)

    else:
        print("Unknown geometry specified.")
        exit()

    num_rows = int(len(cell_types)/num_cols)
    if num_rows*num_cols < len(cell_types):
        num_rows = num_rows+1
    width = max(X)-min(X)
    height = max(Y)-min(Y)

    plt.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Arial'], 'size':'12'})
    plt.rcParams['figure.dpi'] = 450
    plt.rcParams['savefig.dpi'] = 450

    fig, axes = plt.subplots(num_rows,num_cols,figsize=(width/height*3*num_cols,3*num_rows))
    full_frac = 0.047 * (3*num_rows / (width/height*3*num_cols))
    k = 0
    for i in range(num_rows):
        for j in range(num_cols):

            ax = axes[i,j]

            if k >= len(cell_types):
                ax.axis('off')
            else:
                ct = cell_types[k]
                ax.set_aspect('equal')

                node_assignment = metadata.loc[:,ct]

                viridis = cm.get_cmap('viridis')
                norm = matplotlib.colors.Normalize(0,5)

                node_assignment = (1/max(node_assignment))*node_assignment*1.25
                colors = viridis(node_assignment)

                for x, y, c in zip(X, Y, colors):
                    hex = RegularPolygon((x, y), numVertices=hex_vert, radius=hex_rad, 
                                         orientation=np.radians(hex_rot), 
                                         facecolor=c, edgecolor=None)
                    ax.add_patch(hex)

                # Also add scatter points in hexagon centres - not sure why this line has to be here
                ax.scatter(X, Y, c=[c[0] for c in colors],alpha=0)
                ct_label = ct
                # Reformat labels that are too long
                if len(ct_label)>14:
                    ct_label = format_label(ct_label)
                ax.set_title(ct_label)
                cax = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=viridis),ax=ax,label='',fraction=0.036, pad = 0.04)
                ax.axis('off')

            k += 1

    fig.tight_layout()
    fig.subplots_adjust(top=0.92)
    fig.suptitle('Number of cells per spot mapped by CytoSPACE')        
    fig.savefig(fout_png_all, facecolor="w", bbox_inches='tight')
    fig.savefig(fout_pdf_all, facecolor="w", bbox_inches='tight')


def my_plot_results_bulk_ST_jitter(assigned_locations, dir_out, output_prefix, geometry="honeycomb",max_num_cells=50000):
    
    # Define output files
    fout_png_jitter = os.path.join(dir_out, f"{output_prefix}cell_type_assignments_by_spot_jitter.png")
    fout_pdf_jitter = os.path.join(dir_out, f"{output_prefix}cell_type_assignments_by_spot_jitter.pdf")

    if assigned_locations.shape[0] > max_num_cells:
        assigned_locations = assigned_locations.sample(max_num_cells)
    if output_prefix=='Predict':
        X = assigned_locations['predict_x']
        Y = assigned_locations['predict_y']
    elif output_prefix=='gt':
        X = assigned_locations['gt_x']
        Y = assigned_locations['gt_y']
    cell_types = assigned_locations['CellType'].values

    # distinguish between row/col indices and coordinates
    # based on range (500) and type (int vs. float) of values
    scale = Y.max() < 500 and ((Y - Y.round()).abs() < 1e-5).all()

    # define representative interval between each adjacent spot/point
    y_int = 1 if scale else np.median(np.unique(np.diff(np.sort(np.unique(Y)))))
    x_int = 1 if scale else np.median(np.unique(np.diff(np.sort(np.unique(X)))))

    print(scale)
    if geometry == 'honeycomb' and scale:
        print('Detecting row and column indexing of Visium data; rescaling for coordinates')
        
        # Rotate
        # X_prev = X
        # Y_prev = Y
        # X = Y_prev
        # Y = 1-X_prev
        
        # Rescale
        # Y = 1.75*Y
        y_interval = 0.75*x_int
        x_interval = 0.75*y_int

    elif geometry == 'square' and scale:
        print('Detecting row and column indexing of legacy ST data; rotating for coordinates')
        
        # Rotate
        X_prev = X
        Y_prev = Y
        X = Y_prev
        Y = 1-X_prev
        y_interval = 1.5*x_int
        x_interval = y_int

    else:
        # Rotate 
        Y = 1-Y
        y_interval = y_int
        x_interval = x_int

        
    X = rand_jitter(X.values,x_interval)
    Y = rand_jitter(Y.values,y_interval)


    plt.rc('font', **{'family': 'sans-serif', 'sans-serif': ['Arial'], 'size':'12'})
    plt.rcParams['figure.dpi'] = 450
    plt.rcParams['savefig.dpi'] = 450
    
    fig = plt.figure()
    ax = fig.add_subplot(111)

    ax.set_aspect('equal')
    colors = ["#222222", "#F3C300", "#875692", "#F38400", "#A1CAF1", "#BE0032", "#C2B280",
                  "#848482", "#008856", "#E68FAC", "#0067A5", "#F99379", "#604E97", "#F6A600", "#B3446C", 
                  "#DCD300", "#882D17", "#8DB600", "#654522", "#E25822", "#2B3D26", "#5A5156", "#E4E1E3", 
                  "#F6222E", "#FE00FA", "#16FF32", "#3283FE", "#FEAF16", "#B00068", "#1CFFCE", "#90AD1C",
                  "#2ED9FF", "#DEA0FD", "#AA0DFE", "#F8A19F", "#325A9B", "#C4451C", "#1C8356", "#85660D",
                  "#B10DA1", "#FBE426", "#1CBE4F", "#FA0087", "#FC1CBF", "#F7E1A0", "#C075A6", "#782AB6",
                  "#AAF400", "#BDCDFF", "#822E1C", "#B5EFB5", "#7ED7D1", "#1C7F93", "#D85FF7", "#683B79",
                  "#66B0FF", "#3B00FB"]
    
    unique_cell_types = np.unique(cell_types)
    color_per_ct = dict(zip(unique_cell_types,colors[:len(unique_cell_types)]))
    cell_type_colors = [color_per_ct[ct] for ct in cell_types]
    ax.scatter(X,Y,c=cell_type_colors,s=2.5)
    ax.axis('off')
    plt.title('Single cells mapped to tissue by CytoSPACE')
    fig.tight_layout()

    legend_elements = []
    for ct in unique_cell_types:
        legend_elements.append(Patch(facecolor=color_per_ct[ct],label=format_label(ct)))
        
    plt.legend(bbox_to_anchor=(1.05,1.0),handles=legend_elements)
    fig.savefig(fout_png_jitter, facecolor="w", bbox_inches='tight')
    fig.savefig(fout_pdf_jitter, facecolor="w", bbox_inches='tight')

Figure S3a

In [None]:
assigned_locations_path = "./data/regular_mapping_results/Cerebellum/regular_mapping_5cell_5%noise.csv"
assigned_locations = read_file(assigned_locations_path)
coordinates_path="./data/regular_mapping_results/Cerebellum/Coordinates.csv"
coordinates_data = read_file(coordinates_path)
ground_truth_path="./data/regular_mapping_results/Cerebellum/ground_truth.csv"
ground_truth = read_file(ground_truth_path)
gt_coordinates=coordinates_data.reindex(assigned_locations['Lable'])
assigned_locations['gt_x']=gt_coordinates['row'].values
assigned_locations['gt_y']=gt_coordinates['col'].values

## Ground truth
dir_out="./data/regular_mapping_results/Cerebellum/"
output_prefix="gt"
my_plot_results_bulk_ST_jitter(assigned_locations=assigned_locations, dir_out=dir_out, output_prefix=output_prefix)

## regular single-cell-to-spot mapping by Celloc
dir_out="./data/regular_mapping_results/Cerebellum/"
output_prefix="Predict"
my_plot_results_bulk_ST_jitter(assigned_locations=assigned_locations, dir_out=dir_out, output_prefix=output_prefix)

Figure S3b (Taking the results of ground truth and Celloc as examples, other methods can change the corresponding data path)

In [None]:
## Ground truth
dir_out="./data/regular_mapping_results/Cerebellum/"
output_prefix="gt"
my_plot_results_bulk_ST_by_spot(assigned_locations=ground_truth, coordinates_data=coordinates_data, dir_out=dir_out, output_prefix=output_prefix)

## Celloc
dir_out="./data/regular_mapping_results/Cerebellum/"
output_prefix="Predict"
my_plot_results_bulk_ST_by_spot(assigned_locations=assigned_locations, coordinates_data=coordinates_data, dir_out=dir_out, output_prefix=output_prefix)

Figure S4a

In [None]:
assigned_locations_path = "./data/regular_mapping_results/Hippocampus/regular_mapping_5cell_5%noise.csv"
assigned_locations = read_file(assigned_locations_path)
coordinates_path="./data/regular_mapping_results/Hippocampus/Coordinates.csv"
coordinates_data = read_file(coordinates_path)
ground_truth_path="./data/regular_mapping_results/Hippocampus/ground_truth.csv"
ground_truth = read_file(ground_truth_path)
gt_coordinates=coordinates_data.reindex(assigned_locations['Lable'])
assigned_locations['gt_x']=gt_coordinates['row'].values
assigned_locations['gt_y']=gt_coordinates['col'].values

## Ground truth
dir_out="./data/regular_mapping_results/Hippocampus/"
output_prefix="gt"
my_plot_results_bulk_ST_jitter(assigned_locations=assigned_locations, dir_out=dir_out, output_prefix=output_prefix)

## regular single-cell-to-spot mapping by Celloc
dir_out="./data/regular_mapping_results/Hippocampus/"
output_prefix="Predict"
my_plot_results_bulk_ST_jitter(assigned_locations=assigned_locations, dir_out=dir_out, output_prefix=output_prefix)

Figure S4b (Taking the results of ground truth and Celloc as examples, other methods can change the corresponding data path)

In [None]:
## Ground truth
dir_out="./data/regular_mapping_results/Hippocampus/"
output_prefix="gt"
my_plot_results_bulk_ST_by_spot(assigned_locations=ground_truth, coordinates_data=coordinates_data, dir_out=dir_out, output_prefix=output_prefix)

## Celloc
dir_out="./data/regular_mapping_results/Hippocampus/"
output_prefix="Predict"
my_plot_results_bulk_ST_by_spot(assigned_locations=assigned_locations, coordinates_data=coordinates_data, dir_out=dir_out, output_prefix=output_prefix)