# Cell type classification for STARmap PLUS data

2024-09-29

In [None]:
# load libraries 
import os
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tqdm.notebook import tqdm

from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score

## Input

In [None]:
# define IO path and load data object
base_path = './path/to/dataset'
output_path = os.path.join(base_path, 'output')
expr_path = os.path.join(base_path, 'expr')

cdata = sc.read_h5ad(os.path.join(expr_path, f'combined-raw.h5ad'))
cdata

In [None]:
# create unique index
cdata.obs['unique_index'] = cdata.obs['sample'].astype(str) + '_' + cdata.obs['fov_id'].astype(str) + '_' +  cdata.obs['seg_label'].astype(str)
cdata.obs.index = cdata.obs['unique_index']

In [None]:
# load region info generated by SPIN
region_df = pd.read_csv(os.path.join(output_path, 'km-region-4.csv'), index_col=0)
region_df['unique_index'] = region_df['sample'].astype(str) + '_' + region_df['fov_id'].astype(str) + '_' +  region_df['seg_label'].astype(str)
region_df.index = region_df['unique_index']

cdata = cdata[region_df.index, :]
cdata.obs['region'] = region_df['region'].values

In [None]:
# add condition 
cdata.obs['condition'] = cdata.obs['sample'].values
condition_dict = {
    'sample1': 'WT',
    'sample2': "WT",
    'sample3': "99R",
    'sample4': '99R',
    'sample5': '33NM',
    'sample6': '33NM'
}
cdata.obs['condition'] = cdata.obs['condition'].map(condition_dict)
cdata.obs['condition'] = cdata.obs['condition'].astype('category')
cdata.obs['sample'] = cdata.obs['sample'].astype('category')
cdata.obs['sample'].value_counts()

## Preprocessing

In [None]:
# calculate basic QC metrics
sc.pp.calculate_qc_metrics(cdata, inplace=True, percent_top=None)

In [None]:
# reads count filtering 
sc.pp.filter_cells(cdata, min_counts=2)
sc.pp.filter_cells(cdata, min_genes=2)
cdata

In [None]:
# normalization and scaling 
sc.pp.normalize_total(cdata)
sc.pp.log1p(cdata)
cdata.raw = cdata
sc.pp.scale(cdata)
cdata.layers['scaled'] = cdata.X.copy()

In [None]:
# subset by regions 
rdata = cdata[cdata.obs['region'].isin([2, 3]), :].copy()
rdata

## Level 1

