In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from tqdm.notebook import tqdm

from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score

## Input

In [None]:
base_path = 'Z:/Data/Analyzed/2024-02-23-Hongyu-Covid_Spleen_replicate_2/'
output_path = os.path.join(base_path, 'output')
expr_path = os.path.join(base_path, 'expr')
test_path = os.path.join(output_path, "kmeans_cell_typing")
if not os.path.exists(test_path):
    os.mkdir(test_path)
    
cdata = sc.read_h5ad(os.path.join(expr_path, f'combined-raw.h5ad'))
cdata

In [None]:
sc.pp.calculate_qc_metrics(cdata, inplace=True, percent_top=None)

In [None]:
# load region info 
region_df = pd.read_csv(os.path.join(output_path, 'km-region-4.csv'), index_col=0)
region_df.index = region_df['unique_index']

In [None]:
cdata = cdata[region_df.index, :]
cdata.obs['region'] = region_df['region'].values
cdata

In [None]:
# reads count filtering 
sc.pp.filter_cells(cdata, min_counts=2)
sc.pp.filter_cells(cdata, min_genes=2)
cdata

In [None]:
# pp
sc.pp.normalize_total(cdata)
sc.pp.log1p(cdata)
cdata.raw = cdata
sc.pp.scale(cdata)
cdata.layers['scaled'] = cdata.X.copy()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
sns.boxplot(y='total_counts', x='sample', data=cdata.obs, ax=axs[0])
sns.boxplot(y='n_genes_by_counts', x='sample', data=cdata.obs, ax=axs[1])
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
sns.boxplot(y='total_counts', x='condition', data=cdata.obs, ax=axs[0])
sns.boxplot(y='n_genes_by_counts', x='condition', data=cdata.obs, ax=axs[1])
plt.show()

In [None]:
# subset by regions 
rdata = cdata[cdata.obs['region'].isin([1, 2, 3]), :]
rdata

In [None]:
sc.pl.matrixplot(rdata, rdata.var.index, groupby=f'condition', dendrogram=False, use_raw=False, cmap='bwr', vmin=-1, vmax=1, swap_axes=False)

## Level 1

