In [1]:
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import itertools
from skimage import io
from scipy.sparse import csr_matrix # type: ignore
from anndata import AnnData
import cv2
import seaborn as sns
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
sc.settings.verbosity = 3
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import shapiro, norm
from scipy.signal import find_peaks
import os

%config InlineBackend.figure_format='retina'
%matplotlib inline
plt.rcParams['figure.figsize']=(7,4)


## Load stereos-seq

This module is used to read BGI data and image file, and return an AnnData object.

In [2]:
def load_bin(
    gem_file: str,
    image_file: str,
    bin_size: int,
    library_id: str,
) -> AnnData:
    """
    Read BGI data and image file, and return an AnnData object.
    Parameters
    ----------
    gem_file
        The path of the BGI data file.
    image_file
        The path of the image file.
    bin_size
        The size of the bin.
    library_id
        The library id.
    Returns
    -------
    Annotated data object with the following keys:

        - :attr:`anndata.AnnData.obsm` ``['spatial']`` - spatial spot coordinates.
        - :attr:`anndata.AnnData.uns` ``['spatial']['{library_id}']['images']`` - *hires* images.
        - :attr:`anndata.AnnData.uns` ``['spatial']['{library_id}']['scalefactors']`` - scale factors for the spots.
    """ # noqa: E501
    library = library_id
    dat_file = gem_file
    image = image_file
    bin_s = bin_size
    ###########################
    # different gem have different delimiter!!!!!!!
    # COAD: " " , other may be "\t"
    dat = pd.read_csv(dat_file, delimiter="\t", comment="#")
    
    image = cv2.imread(image)
    ######
    dat['x'] -= dat['x'].min()
    dat['y'] -= dat['y'].min()

    width = dat['x'].max() + 1
    height = dat['y'].max() + 1
    ###
    dat['xp'] = (dat['x'] // bin_s) * bin_s
    dat['yp'] = (dat['y'] // bin_s) * bin_s
    dat['xb'] = np.floor(dat['xp'] / bin_s + 1).astype(int)
    dat['yb'] = np.floor(dat['yp'] / bin_s + 1).astype(int)

    dat['bin_ID'] = max(dat['xb']) * (dat['yb'] - 1) + dat['xb']
    ###
    trans_x_xb = dat[['x', 'xb']].drop_duplicates()
    trans_x_xb = trans_x_xb.groupby('xb')['x'].apply(
        lambda x: int(np.floor(np.mean(x)))).reset_index()
    trans_y_yb = dat[['y', 'yb']].drop_duplicates()
    trans_y_yb = trans_y_yb.groupby('yb')['y'].apply(
        lambda y: int(np.floor(np.mean(y)))).reset_index()

    trans_matrix = pd.DataFrame(list(itertools.product(
        trans_x_xb['xb'], trans_y_yb['yb'])), columns=['xb', 'yb'])
    trans_matrix = pd.merge(trans_matrix, trans_x_xb, on='xb')
    trans_matrix = pd.merge(trans_matrix, trans_y_yb, on='yb')
    trans_matrix['bin_ID'] = max(
        trans_matrix['xb']) * (trans_matrix['yb'] - 1) + trans_matrix['xb']

    trans_matrix['in_tissue'] = 1

    tissue_positions = pd.DataFrame()
    # barcode is str, not number
    tissue_positions['barcodes'] = trans_matrix['bin_ID'].astype(str)
    tissue_positions['in_tissue'] = trans_matrix['in_tissue']
    tissue_positions['array_row'] = trans_matrix['yb']
    tissue_positions['array_col'] = trans_matrix['xb']
    tissue_positions['pxl_row_in_fullres'] = trans_matrix['y']
    tissue_positions['pxl_col_in_fullres'] = trans_matrix['x']
    tissue_positions.set_index('barcodes', inplace=True)

    ### 
    if 'MIDCount' in dat.columns:
        dat = dat.groupby(['geneID', 'xb', 'yb'])[
            'MIDCount'].sum().reset_index()
        dat['bin_ID'] = max(dat['xb']) * (dat['yb'] - 1) + dat['xb']

        ### 
        unique_genes = dat['geneID'].unique()
        unique_barcodes = dat['bin_ID'].unique()
        gene_hash = {gene: index for index, gene in enumerate(unique_genes)}
        barcodes_hash = {barcodes: index for index,
                         barcodes in enumerate(unique_barcodes)}
        dat['gene'] = dat['geneID'].map(gene_hash)
        dat['barcodes'] = dat['bin_ID'].map(barcodes_hash)

        ### 
        counts = csr_matrix((dat['MIDCount'], (dat['barcodes'], dat['gene'])))

    else:
        dat = dat.groupby(['geneID', 'xb', 'yb'])[
            'MIDCounts'].sum().reset_index()
        dat['bin_ID'] = max(dat['xb']) * (dat['yb'] - 1) + dat['xb']
        ###
        unique_genes = dat['geneID'].unique()
        unique_barcodes = dat['bin_ID'].unique()
        gene_hash = {gene: index for index, gene in enumerate(unique_genes)}
        barcodes_hash = {barcodes: index for index,
                         barcodes in enumerate(unique_barcodes)}
        dat['gene'] = dat['geneID'].map(gene_hash)
        dat['barcodes'] = dat['bin_ID'].map(barcodes_hash)

        ###
        counts = csr_matrix((dat['MIDCounts'], (dat['barcodes'], dat['gene'])))
    adata = AnnData(counts)
    adata.var_names = list(gene_hash.keys())
    adata.obs_names = list(map(str, barcodes_hash.keys()))
    ##########
    adata.obs = adata.obs.join(tissue_positions, how="left")
    adata.obsm['spatial'] = adata.obs[[
        'pxl_row_in_fullres', 'pxl_col_in_fullres']].to_numpy()
    adata.obs.drop(columns=['in_tissue', 'array_row', 'array_col',
                   'pxl_row_in_fullres', 'pxl_col_in_fullres'], inplace=True,)
    ###
    spatial_key = "spatial"
    adata.uns[spatial_key] = {library: {}}
    adata.uns[spatial_key][library]["images"] = {}
    adata.uns[spatial_key][library]["images"] = {"hires": image}
    # tissue image / RNA shape
    tissue_hires_scalef = max(image.shape[0]/width, image.shape[1]/height)

    # the diameter of detection area(the spot that contains tissue)
    # can be adjust out side by size= in scatter function
    spot_diameter = bin_s / tissue_hires_scalef
    
    #fiducial_area = max(tissue_positions['array_row'].max() - tissue_positions['array_row'].min(),
    #                    tissue_positions['array_col'].max() - tissue_positions['array_col'].min())
    adata.uns[spatial_key][library]["scalefactors"] = {
        "tissue_hires_scalef": tissue_hires_scalef,
        "spot_diameter_fullres": spot_diameter,
    }

    return adata
    

In [None]:

def load_cell(
    gem_file: str,
    image_file: str,
    mask_file: int,
    library_id: str,
) -> AnnData:
    """
    Read BGI data and image file, and return an AnnData object.
    Parameters
    ----------
    gem_file
        The path of the BGI data file.
    image_file
        The path of the image file.
    bin_size
        The size of the bin.
    library_id
        The library id.
    Returns
    -------
    Annotated data object with the following keys:

        - :attr:`anndata.AnnData.obsm` ``['spatial']`` - spatial spot coordinates.
        - :attr:`anndata.AnnData.uns` ``['spatial']['{library_id}']['images']`` - *hires* images.
        - :attr:`anndata.AnnData.uns` ``['spatial']['{library_id}']['scalefactors']`` - scale factors for the spots.
    """ # noqa: E501
    
    # mask = pd.read_csv(mask_file, delimiter=",")
    dat = pd.read_csv(gem_file, delimiter="\t", comment="#")
    image = io.imread(image_file)
    
    spatial_key = 'spatial'
    library = library_id

    mask = np.load(mask_file)
    mask_nozero = np.nonzero(mask)
    x = mask_nozero[0]; y = mask_nozero[1]
    value = [mask[x[i],y[i]] for i in range(len(x))]
    mask = pd.DataFrame({'x':y,'y':x,'barcodes':value})
    
    # stereoseq GEM xy is not the same as image xy
    # exchange gem xy!
    # !!!!!!!!!!!!!!!! 
    ##############y shall we? yes!
    # dat = dat.rename(columns={'x': 'temp'})
    # dat = dat.rename(columns={'y': 'x'})
    # dat = dat.rename(columns={'temp': 'y'})  
    # ######### 
    dat['x'] -= dat['x'].min()
    dat['y'] -= dat['y'].min()
    mask['x'] = mask['x'] - mask['x'].min()
    mask['y'] = mask['y'] - mask['y'].min() 
    # 20230717
    # dat['y'] = dat['y'].max() - dat['y'] # 为什么错了呢？
    # dat['y'] = dat['y'][::-1]
    
    mask_data = pd.merge(left = mask, right = dat, on=['x', 'y'],how = "inner")

    # mask_data for RNA
    # mask for celluar location
    exp = mask_data.groupby(['geneID', 'barcodes'])['MIDCount'].sum().reset_index()
    
    # construct count matrix
    unique_genes = exp['geneID'].unique()
    unique_barcodes = exp['barcodes'].unique()
    gene_hash = {gene: index for index, gene in enumerate(unique_genes)}
    barcodes_hash = {barcodes: index for index, barcodes in enumerate(unique_barcodes)}

    exp['gene'] = exp['geneID'].map(gene_hash)
    exp['barcodes'] = exp['barcodes'].map(barcodes_hash)
 
    counts = csr_matrix((exp['MIDCount'], (exp['barcodes'], exp['gene']))) 

    adata = AnnData(counts)
    adata.var_names = list(gene_hash.keys())
    adata.obs_names = list(map(str, barcodes_hash.keys()))

    # normalize mask coordinate to get mask and data overlap region
    # this is to ensure cell position start from left upper corner
    # according to mask, for we only care about 

    grouped_mask = mask.groupby('barcodes')
    transform_mtx = pd.DataFrame(columns=['barcodes', 'center_x', 'center_y'])

    for barcode, group in grouped_mask:
        x_mean = int(np.floor(group['x'].mean()))
        y_mean = int(np.floor(group['y'].mean()))
        
        # yanping 2023-07-03
        # 'center_x': x_mean, 'center_y': y_mean
        #transform_mtx = transform_mtx.append({'barcodes': str(barcode), 'center_x': x_mean, 'center_y': y_mean}, ignore_index=True)
        transform_mtx = transform_mtx._append({'barcodes': str(barcode), 'center_x': x_mean, 'center_y': y_mean}, ignore_index=True)
    
    # reset index
    transform_mtx.set_index('barcodes', inplace=True)

    adata.obs = adata.obs.join(transform_mtx, how="left")
    adata.obsm['spatial'] = adata.obs[[ "center_y","center_x"]].to_numpy()
    
    adata.obs.drop(columns=["center_x", "center_y"], inplace=True)
    
    adata.uns[spatial_key] = {library: {}}
    adata.uns[spatial_key][library]["images"] = {"hires": image}
    ######
    
    tissue_hires_scalef = max((mask['y'].max()+1)/image.shape[1], (mask['x'].max()+1)/image.shape[0])
    
    # spot_diameter could be set to *mean pixel of mask* / hires_scalef
    # can be adjust out side by size= in scatter function
    adata.uns[spatial_key][library]["scalefactors"] = {
         "tissue_hires_scalef": tissue_hires_scalef,
         "spot_diameter_fullres": 250,
    }
    
    return adata

## Define functions

In [None]:
def calculate_composition(adata, groupby_key, category_key):
    """
    Calculate the composition of categories within each group.

    Parameters:
    adata: AnnData object
    groupby_key: Key in adata.obs to group by (e.g., 'louvain')
    category_key: Key in adata.obs for categories (e.g., 'annotations')

    Returns:
    Pandas DataFrame with composition data.
    """
    # Create a DataFrame from the AnnData object
    data = pd.DataFrame(adata.obs)
    # Calculate composition
    composition = data.groupby(groupby_key)[category_key].value_counts(normalize=True).unstack(fill_value=0)
    return composition


In [None]:

def plot_composition(composition_df, title, x, y, palette='viridis'):
    """
    Plot the composition data as a stacked bar plot.
    Parameters:
    composition_df: DataFrame with composition data
    title: Title for the plot
    """
    if isinstance(palette, str):
        cmap = plt.get_cmap(palette)
        colors = cmap(np.linspace(0, 1, composition_df.shape[1]))
    elif isinstance(palette, list):
        colors = palette
    else:
        raise ValueError("Palette should be a string (colormap name) or a list of colors")
    
    # Create the plot
    ax = composition_df.plot(kind='bar', stacked=True, figsize=(6, 4), color=colors, fontsize=8)
    
    # Set title and labels
    plt.title(title)
    plt.ylabel('Proportion')
    plt.xlabel(x)

    # Customize the legend
    plt.legend(title=y, bbox_to_anchor=(1.05, 1), loc='upper left')

    # Remove the top and right spines
    sns.despine()

    # Optional: Remove grid lines
    ax.yaxis.grid(False)  # Remove horizontal grid lines
    ax.xaxis.grid(False)  # Remove vertical grid lines

    plt.tight_layout()
    plt.show()
    

In [None]:
from skimage.morphology import dilation, square
from tqdm import tqdm


def compute_relative_abundance(adata, dict_use):
    relative_abundances = {}
    for cell_type, markers in tqdm(dict_use.items()):
        avg_expressions = []
        for marker in markers:
            if marker not in adata.var_names:
                continue
            expressed_values = adata[:, marker].X[adata[:, marker].X > 0]
            # Check if there are any expressed values
            if isinstance(expressed_values, np.ndarray):
                num_nonzero = len(expressed_values)
            else:  # Assuming it's a sparse matrix
                num_nonzero = expressed_values.getnnz()
            if num_nonzero > 0:
                avg_expressions.append(np.mean(expressed_values))
            else:
                avg_expressions.append(0)
        relative_abundance = np.log10(sum(avg_expressions) + 1e-10)
        relative_abundances[cell_type] = 1 / relative_abundance
    return relative_abundances


def annotate_cells_stage(marker_dict_use, adata, cells_to_annotate=None):
    WCT = compute_relative_abundance(adata, marker_dict_use)
    annotations = {}
    unannotated_cells = []

    if cells_to_annotate is None:
        cells_to_annotate = adata.obs_names
        data_submatrix = adata.X
    else:
        data_submatrix = adata[cells_to_annotate, :].X

    all_scores = np.zeros((len(cells_to_annotate), len(marker_dict_use)))
    print(all_scores.shape)
    for idx, (cell_type, markers) in enumerate(marker_dict_use.items()):
        valid_markers_indices = [
            adata.var_names.get_loc(marker)
            for marker in markers
            if marker in adata.var_names
        ]
        marker_matrix = data_submatrix[:, valid_markers_indices]
        presence_matrix = (marker_matrix > 0).astype(int)
        scores = presence_matrix.sum(axis=1) * WCT[cell_type]

        all_scores[:, idx] = scores.ravel()
    max_scores = np.max(all_scores, axis=1)
    max_score_indices = np.argmax(all_scores, axis=1)
    cell_types = list(marker_dict_use.keys())
    annotations = np.array(
        [
            cell_types[idx] if score > 0 else "Others"
            for idx, score in zip(max_score_indices, max_scores)
        ]
    )
    unannotated = [cells_to_annotate[i] for i in np.where(annotations == "Others")[0]]
    return dict(zip(cells_to_annotate, annotations)), unannotated

def visualize_results(adata, key="annotations"):
    # Plotting
    cell_types, counts = np.unique(adata.obs[key], return_counts=True)
    plt.bar(cell_types, counts)
    # Annotate each bar with its value
    for i, value in enumerate(counts):
        plt.text(i, value, str(value), ha="center", va="bottom", fontsize=8)
    plt.xticks(rotation=90)
    plt.ylabel("Number of Cells")
    plt.xlabel("Cell Types")
    plt.title("Cell Type Annotations")
    plt.show()

## Import cell merkers

In [None]:
import pickle
with open('./Files/markers.pkl', 'rb') as file: 
    marker_dict = pickle.load(file)

## Import data & cell type annotate

In [None]:
import os

# please input the sample names
samples = ['', '', '' ]


In [None]:

for sample in samples:
    print(sample)

    # input path
    gem_file = './CRLM/' + sample + "/results/"
    file_name = [x for x in os.listdir(gem_file) if "tissue.gem.gz" in x][0]
    gem_file = './CRLM/' + sample + "/results/" + file_name
    image_file = "./gem_data/" + sample + "_tissue.tif"
    library_id = "st"

    adata = ld.load_bin(gem_file, image_file, 20, library_id)

    adata_copy = adata.copy()
    sc.pp.filter_cells(adata, min_genes=1)
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(
        adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
    ) 
    adata.layers['counts'] = adata.X
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    print(adata)

    annotations, unannotated_cells = annotate_cells_stage(marker_dict, adata)
    print("There are {} cells remaining unannotated.".format(len(unannotated_cells)))
    adata.obs["annotations"] = pd.Series(annotations).astype("category")

    # Identify the two types of cells
    adata_epi = adata[adata.obs["annotations"] == "Epithelial", :].copy()
    annotations, unannotated_cells_stage_1 = annotate_cells_stage(marker_dict, adata_epi)
    annotations_stage_2, unannotated = annotate_cells_stage( {key:val for key, val in marker_stage1.items() if key in ["Malignant", "Epithelial"]}, adata_epi,
                                                            cells_to_annotate=unannotated_cells_stage_1) 
    for idx, ann in annotations_stage_2.items():
        annotations[idx] = ann
    adata_epi.obs["anno_1"] = pd.Series(annotations).astype("category")
    adata_epi.obs['anno_1'] = np.array(adata_epi.obs['anno_1'])
    adata.obs["annotations"] = np.array(adata.obs["annotations"])
    adata.obs.loc[adata_epi.obs.index, "annotations"] = adata_epi.obs['anno_1']

    
    visualize_results(adata)
    plt.rcParams['figure.figsize']=(4,4)
    try:
        sc.tl.rank_genes_groups(adata, groupby='annotations') 
        sc.pl.rank_genes_groups(adata)
    except:
        print('skip')

    sc.pl.spatial(adata, color=["annotations",'B2M'],size=0.3,img_key=None)
    sc.pl.spatial(adata)

    adata.obsm['spatial'] = np.array(adata.obsm['spatial']).astype(int)

    ##### please complete the output path #####
    adata.write_h5ad('./'+ sample + ".h5ad")