In [None]:
# load gene annotation for cell typing 
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_1_binary'] == True, 'Gene'].to_list()
level_1_order = gene_annotation.loc[gene_annotation['Level_1_binary'] == True, 'Level_1_annotation'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(level_1_order)

In [None]:
# create gene dict
selected_gene_dict = {}
for current_type in level_1_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_1_binary'] == True) & (gene_annotation['Level_1_annotation'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# create subset
sdata = rdata[:, selected_genes]
sdata

In [None]:
# use pped expression profile
X_expr = sdata.X

# kmeans
k = 24
kmeans = KMeans(n_clusters=k, random_state=5).fit(X_expr)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_1'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l1 = {}

# Level_1
level_1_list = [
    'T cells', #0
    'Dendritic cells', #1
    'B cells', #2
    'NA', #3
    'Dendritic cells', #4
    'T cells', #5 
    'T cells', #6
    'Dendritic cells', #7
    'T cells', #8
    'B cells', #9
    'Macrophages', #10 
    'Macrophages', #11
    'Dendritic cells', #12
    'T cells', #13
    'Macrophages', #14
    'T cells', #15
    'B cells', #16
    'B cells', #17
    'Dendritic cells', #18
    'Macrophages', #19
    'B cells', #20
    'Dendritic cells', #21
    'B cells', #22
    'Macrophages', #23
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l1[i] = level_1_list[int(i)]

In [None]:
# assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_1': transfer_dict_l1})

In [None]:
# order categories
level_1_order = ['T cells', 'B cells', 'Macrophages', 'Dendritic cells', 'NA']
sdata.obs['level_1'] = sdata.obs['level_1'].astype('category')
sdata.obs['level_1'] = sdata.obs['level_1'].cat.reorder_categories(level_1_order)

In [None]:
# create color palette
level_1_pl = sns.color_palette(['#00A651', '#FBB040', '#92278F', '#03a5fc', '#dbdbdb'])
level_1_cmap = ListedColormap(level_1_pl.as_hex())
sns.palplot(level_1_pl)
plt.xticks(range(len(level_1_order)), level_1_order, size=10, rotation=45)
plt.tight_layout()
plt.show()

In [None]:
# create gene dict for visualization
selected_gene_dict = {}
for current_type in level_1_order:
    if current_type != 'NA':
        selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_1_binary'] == True) & (gene_annotation['Level_1_annotation'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# plot heatmap
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_1', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_1', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
# plot spatial cell type map
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[sdata.obs['sample'] == current_sample, :]
    
    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_1', data=current_obs, palette=level_1_pl, s=1, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to regional obj
rdata.obs['level_1'] = 'NA'
rdata.obs['level_1'] = rdata.obs['level_1'].astype(object)
rdata.obs.loc[sdata.obs.index, 'level_1'] = sdata.obs['level_1'].values
rdata.obs['level_1'].unique()

In [None]:
# Map to complete obj
cdata.obs['level_1'] = 'NA'
cdata.obs['level_1'] = cdata.obs['level_1'].astype(object)
cdata.obs.loc[sdata.obs.index, 'level_1'] = sdata.obs['level_1'].values
cdata.obs['level_1'].unique()

In [None]:
# # backup
# from datetime import datetime
# date = datetime.today().strftime('%Y-%m-%d')
# rdata.write_h5ad(f"{expr_path}/{date}-combined-level1-region23.h5ad")
# cdata.write_h5ad(f"{expr_path}/{date}-combined-level1.h5ad")

## Level 2 

In [None]:
# create level2 annotation 
rdata.obs['level_2'] = rdata.obs['level_1'].values
rdata.obs['level_2'] = rdata.obs['level_2'].astype(object)

cdata.obs['level_2'] = cdata.obs['level_1'].values
cdata.obs['level_2'] = cdata.obs['level_2'].astype(object)

### T cells

In [None]:
# select genes for t cell clustering 
current_order = ['CD3+', 'CD3-']
selected_genes = ['Cd3e', 'Cd3d', 'Cd3g', 'Cd4', 'Cd8a', 'Ccr7']
print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
# create gene dict
selected_gene_dict = {'CD3+': selected_genes}
selected_gene_dict

In [None]:
# subset
sdata = rdata[rdata.obs['level_1'] == 'T cells', selected_genes]
sdata

In [None]:
X_expr = sdata.X

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_expr)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 5
kmeans = KMeans(n_clusters=k, random_state=10).fit(X_expr)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_2'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l2 = {}

# Level_2
level_2_list = [
    'Undefined T cells', #0
    'NA', #1
    'Synthetic T cells', #2
    'CD4+ T cells', #3
    'CD8+ T cells', #4
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l2[i] = level_2_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_2': transfer_dict_l2})

In [None]:
# create level 2 t cell order
current_order = ['Undefined T cells', 'Synthetic T cells', 'CD4+ T cells', 'CD8+ T cells', 'NA']
sdata.obs['level_2'] = sdata.obs['level_2'].astype('category')
sdata.obs['level_2'] = sdata.obs['level_2'].cat.reorder_categories(current_order)

In [None]:
# create color palette
current_pl = sns.color_palette('tab10', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# plot heatmap
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=True, cmap='Reds', swap_axes=True,)

In [None]:
# plot spatial cell type map
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_2', data=current_obs, palette='tab10', size=10, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
rdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values

In [None]:
# assign na cells 
na_index = sdata.obs.loc[sdata.obs['level_2'] == 'NA', :].index
cdata.obs.loc[na_index, 'level_1'] = 'NA'
rdata.obs.loc[na_index, 'level_1'] = 'NA'

### Macrophages

In [None]:
# select genes for macrophage clustering 
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_macrophages'] == True, 'Gene'].to_list()
current_order = gene_annotation.loc[gene_annotation['Level_2_binary_macrophages'] == True, 'Level_2_annotation_macrophages'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
# create gene dict
selected_gene_dict = {}
for current_type in current_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_macrophages'] == True) & (gene_annotation['Level_2_annotation_macrophages'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# subset
sdata = rdata[rdata.obs['level_1'] == 'Macrophages', selected_genes]
sdata

In [None]:
X_expr = sdata.X

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_expr)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 5
kmeans = KMeans(n_clusters=k, random_state=5).fit(X_expr)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_2'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l2 = {}

# Level_2
level_2_list = [
    'Monocytes', #0
    'Macrophages', #1
    'Activated Macrophages', #2
    'Macrophages', #3
    'Monocytes', #4
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l2[i] = level_2_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_2': transfer_dict_l2})

In [None]:
# create level 2 macrophage order
current_order = ['Macrophages', 'Activated Macrophages', 'Monocytes']
sdata.obs['level_2'] = sdata.obs['level_2'].astype('category')
sdata.obs['level_2'] = sdata.obs['level_2'].cat.reorder_categories(current_order)

In [None]:
# create color palette
current_pl = sns.color_palette('tab10', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# create gene dict for viualization 
selected_gene_dict = {}
for current_type in current_order:
    if current_type != 'NA':
        selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_macrophages'] == True) & (gene_annotation['Level_2_annotation_macrophages'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# plot heatmap
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
# plot spatial cell type map
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_2', data=current_obs, palette='tab10', s=5, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs['level_2'] = cdata.obs['level_2'].astype(object)
rdata.obs['level_2'] = rdata.obs['level_2'].astype(object)

cdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
rdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
cdata.obs['level_2'].unique()

### Dendritic cells

In [None]:
# select genes for dendritic cell clustering 
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_dendritic_cells'] == True, 'Gene'].to_list()
current_order = gene_annotation.loc[gene_annotation['Level_2_binary_dendritic_cells'] == True, 'Level_2_annotation_dendritic_cells'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
# create gene dict
selected_gene_dict = {}
for current_type in current_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_dendritic_cells'] == True) & (gene_annotation['Level_2_annotation_dendritic_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# subset
sdata = rdata[rdata.obs['level_1'] == 'Dendritic cells', selected_genes]
sdata

In [None]:
X_expr = sdata.X

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_expr)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 4
kmeans = KMeans(n_clusters=k, random_state=5).fit(X_expr)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_2'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l2 = {}

# Level_2
level_2_list = [
    'cDC2', #0
    'Other Dendritic cells', #1
    'cDC1', #2
    'Other Dendritic cells', #3
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l2[i] = level_2_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_2': transfer_dict_l2})

In [None]:
# create level 2 dendritic cell order
current_order = ['Other Dendritic cells', 'cDC1', 'cDC2']
sdata.obs['level_2'] = sdata.obs['level_2'].astype('category')
sdata.obs['level_2'] = sdata.obs['level_2'].cat.reorder_categories(current_order)

In [None]:
# create color palette
current_pl = sns.color_palette('tab10', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# create gene dict
selected_gene_dict = {}
for current_type in current_order:
    if current_type != 'NA':
        selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_dendritic_cells'] == True) & (gene_annotation['Level_2_annotation_dendritic_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# plot heatmap
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
# plot spatial cell type map
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_2', data=current_obs, palette='tab10', s=5, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs['level_2'] = cdata.obs['level_2'].astype(object)
rdata.obs['level_2'] = rdata.obs['level_2'].astype(object)

cdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
rdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
cdata.obs['level_2'].unique()

### B cells

In [None]:
# select genes for b cell clustering 
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_b_cells'] == True, 'Gene'].to_list()
current_order = gene_annotation.loc[gene_annotation['Level_2_binary_b_cells'] == True, 'Level_2_annotation_b_cells'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
# create gene dict
selected_gene_dict = {}
for current_type in current_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_b_cells'] == True) & (gene_annotation['Level_2_annotation_b_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# subset
sdata = rdata[rdata.obs['level_1'] == 'B cells', selected_genes]
sdata

In [None]:
X_expr = sdata.X

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_expr)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 7
kmeans = KMeans(n_clusters=k, random_state=5).fit(X_expr)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_2'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l2 = {}

# Level_2
level_2_list = [
    'Activated B cells', #0
    'B cells', #1
    'Activated B cells', #2
    'Activated B cells', #3
    'Age-associated B cells', #4
    'Age-associated B cells', #5
    'Activated B cells', #6
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l2[i] = level_2_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_2': transfer_dict_l2})

In [None]:
# create level 2 b cell order
current_order = ['B cells', 'Activated B cells', 'Age-associated B cells']
sdata.obs['level_2'] = sdata.obs['level_2'].astype('category')
sdata.obs['level_2'] = sdata.obs['level_2'].cat.reorder_categories(current_order)

In [None]:
# create color palette
current_pl = sns.color_palette('tab10', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
# create gene dict for visualization 
selected_gene_dict = {}
for current_type in current_order:
    if current_type != 'NA':
        selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_b_cells'] == True) & (gene_annotation['Level_2_annotation_b_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# plot heatmap
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
# plot spatial cell type map
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_2', data=current_obs, palette='tab10', s=5, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs['level_2'] = cdata.obs['level_2'].astype(object)
rdata.obs['level_2'] = rdata.obs['level_2'].astype(object)

cdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
rdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
cdata.obs['level_2'].unique()

## Level 3

In [None]:
# create level3 annotation 

rdata.obs['level_3'] = rdata.obs['level_2'].values
rdata.obs['level_3'] = rdata.obs['level_3'].astype(object)

cdata.obs['level_3'] = cdata.obs['level_2'].values
cdata.obs['level_3'] = cdata.obs['level_3'].astype(object)

### CD4+/CD8+ T cells

In [None]:
# select genes for t cell clustering 
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_t_cells'] == True, 'Gene'].to_list()
current_order = gene_annotation.loc[gene_annotation['Level_2_binary_t_cells'] == True, 'Level_2_annotation_t_cells'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
# create gene dict
selected_gene_dict = {}
for current_type in current_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_t_cells'] == True) & (gene_annotation['Level_2_annotation_t_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
# subset
sdata = rdata[rdata.obs['level_2'].isin(['CD4+ T cells', 'CD8+ T cells']), selected_genes]
sdata

In [None]:
X_expr = sdata.X

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_expr)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 25
kmeans = KMeans(n_clusters=k, random_state=24).fit(X_expr)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, cmap='Reds', swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_3'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l3 = {}

# Level_3
level_3_list = [
    'Th1', #0
    'CD4+ T cells', #1
    'CD8+ T cells', #2
    'Naive CD4+ T cells', #3
    'Naive CD4+ T cells', #4
    'Naive CD8+ T cells', #5 
    'CD8+ T cells', #6
    'Naive CD4+ T cells', #7
    'Exhausted T cells', #8
    'CD4+ T cells', #9
    'Naive CD4+ T cells', #10 
    'CD8+ T cells', #11
    'Th1', #12
    'Naive CD8+ T cells', #13
    'Treg', #14
    'CD8+ T cells', #15
    'Th2', #16
    'Th17', #17
    'CD4+ T cells', #18
    'CD8+ T cells', #19
    'Naive CD4+ T cells', #20
    'Naive CD8+ T cells', #21
    'CD4+ T cells', #22
    'Naive CD4+ T cells', #23
    'Treg', #24
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l3[i] = level_3_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_3': transfer_dict_l3})

In [None]:
# create level 3 t cell order
current_order = ['CD8+ T cells', 'CD4+ T cells', 'Treg', 'Th1', 'Th2', 'Th17', 'Naive CD8+ T cells', 'Naive CD4+ T cells', 'Exhausted T cells']
sdata.obs['level_3'] = sdata.obs['level_3'].astype('category')
sdata.obs['level_3'] = sdata.obs['level_3'].cat.reorder_categories(current_order)

In [None]:
# create color palette
current_pl = sns.color_palette('tab20', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_3_palette.pdf'))
plt.show()

In [None]:
# create gene dict for visualization
selected_gene_dict = {'T cells': ['Cd3d', 'Cd3e', 'Cd3g'],
 'CD8 T cells': ['Cd8a'],
 'CD4 T cells': ['Cd4'],
 'Treg': ['Foxp3', 'Il2ra'],
 'Th1': ['Ifng', 'Tbx21'],
 'Th2': ['Il4'],
 'Th17': ['Il17a'],
 'Naive T cells': ["Sell", "Ccr7", "Lef1"],
 'Exhausted T cells': ['Pdcd1']}

In [None]:
# plot heatmap
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=True, cmap='Reds', swap_axes=True)

In [None]:
# plot spatial cell type map
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_3', data=current_obs, palette='tab20', size=10, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs.loc[sdata.obs.index, 'level_3'] = sdata.obs['level_3'].values
rdata.obs.loc[sdata.obs.index, 'level_3'] = sdata.obs['level_3'].values

In [None]:
# backup
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{expr_path}/{date}-combined-level3-region23-bk.h5ad")
cdata.write_h5ad(f"{expr_path}/{date}-combined-level3-bk.h5ad")