In [None]:
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_1_binary'] == True, 'Gene'].to_list()
level_1_order = gene_annotation.loc[gene_annotation['Level_1_binary'] == True, 'Level_1_annotation'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(level_1_order)

In [None]:
gene_annotation.loc[gene_annotation['Level_1_binary'] == True, :]

In [None]:
selected_gene_dict = {}
for current_type in level_1_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_1_binary'] == True) & (gene_annotation['Level_1_annotation'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
sdata = rdata[:, selected_genes]
sdata

In [None]:
# Use PCA embedding
# extract pca coordinates
# sc.pp.pca(sdata)
# X_pca = sdata.obsm['X_pca']

# Use pped expression profile
X_pca = sdata.X

# kmeans
k = 34
kmeans = KMeans(n_clusters=k, random_state=5).fit(X_pca)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
# sc.pl.pca(sdata, color=[f'kmeans{k}'])

### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_1'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l1 = {}

# Level_1
level_1_list = [
    'NK cells', #0
    'NA', #1
    'Dendritic cells', #2
    'NK cells', #3
    'Macrophages', #4
    'NK cells', #5 
    'B cells', #6
    'B cells', #7
    'Dendritic cells', #8
    'Endothelial cells', #9
    'T cells', #10 
    'B cells', #11
    'Dendritic cells', #12
    'T cells', #13
    'NK cells', #14
    'Endothelial cells', #15
    'B cells', #16
    'Endothelial cells', #17
    'Macrophages', #18
    'Macrophages', #19
    'Dendritic cells', #20
    'B cells', #21
    'Macrophages', #22
    'Dendritic cells', #23
    'Dendritic cells', #24
    'Endothelial cells', #25
    'Macrophages', #26
    'T cells', #27
    'Dendritic cells', #28
    'Macrophages', #29
    'NK cells', #30
    'Dendritic cells', #31
    'B cells', #32
    'NK cells', #33
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l1[i] = level_1_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_1': transfer_dict_l1})

In [None]:
level_1_order = ['T cells', 'B cells', 'Macrophages', 'Dendritic cells', 'NK cells', 'Endothelial cells', 'NA']
sdata.obs['level_1'] = sdata.obs['level_1'].astype('category')
sdata.obs['level_1'] = sdata.obs['level_1'].cat.reorder_categories(level_1_order)

In [None]:
level_1_pl = sns.color_palette(['#00A651', '#FBB040', '#92278F', '#03a5fc', '#386363', '#d12852', '#dbdbdb'])
level_1_cmap = ListedColormap(level_1_pl.as_hex())
sns.palplot(level_1_pl)
plt.xticks(range(len(level_1_order)), level_1_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
total_counts = pd.DataFrame(sdata.obs['condition'].value_counts())
total_counts_dict = dict(zip(total_counts.index, total_counts['count']))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_1')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_1'] = [i[0] for i in sdata.obs.groupby('level_1')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_1')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
leiden_df['total_counts'] = leiden_df['condition'].values
leiden_df['total_counts'] = leiden_df['total_counts'].map(total_counts_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['total_counts'] = leiden_df['total_counts'].astype(np.float)
leiden_df['percentage'] = leiden_df['counts'] / leiden_df['total_counts'] * 100

sns.barplot(x='level_1', y='percentage', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
total_areas = [429, 305, 362, 401, 305, 372]
total_areas_dict = dict(zip(sdata.obs['sample'].cat.categories, total_areas))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_1')['sample'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_1'] = [i[0] for i in sdata.obs.groupby('level_1')['sample'].value_counts().index]
leiden_df['sample'] = [i[1] for i in sdata.obs.groupby('level_1')['sample'].value_counts().index]
leiden_df['sample'] = leiden_df['sample'].astype('category')
leiden_df['n_fovs'] = leiden_df['sample'].values
leiden_df['n_fovs'] = leiden_df['n_fovs'].map(total_areas_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['n_fovs'] = leiden_df['n_fovs'].astype(np.float)
leiden_df['counts_per_fov'] = leiden_df['counts'] / leiden_df['n_fovs']

sns.barplot(x='level_1', y='counts_per_fov', hue='sample', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
leiden_df = pd.DataFrame(sdata.obs.groupby('level_1')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_1'] = [i[0] for i in sdata.obs.groupby('level_1')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_1')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
sns.barplot(x='level_1', y='counts', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
plt.show()

In [None]:
selected_gene_dict = {}
for current_type in level_1_order:
    if current_type != 'NA':
        selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_1_binary'] == True) & (gene_annotation['Level_1_annotation'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_1', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_1', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[sdata.obs['sample'] == current_sample, :]
    
    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_1', data=current_obs, palette=level_1_pl, s=1, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to regional obj
rdata.obs['level_1'] = 'NA'
rdata.obs['level_1'] = rdata.obs['level_1'].astype(object)
rdata.obs.loc[sdata.obs.index, 'level_1'] = sdata.obs['level_1'].values
rdata.obs['level_1'].unique()

In [None]:
# Map to complete obj
cdata.obs['level_1'] = 'NA'
cdata.obs['level_1'] = cdata.obs['level_1'].astype(object)
cdata.obs.loc[sdata.obs.index, 'level_1'] = sdata.obs['level_1'].values
cdata.obs['level_1'].unique()

In [None]:
# backup
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{expr_path}/{date}-combined-level1-region123.h5ad")
cdata.write_h5ad(f"{expr_path}/{date}-combined-level1.h5ad")

## Level 2 

In [None]:
rdata.obs['level_2'] = rdata.obs['level_1'].values
rdata.obs['level_2'] = rdata.obs['level_2'].astype(object)

In [None]:
cdata.obs['level_2'] = cdata.obs['level_1'].values
cdata.obs['level_2'] = cdata.obs['level_2'].astype(object)

### T cells

In [None]:
# gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
# gene_annotation.index = gene_annotation['Gene']
# selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_t_cells'] == True, 'Gene'].to_list()
# current_order = gene_annotation.loc[gene_annotation['Level_2_binary_t_cells'] == True, 'Level_2_annotation_t_cells'].unique()
current_order = ['CD3+', 'CD3-']
# selected_genes = ['Cd3e',]
selected_genes = ['Cd3e', 'Cd3d', 'Cd3g', 'Cd4', 'Cd8a']
print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
# selected_gene_dict = {}
# for current_type in current_order:
#     selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_t_cells'] == True) & (gene_annotation['Level_2_annotation_t_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict = {'CD3+': selected_genes}
selected_gene_dict

In [None]:
sdata = rdata[rdata.obs['level_1'] == 'T cells', selected_genes]
sdata

In [None]:
# sc.pp.pca(sdata)
X_pca = sdata.X
# X_pca = sdata.obsm['X_pca']

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_pca)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 5
kmeans = KMeans(n_clusters=k, random_state=10).fit(X_pca)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_2'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l2 = {}

# Level_2
level_2_list = [
    'NA', #0
    'CD4+ T cells', #1
    'T cells', #2
    'T cells', #3
    'CD8+ T cells', #4
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l2[i] = level_2_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_2': transfer_dict_l2})

In [None]:
current_order = ['T cells', 'CD4+ T cells', 'CD8+ T cells', 'NA']
sdata.obs['level_2'] = sdata.obs['level_2'].astype('category')
sdata.obs['level_2'] = sdata.obs['level_2'].cat.reorder_categories(current_order)

In [None]:
current_pl = sns.color_palette('tab10', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
total_areas = [429, 305, 362, 401, 305, 372]
total_areas_dict = dict(zip(sdata.obs['sample'].cat.categories, total_areas))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['sample'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['sample'].value_counts().index]
leiden_df['sample'] = [i[1] for i in sdata.obs.groupby('level_2')['sample'].value_counts().index]
leiden_df['sample'] = leiden_df['sample'].astype('category')
leiden_df['n_fovs'] = leiden_df['sample'].values
leiden_df['n_fovs'] = leiden_df['n_fovs'].map(total_areas_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['n_fovs'] = leiden_df['n_fovs'].astype(np.float)
leiden_df['counts_per_fov'] = leiden_df['counts'] / leiden_df['n_fovs']

sns.barplot(x='level_2', y='counts_per_fov', hue='sample', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
total_counts = pd.DataFrame(sdata.obs['condition'].value_counts())
total_counts_dict = dict(zip(total_counts.index, total_counts['count']))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
leiden_df['total_counts'] = leiden_df['condition'].values
leiden_df['total_counts'] = leiden_df['total_counts'].map(total_counts_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['total_counts'] = leiden_df['total_counts'].astype(np.float)
leiden_df['percentage'] = leiden_df['counts'] / leiden_df['total_counts'] * 100

sns.barplot(x='level_2', y='percentage', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
sns.barplot(x='level_2', y='counts', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
plt.show()

In [None]:
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=True, cmap='Reds', swap_axes=True,)

In [None]:
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_2', data=current_obs, palette='tab10', size=10, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
rdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values

In [None]:
na_index = sdata.obs.loc[sdata.obs['level_2'] == 'NA', :].index
cdata.obs.loc[na_index, 'level_1'] = 'NA'
rdata.obs.loc[na_index, 'level_1'] = 'NA'

In [None]:
# backup
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{expr_path}/{date}-combined-level2-region123-bk.h5ad")
cdata.write_h5ad(f"{expr_path}/{date}-combined-level2-bk.h5ad")

### Macrophages

In [None]:
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_macrophages'] == True, 'Gene'].to_list()
current_order = gene_annotation.loc[gene_annotation['Level_2_binary_macrophages'] == True, 'Level_2_annotation_macrophages'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
selected_gene_dict = {}
for current_type in current_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_macrophages'] == True) & (gene_annotation['Level_2_annotation_macrophages'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
sdata = rdata[rdata.obs['level_1'] == 'Macrophages', selected_genes]
sdata

In [None]:
X_pca = sdata.X

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_pca)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 5
kmeans = KMeans(n_clusters=k, random_state=5).fit(X_pca)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_2'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l2 = {}

# Level_2
level_2_list = [
    'Macrophages', #0
    'Monocytes', #1
    'Macrophages', #2
    'Monocytes', #3
    'Activated Macrophages', #4
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l2[i] = level_2_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_2': transfer_dict_l2})

In [None]:
current_order = ['Macrophages', 'Activated Macrophages', 'Monocytes']
sdata.obs['level_2'] = sdata.obs['level_2'].astype('category')
sdata.obs['level_2'] = sdata.obs['level_2'].cat.reorder_categories(current_order)

In [None]:
current_pl = sns.color_palette('tab10', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
total_areas = [429, 305, 362, 401, 305, 372]
total_areas_dict = dict(zip(sdata.obs['sample'].cat.categories, total_areas))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['sample'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['sample'].value_counts().index]
leiden_df['sample'] = [i[1] for i in sdata.obs.groupby('level_2')['sample'].value_counts().index]
leiden_df['sample'] = leiden_df['sample'].astype('category')
leiden_df['n_fovs'] = leiden_df['sample'].values
leiden_df['n_fovs'] = leiden_df['n_fovs'].map(total_areas_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['n_fovs'] = leiden_df['n_fovs'].astype(np.float)
leiden_df['counts_per_fov'] = leiden_df['counts'] / leiden_df['n_fovs']

sns.barplot(x='level_2', y='counts_per_fov', hue='sample', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
total_counts = pd.DataFrame(sdata.obs['condition'].value_counts())
total_counts_dict = dict(zip(total_counts.index, total_counts['count']))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
leiden_df['total_counts'] = leiden_df['condition'].values
leiden_df['total_counts'] = leiden_df['total_counts'].map(total_counts_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['total_counts'] = leiden_df['total_counts'].astype(np.float)
leiden_df['percentage'] = leiden_df['counts'] / leiden_df['total_counts'] * 100

sns.barplot(x='level_2', y='percentage', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
sns.barplot(x='level_2', y='counts', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
plt.show()

In [None]:
selected_gene_dict = {}
for current_type in current_order:
    if current_type != 'NA':
        selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_macrophages'] == True) & (gene_annotation['Level_2_annotation_macrophages'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_2', data=current_obs, palette='tab10', s=5, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs['level_2'] = cdata.obs['level_2'].astype(object)
rdata.obs['level_2'] = rdata.obs['level_2'].astype(object)

cdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
rdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
cdata.obs['level_2'].unique()

In [None]:
# backup
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{expr_path}/{date}-combined-level2-region123-bk.h5ad")
cdata.write_h5ad(f"{expr_path}/{date}-combined-level2-bk.h5ad")

### Dendritic cells

In [None]:
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_dendritic_cells'] == True, 'Gene'].to_list()
current_order = gene_annotation.loc[gene_annotation['Level_2_binary_dendritic_cells'] == True, 'Level_2_annotation_dendritic_cells'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
selected_gene_dict = {}
for current_type in current_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_dendritic_cells'] == True) & (gene_annotation['Level_2_annotation_dendritic_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
sdata = rdata[rdata.obs['level_1'] == 'Dendritic cells', selected_genes]
sdata

In [None]:
X_pca = sdata.X

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_pca)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
# kmeans
k = 4
kmeans = KMeans(n_clusters=k, random_state=5).fit(X_pca)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_2'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l2 = {}

# Level_2
level_2_list = [
    'Other Dendritic cells', #0
    'cDC1', #1
    'cDC2', #2
    'Other Dendritic cells', #3
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l2[i] = level_2_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_2': transfer_dict_l2})

In [None]:
current_order = ['Other Dendritic cells', 'cDC1', 'cDC2']
sdata.obs['level_2'] = sdata.obs['level_2'].astype('category')
sdata.obs['level_2'] = sdata.obs['level_2'].cat.reorder_categories(current_order)

In [None]:
current_pl = sns.color_palette('tab10', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
total_areas = [429, 305, 362, 401, 305, 372]
total_areas_dict = dict(zip(sdata.obs['sample'].cat.categories, total_areas))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['sample'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['sample'].value_counts().index]
leiden_df['sample'] = [i[1] for i in sdata.obs.groupby('level_2')['sample'].value_counts().index]
leiden_df['sample'] = leiden_df['sample'].astype('category')
leiden_df['n_fovs'] = leiden_df['sample'].values
leiden_df['n_fovs'] = leiden_df['n_fovs'].map(total_areas_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['n_fovs'] = leiden_df['n_fovs'].astype(np.float)
leiden_df['counts_per_fov'] = leiden_df['counts'] / leiden_df['n_fovs']

sns.barplot(x='level_2', y='counts_per_fov', hue='sample', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
total_counts = pd.DataFrame(sdata.obs['condition'].value_counts())
total_counts_dict = dict(zip(total_counts.index, total_counts['count']))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
leiden_df['total_counts'] = leiden_df['condition'].values
leiden_df['total_counts'] = leiden_df['total_counts'].map(total_counts_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['total_counts'] = leiden_df['total_counts'].astype(np.float)
leiden_df['percentage'] = leiden_df['counts'] / leiden_df['total_counts'] * 100

sns.barplot(x='level_2', y='percentage', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
leiden_df = pd.DataFrame(sdata.obs.groupby('level_2')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_2'] = [i[0] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_2')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
sns.barplot(x='level_2', y='counts', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
plt.show()

In [None]:
selected_gene_dict = {}
for current_type in current_order:
    if current_type != 'NA':
        selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_dendritic_cells'] == True) & (gene_annotation['Level_2_annotation_dendritic_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)

In [None]:
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_2', data=current_obs, palette='tab10', s=5, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs['level_2'] = cdata.obs['level_2'].astype(object)
rdata.obs['level_2'] = rdata.obs['level_2'].astype(object)

cdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
rdata.obs.loc[sdata.obs.index, 'level_2'] = sdata.obs['level_2'].values
cdata.obs['level_2'].unique()

In [None]:
# backup
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{expr_path}/{date}-combined-level2-region123-bk.h5ad")
cdata.write_h5ad(f"{expr_path}/{date}-combined-level2-bk.h5ad")

## Level 3

In [None]:
# cdata = sc.read_h5ad(os.path.join(expr_path, f'2024-03-04-combined-level2-bk.h5ad'))
# rdata = sc.read_h5ad(os.path.join(expr_path, f'2024-03-04-combined-level2-region23-bk.h5ad'))

In [None]:
rdata.obs['level_3'] = rdata.obs['level_2'].values
rdata.obs['level_3'] = rdata.obs['level_3'].astype(object)

In [None]:
cdata.obs['level_3'] = cdata.obs['level_2'].values
cdata.obs['level_3'] = cdata.obs['level_3'].astype(object)

### CD4+/CD8+ T cells

In [None]:
gene_annotation = pd.read_csv(os.path.join(base_path, 'documents', 'gene_annotation.csv'))
gene_annotation.index = gene_annotation['Gene']
selected_genes = gene_annotation.loc[gene_annotation['Level_2_binary_t_cells'] == True, 'Gene'].to_list()
current_order = gene_annotation.loc[gene_annotation['Level_2_binary_t_cells'] == True, 'Level_2_annotation_t_cells'].unique()

print(f"Selected genes: {len(selected_genes)}")
print(current_order)

In [None]:
selected_gene_dict = {}
for current_type in current_order:
    selected_gene_dict[current_type] = gene_annotation.loc[(gene_annotation['Level_2_binary_t_cells'] == True) & (gene_annotation['Level_2_annotation_t_cells'] == current_type), 'Gene'].to_list()

selected_gene_dict

In [None]:
gene_annotation.loc[gene_annotation['Level_2_binary_t_cells'] == True, ['Gene', 'Level_2_annotation_t_cells']]

In [None]:
sdata = rdata[rdata.obs['level_2'].isin(['CD4+ T cells', 'CD8+ T cells']), selected_genes]
sdata

In [None]:
# sc.pp.pca(sdata)
X_pca = sdata.X
# X_pca = sdata.obsm['X_pca']

# Kmeans elbow
distorsions = []
for k in range(2, 20):
    kmeans = KMeans(n_clusters=k)
    kmeans.fit(X_pca)
    distorsions.append(kmeans.inertia_)

fig = plt.figure(figsize=(15, 5))
plt.plot(range(2, 20), distorsions)
plt.grid(True)
plt.title('Elbow curve')

In [None]:
for i in range(50):
    print(i)
    kmeans = KMeans(n_clusters=25, random_state=i).fit(X_pca)
    sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)
    sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)


In [None]:
# kmeans
k = 25
kmeans = KMeans(n_clusters=k, random_state=43).fit(X_pca)
sdata.obs[f'kmeans{k}'] = kmeans.labels_.astype(str)

# sc.pl.pca(sdata, color=[f'kmeans{k}'])
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True)
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, vmax=2, swap_axes=True)
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'kmeans{k}', dendrogram=False, use_raw=True, cmap='Reds', swap_axes=True)

#### assign cell types

In [None]:
# create backup for kmeans label
sdata.obs['level_3'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l3 = {}

# Level_3
level_3_list = [
    'Treg', #0
    'Naive CD4+ T cells', #1
    'CD8+ T cells', #2
    'CD4+ T cells', #3
    'Naive CD4+ T cells', #4
    'CD4+ T cells', #5 
    'CD8+ T cells', #6
    'CD4+ T cells', #7
    'CD8+ T cells', #8
    'Naive CD4+ T cells', #9
    'CD4+ T cells', #10 
    'Naive CD4+ T cells', #11
    'PD-1+ T cells', #12
    'CD8+ T cells', #13
    'CD8+ T cells', #14
    'CD8+ T cells', #15
    'Treg', #16
    'Naive CD8+ T cells', #17
    'Naive CD4+ T cells', #18
    'CD8+ T cells', #19
    'Th1', #20
    'Th17', #21
    'CD4+ T cells', #22
    'Th2', #23
    'Naive CD8+ T cells', #24
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l3[i] = level_3_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_3': transfer_dict_l3})

In [None]:
current_order = ['CD8+ T cells', 'CD4+ T cells', 'Treg', 'Th1', 'Th2', 'Th17', 'Naive CD8+ T cells', 'Naive CD4+ T cells', 'PD-1+ T cells']
sdata.obs['level_3'] = sdata.obs['level_3'].astype('category')
sdata.obs['level_3'] = sdata.obs['level_3'].cat.reorder_categories(current_order)

In [None]:
current_pl = sns.color_palette('tab20', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
total_areas = [429, 305, 362, 401, 305, 372]
total_areas_dict = dict(zip(sdata.obs['sample'].cat.categories, total_areas))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_3')['sample'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_3'] = [i[0] for i in sdata.obs.groupby('level_3')['sample'].value_counts().index]
leiden_df['sample'] = [i[1] for i in sdata.obs.groupby('level_3')['sample'].value_counts().index]
leiden_df['sample'] = leiden_df['sample'].astype('category')
leiden_df['n_fovs'] = leiden_df['sample'].values
leiden_df['n_fovs'] = leiden_df['n_fovs'].map(total_areas_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['n_fovs'] = leiden_df['n_fovs'].astype(np.float)
leiden_df['counts_per_fov'] = leiden_df['counts'] / leiden_df['n_fovs']

sns.barplot(x='level_3', y='counts_per_fov', hue='sample', data=leiden_df, )
plt.xticks(rotation=45)
plt.legend(loc='upper right')
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
total_counts = pd.DataFrame(sdata.obs['condition'].value_counts())
total_counts_dict = dict(zip(total_counts.index, total_counts['count']))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_3')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_3'] = [i[0] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
leiden_df['total_counts'] = leiden_df['condition'].values
leiden_df['total_counts'] = leiden_df['total_counts'].map(total_counts_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['total_counts'] = leiden_df['total_counts'].astype(np.float)
leiden_df['percentage'] = leiden_df['counts'] / leiden_df['total_counts'] * 100

sns.barplot(x='level_3', y='percentage', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
leiden_df = pd.DataFrame(sdata.obs.groupby('level_3')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_3'] = [i[0] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
sns.barplot(x='level_3', y='counts', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
plt.show()

In [None]:
selected_gene_dict = {'T cells': ['Cd3d', 'Cd3e', 'Cd3g'],
 'CD8 T cells': ['Cd8a'],
 'CD4 T cells': ['Cd4'],
 'Treg': ['Foxp3', 'Il2ra'],
 'Th1': ['Ifng', 'Tbx21'],
 'Th2': ['Il4'],
 'Th17': ['Il17a'],
 'Naive T cells': ["Sell", "Ccr7", "Lef1"],
 'PD-1+ T cells': ['Pdcd1']}

In [None]:
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=True, cmap='Reds', swap_axes=True)

In [None]:
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_3', data=current_obs, palette='tab20', size=10, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs.loc[sdata.obs.index, 'level_3'] = sdata.obs['level_3'].values
rdata.obs.loc[sdata.obs.index, 'level_3'] = sdata.obs['level_3'].values

In [None]:
# backup
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{expr_path}/{date}-combined-level3-region123-bk.h5ad")
cdata.write_h5ad(f"{expr_path}/{date}-combined-level3-bk.h5ad")

In [None]:
# create backup for kmeans label
sdata.obs['level_3'] = sdata.obs[f'kmeans{k}'].values

In [None]:
# Change cluster label to cell type label
transfer_dict_l3 = {}

# Level_3
level_3_list = [
    'ETP', #0
    'ETP', #1
    'ETP', #2
    'ETP', #3
    'ETP', #4
    'DN2', #5 
    'DN3', #6
    'DN3', #7
    'Undefined type', #8
    'ETP', #9
    'DN3', #10 
    'ETP', #11
    'Undefined type', #12
    'DN3', #13
    'ETP', #14
    'Undefined type', #15
    'DN3', #16
    'Bcl-2+ DN', #17
    'Undefined type', #18
    'DN2', #19
    'DN3', #20
    'ETP', #21
    'ETP', #22
    'Undefined type', #23
    'ETP', #24
]

# construct transfer dict
for i in sorted(sdata.obs[f'kmeans{k}'].unique()):
    transfer_dict_l3[i] = level_3_list[int(i)]

In [None]:
# Assign cell type to sdata
sdata.obs = sdata.obs.replace({'level_3': transfer_dict_l3})

In [None]:
current_order = ['ETP', 'DN2', 'DN3', 'Bcl-2+ DN', 'Undefined type']
sdata.obs['level_3'] = sdata.obs['level_3'].astype('category')
sdata.obs['level_3'] = sdata.obs['level_3'].cat.reorder_categories(current_order)

In [None]:
sdata.obs['level_3'] = sdata.obs['level_3'].astype(object)

In [None]:
sc.pp.pca(sdata)

In [None]:
def label_transfer(adata, embedding='pca', field='level_2_code', metric='cosine', n_neighbors=100):
    
    # reclassify starmap cells 
    ref_cells = adata.obs.loc[adata.obs['level_3'] != 'Undefined type', :].index
    query_cells = adata.obs.loc[adata.obs['level_3'] == 'Undefined type', :].index

    # cdhp
    ref_cell_loc = adata[ref_cells, :].obsm[f'X_{embedding}']
    query_cell_loc = adata[query_cells, :].obsm[f'X_{embedding}']

    # ref annotation
    ref_cell_annot = adata.obs.loc[ref_cells, field].values
    
    from sklearn.neighbors import KNeighborsClassifier
    neigh = KNeighborsClassifier(n_neighbors=n_neighbors, metric=metric)
    neigh.fit(ref_cell_loc, ref_cell_annot)
    query_cell_predicted = neigh.predict(query_cell_loc)
    
    return query_cell_predicted

In [None]:
# parameters
n_neighbors = 50

In [None]:
# create new label columns 
sdata.obs['level_3_cdhp'] = sdata.obs['level_3'].values

In [None]:
# conduct label transfer
predicted_label_cdhp = label_transfer(sdata, embedding='pca', field='level_3', metric='cosine', n_neighbors=n_neighbors)

In [None]:
# update to cdata
query_cells = sdata.obs.loc[sdata.obs['level_3'] == 'Undefined type', :].index

sdata.obs.loc[query_cells, 'level_3_cdhp'] = predicted_label_cdhp

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(pd.crosstab(sdata.obs['level_3'], sdata.obs['level_3_cdhp']), annot=True, fmt='g')

In [None]:
sdata.obs['level_3'] = sdata.obs['level_3_cdhp']

In [None]:
current_order = ['ETP', 'DN2', 'DN3', 'Bcl-2+ DN']
sdata.obs['level_3'] = sdata.obs['level_3'].astype('category')
sdata.obs['level_3'] = sdata.obs['level_3'].cat.reorder_categories(current_order)

In [None]:
current_pl = sns.color_palette('tab20', len(current_order))
current_cmap = ListedColormap(current_pl.as_hex())
sns.palplot(current_pl)
plt.xticks(range(len(current_order)), current_order, size=10, rotation=45)
plt.tight_layout()
# plt.savefig(os.path.join(fig_path, 'level_2_palette.pdf'))
plt.show()

In [None]:
total_areas = [262, 243, 271, 261, 273, 250]
total_areas_dict = dict(zip(sdata.obs['sample'].cat.categories, total_areas))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_3')['sample'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_3'] = [i[0] for i in sdata.obs.groupby('level_3')['sample'].value_counts().index]
leiden_df['sample'] = [i[1] for i in sdata.obs.groupby('level_3')['sample'].value_counts().index]
leiden_df['sample'] = leiden_df['sample'].astype('category')
leiden_df['n_fovs'] = leiden_df['sample'].values
leiden_df['n_fovs'] = leiden_df['n_fovs'].map(total_areas_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['n_fovs'] = leiden_df['n_fovs'].astype(np.float)
leiden_df['counts_per_fov'] = leiden_df['counts'] / leiden_df['n_fovs']

sns.barplot(x='level_3', y='counts_per_fov', hue='sample', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
total_counts = pd.DataFrame(sdata.obs['condition'].value_counts())
total_counts_dict = dict(zip(total_counts.index, total_counts['count']))

leiden_df = pd.DataFrame(sdata.obs.groupby('level_3')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_3'] = [i[0] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
leiden_df['total_counts'] = leiden_df['condition'].values
leiden_df['total_counts'] = leiden_df['total_counts'].map(total_counts_dict)
leiden_df['counts'] = leiden_df['counts'].astype(np.float)
leiden_df['total_counts'] = leiden_df['total_counts'].astype(np.float)
leiden_df['percentage'] = leiden_df['counts'] / leiden_df['total_counts'] * 100

sns.barplot(x='level_3', y='percentage', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
# plt.savefig(os.path.join(output_path, f"cell_type_composition_condition.png"))
plt.show()

In [None]:
leiden_df = pd.DataFrame(sdata.obs.groupby('level_3')['condition'].value_counts().values)
leiden_df.columns = ['counts']
leiden_df['level_3'] = [i[0] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = [i[1] for i in sdata.obs.groupby('level_3')['condition'].value_counts().index]
leiden_df['condition'] = leiden_df['condition'].astype('category')
sns.barplot(x='level_3', y='counts', hue='condition', data=leiden_df)
plt.xticks(rotation=45)
plt.show()

In [None]:
sc.pl.heatmap(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))
sc.pl.dotplot(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=True, cmap='Reds', swap_axes=True)

In [None]:
sc.pl.matrixplot(sdata, selected_gene_dict, groupby=f'level_3', dendrogram=False, use_raw=False, cmap='bwr', vmin=-1, vmax=1, swap_axes=True, standard_scale='var', figsize=(5,6))

In [None]:
for current_sample in sdata.obs['sample'].unique():
    print(current_sample)
    current_complete_obs = cdata.obs.loc[cdata.obs['sample'] == current_sample, :]
    current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample), :]
    # current_obs = sdata.obs.loc[(sdata.obs['sample'] == current_sample) & (sdata.obs['level_2'] != "T cells"), :]

    fig_size = np.array([current_complete_obs['global_x'].max(), current_complete_obs['global_y'].max()]) / 1000
    fig, ax = plt.subplots(figsize=fig_size)
    sns.scatterplot(x='global_x', y='global_y', data=current_complete_obs, color='#dbdbdb', s=1, linewidth=0, ax=ax)
    sns.scatterplot(x='global_x', y='global_y', hue='level_3', data=current_obs, palette='tab20', size=10, linewidth=0, ax=ax)
    # plt.savefig(os.path.join(output_path, f"sct_{current_sample}.png"))
    plt.show()

In [None]:
# Map to original obj
cdata.obs['level_3'] = cdata.obs['level_3'].astype(object)
rdata.obs['level_3'] = rdata.obs['level_3'].astype(object)

cdata.obs.loc[sdata.obs.index, 'level_3'] = sdata.obs['level_3'].values
rdata.obs.loc[sdata.obs.index, 'level_3'] = sdata.obs['level_3'].values

In [None]:
# backup
from datetime import datetime
date = datetime.today().strftime('%Y-%m-%d')
rdata.write_h5ad(f"{expr_path}/{date}-combined-level3-region23-bk.h5ad")
cdata.write_h5ad(f"{expr_path}/{date}-combined-level3-bk.h5ad")

In [None]:
a = pd.crosstab(sdata.obs['level_2'], sdata.obs['level_3'])

In [None]:
a

In [None]:
fig, ax = plt.subplots(figsize=(10,5))
sns.heatmap(a, annot=True, fmt='g')
plt.show()

## Test

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
sns.heatmap(pd.crosstab(sdata.obs['level_2'], sdata.obs['level_3']), annot=True, fmt='g')

In [None]:
rdata

In [None]:
dc_genes = ['H2-Aa', 'Cd40', 'Cd83', 'Cd86', 'H2-K1', 'Ccl17', 'Ccl22', 'Ccl25']

In [None]:
rdata

In [None]:
sdata = rdata[rdata.obs['level_1'] == 'Dendritic cells', :]
sdata

In [None]:
selected_gene_dict = {'CD3+': ['Cd3d', 'Cd3e', 'Cd3g', 'Ccr7', 'Cd4', 'Cd8a']}

selected_gene_dict =  ['Cd3d', 'Cd3e', 'Cd3g', 'Ccr7', 'Cd4', 'Cd8a']

In [None]:
level_2_order = [
    'Other Dendritic cells',
    'cDC1',
    'cDC2',
]

condition_order = ['WT', '99R', '33NM']

In [None]:
level_2_condition_order = []

for i in level_2_order:
    for j in condition_order:
        level_2_condition_order.append(f"{i}_{j}")

sdata.obs['level_2_condition'] = sdata.obs['level_2'].astype(str) + '_' + sdata.obs['condition'].astype(str)
sdata.obs['level_2_condition'] = sdata.obs['level_2_condition'].astype('category')
sdata.obs['level_2_condition'] = sdata.obs['level_2_condition'].cat.reorder_categories(level_2_condition_order)

In [None]:
sc.pl.dotplot(sdata, dc_genes, groupby=f'level_2_condition', dendrogram=False, use_raw=True, cmap='viridis', swap_axes=True)

In [None]:
sc.pl.matrixplot(sdata, dc_genes, groupby=f'level_2_condition', dendrogram=False, use_raw=False, cmap='bwr', vmin=-1, vmax=1, swap_axes=True)

In [None]:
for current_type in sdata.obs['level_2'].cat.categories:
    print(current_type)
    pdata = sdata[sdata.obs['level_2'] == current_type, :]
    sc.pl.dotplot(pdata, dc_genes, groupby=f'condition', dendrogram=False, use_raw=True, cmap='viridis', swap_axes=True)
    sc.pl.matrixplot(pdata, dc_genes, groupby=f'condition', dendrogram=False, use_raw=False, cmap='bwr', vmin=-1, vmax=1, swap_axes=True)

In [None]:
sc.pl.matrixplot(sdata, dc_genes, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))

In [None]:
sc.pl.heatmap(sdata, dc_genes, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, figsize=(3, 15))
sc.pl.matrixplot(sdata, dc_genes, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))
sc.pl.dotplot(sdata, dc_genes, groupby=f'level_2', dendrogram=False, use_raw=False, cmap='bwr', vmin=-4, vmax=4, swap_axes=True, figsize=(5, 6))