In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc
import umap
import scipy.stats

import torch

from scmg.model.contrastive_embedding import (CellEmbedder,  embed_adata)

from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
plot_output_path = 'pert_pred_plots'
os.makedirs(plot_output_path, exist_ok=True)

In [None]:
pert_data_files = [
    '/GPUData_xingjie/SCMG/perturbation_data/AdamsonWeissman2016_GSM2406681_10X010.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/FrangiehIzar2021_RNA.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/hESC_TF_screen.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNG.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_INS.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TGFB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TNFA.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/Joung_TFScreen_HS_2023.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/knockTF_human.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/knockTF_mouse.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/omnipath.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/PertOrg.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_essential.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_gwps.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_rpe1.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRa.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRi.h5ad',
    #'/GPUData_xingjie/SCMG/hESC_perturb_seq/pseudo_bulk.h5ad', # Test
]

adata_pert_list = []
for pdf in pert_data_files:
    adata_pert_list.append(sc.read_h5ad(pdf))
    print(os.path.basename(pdf), adata_pert_list[-1].shape[0])

adata_pert = anndata.concat(adata_pert_list, axis=0)
adata_pert.var['gene_name'] = adata_pert_list[0].var['gene_name']

adata_pert

In [None]:
# Mask out the direct target genes
for i in range(adata_pert.shape[0]):
    pg = adata_pert.obs['perturbed_gene'].iloc[i]
    
    if pg in adata_pert.var_names:
        adata_pert.X[i, adata_pert.var_names.get_loc(pg)] = 0

adata_pert.obs['effect_size'] = np.linalg.norm(adata_pert.X, axis=1)
adata_pert.obs['max_effect'] = np.max(np.abs(adata_pert.X), axis=1)

In [None]:
gene_func_emb_df = pd.read_parquet('gene_func_emb_no_rpe1.parquet')
adata_pert = adata_pert[adata_pert.obs['perturbed_gene_name'].isin(gene_func_emb_df.index)].copy()
adata_pert.obsm['gene_func_emb'] = gene_func_emb_df.loc[adata_pert.obs['perturbed_gene_name']].values

In [None]:
adata_pert.obs['condition'].value_counts()

In [None]:
adata_k562_w = adata_pert[adata_pert.obs['condition'] == 'ReplogleWeissman2022_K562_gwps'].copy()
adata_k562_e = adata_pert[adata_pert.obs['condition'] == 'ReplogleWeissman2022_K562_essential'].copy()
adata_rpe1 = adata_pert[adata_pert.obs['condition'] == 'ReplogleWeissman2022_rpe1'].copy()

In [None]:
adata_no_target = adata_pert[~adata_pert.obs['condition'].isin(['ReplogleWeissman2022_rpe1'])].copy()

i_to_keep = []
for i in range(adata_no_target.shape[0]):
    pg = adata_no_target.obs['perturbed_gene'].iloc[i]
    adata_no_target.X[i] = adata_no_target[adata_no_target.obs['perturbed_gene'] == pg].X.mean(axis=0)

    if i == min(np.nonzero(list(adata_no_target.obs['perturbed_gene'] == pg))[0]):
        i_to_keep.append(True)
    else:
        i_to_keep.append(False)

adata_no_target = adata_no_target[i_to_keep].copy()

In [None]:
k562_intra_comp_dict = {
        'gene_name' : [],
        'sim': [],
    }

adata_source = adata_k562_e.copy()
adata_target = adata_k562_w.copy()
common_genes = np.intersect1d(adata_source.obs['perturbed_gene'], adata_target.obs['perturbed_gene'])
adata_source.obs.index = list(adata_source.obs['perturbed_gene'])
adata_source = adata_source[common_genes, :].copy()
adata_target.obs.index = list(adata_target.obs['perturbed_gene'])
adata_target = adata_target[common_genes, :].copy()

def sim_func(v1, v2):
        #return 1 - scipy.spatial.distance.cosine(v1, v2)
        return 1 - scipy.spatial.distance.correlation(v1, v2)


for i in range(adata_source.shape[0]):
    k562_intra_comp_dict['gene_name'].append(adata_source.obs['perturbed_gene'].iloc[i])

    v1 = adata_source.X[i, :]
    v2 = adata_target.X[i, :]

    k562_intra_comp_dict['sim'].append(sim_func(v1, v2))


k562_intra_comp_df = pd.DataFrame(k562_intra_comp_dict)
k562_intra_comp_df

In [None]:
adata_zero_shot = sc.read_h5ad('zero_shot_pred_RPE1.h5ad')
adata_zero_shot.X = adata_zero_shot.layers['zero_shot_pred']

In [None]:
#adata_source = adata_k562_w.copy()
#adata_source = adata_rpe1.copy()
#adata_source = adata_hesc.copy()
adata_source = adata_no_target.copy()

#adata_target = adata_k562_w.copy()
adata_target = adata_rpe1.copy()
#adata_target = adata_hesc.copy()

In [None]:
common_genes = np.intersect1d(adata_source.obs['perturbed_gene'], adata_target.obs['perturbed_gene'])
common_genes = np.intersect1d(common_genes, adata_zero_shot.obs['perturbed_gene'])
common_genes.shape

In [None]:
adata_source.obs.index = list(adata_source.obs['perturbed_gene'])
adata_source = adata_source[common_genes, :].copy()
adata_target.obs.index = list(adata_target.obs['perturbed_gene'])
adata_target = adata_target[common_genes, :].copy()
adata_zero_shot.obs.index = list(adata_zero_shot.obs['perturbed_gene'])
adata_zero_shot = adata_zero_shot[common_genes, :].copy()

In [None]:
from sklearn.cluster import KMeans

def kmeans_find_centroid_indices(X, k):
    kmeans = KMeans(n_clusters=k, random_state=0).fit(X)
    centroids = kmeans.cluster_centers_
    indices = []
    for centroid in centroids:
        distances = np.linalg.norm(X - centroid, axis=1)
        index = np.argmin(distances)
        indices.append(index)
    return np.unique(indices)

from sklearn.linear_model import Ridge

def few_shot_prediction(adata_ref, adata_query, adata_few_shot, alpha=3):
    model = Ridge(alpha=alpha)

    model.fit(adata_ref.obsm['gene_func_emb'], adata_few_shot.X)
    adata_query.layers['predicted_X'] = model.predict(adata_query.obsm['gene_func_emb'])

In [None]:
#alphas = [0.5, 1, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0, 2.1, 2.2, 2.3, 2.4, 2.5, 3, 4, 5, 6]
#
#few_shot_ids = kmeans_find_centroid_indices(adata_source.obsm['gene_func_emb'], k=100)
#adata_ref = adata_source[few_shot_ids, :].copy()
#adata_few_shot = adata_target[few_shot_ids, :].copy()
#adata_query = adata_source[~adata_source.obs.index.isin(adata_few_shot.obs.index)].copy()
#adata_true = adata_target[~adata_target.obs.index.isin(adata_few_shot.obs.index)].copy()
#
#
#for alpha in alphas:
#    few_shot_prediction(adata_ref, adata_query, adata_few_shot, alpha=alpha)
#
#    sims = []
#    for i in range(adata_true.shape[0]):
#    
#        v1 = adata_true.X[i, :]
#        v2 = adata_query.layers['predicted_X'][i, :]
#        sims.append(1 - scipy.spatial.distance.cosine(v1, v2))
#
#    print(f'Alpha: {alpha}, Median similarity: {np.median(sims):.4f}, Mean similarity: {np.mean(sims):.4f}, Std: {np.std(sims):.4f}')

In [None]:
k_values = [-1, 0, 5, 10, 20, 50, 100, 200,]

comp_dict = {
        'k' : [],
        'gene_name' : [],
        'sim': [],
        'true_effect_size' : [],
    }

def sim_func(v1, v2):
        #return 1 - scipy.spatial.distance.cosine(v1, v2)
        return 1 - scipy.spatial.distance.correlation(v1, v2)

for k in tqdm(k_values):

    if k == -1:
        adata_true = adata_target.copy()
        adata_query = adata_source.copy()
        adata_query.layers['predicted_X'] = adata_query.X.copy()

    elif k == 0:
        adata_true = adata_target.copy()
        adata_query = adata_zero_shot.copy()
        adata_query.layers['predicted_X'] = adata_query.X.copy()
        
    else:
        few_shot_ids = kmeans_find_centroid_indices(adata_source.obsm['gene_func_emb'], k)
        adata_ref = adata_source[few_shot_ids, :].copy()
        adata_few_shot = adata_target[few_shot_ids, :].copy()
        adata_query = adata_source[~adata_source.obs.index.isin(adata_few_shot.obs.index)].copy()
        adata_true = adata_target[~adata_target.obs.index.isin(adata_few_shot.obs.index)].copy()

        few_shot_prediction(adata_ref, adata_query, adata_few_shot)

    for i in range(adata_true.shape[0]):
        comp_dict['k'].append(k)
        comp_dict['gene_name'].append(adata_true.obs['perturbed_gene'].iloc[i])

        v1 = adata_true.X[i, :]
        v2 = adata_query.layers['predicted_X'][i, :]

        comp_dict['sim'].append(sim_func(v1, v2))
        comp_dict['true_effect_size'].append(np.linalg.norm(v1))


comp_df = pd.DataFrame(comp_dict)
    

In [None]:
show_df = comp_df.copy()

fig, ax = plt.subplots(figsize=(8,4))
sns.violinplot(data=show_df, x='k', y='sim',
               order=k_values, 
               #inner= 'quart', fill=False, #color='tab:green',
               width=1,
               ax=ax, )

ax.axhline(0, color='black', linestyle='--', linewidth=1)
ax.set_ylim(-0.2, 1.0)
ax.set_xlabel('Number of few shots (K)', fontsize=14)
ax.set_ylabel('Correlation', fontsize=14)
ax.set_title('RPE1 prediction', fontsize=16)

ax.set_xticklabels(['baseline \ntrain set mean', 'zero shot'] + k_values[2:], fontsize=12)

fig.savefig(os.path.join(plot_output_path, f'few_shot_prediction_rpe1.pdf'), bbox_inches='tight')

In [None]:
k1 = -1
k2 = 100
show_df = comp_df[
     (comp_df['k'].isin([k1, k2]))
].copy()
show_df['color'] = 'deepskyblue'
show_df.loc[show_df['k'] == k2, 'color'] = 'red'
shuffled_df = show_df.sample(frac=1, random_state=42, replace=False)

fig, ax = plt.subplots(figsize=(4, 4))
ax.scatter(shuffled_df['true_effect_size'], shuffled_df['sim'], s=1, alpha=1, rasterized=True, color=shuffled_df['color'])

ax.set_xlim(1, 24)
ax.set_ylim(-0.2, 1)

# Add legend to the plot
handles = []
handles.append(plt.Line2D([0], [0], marker='o', linestyle='', color='deepskyblue', markersize=2))
handles.append(plt.Line2D([0], [0], marker='o', linestyle='', color='red', markersize=2))
plt.legend(handles=handles, labels=[f'baseline', f'K = {k2}'], title='Categories')

ax.set_xlabel('True Effect Size')
ax.set_ylabel('Similarity')
ax.set_title('RPE1 prediction', fontsize=16)
fig.savefig(os.path.join(plot_output_path, f'few_shot_prediction_scatter_rpe1.pdf'), bbox_inches='tight')

In [None]:
genes_to_show = np.intersect1d(k562_intra_comp_df['gene_name'], comp_df['gene_name'])
fig, ax = plt.subplots(figsize=(6,6))
k562_intra_comp_df[k562_intra_comp_df['gene_name'].isin(genes_to_show)]['sim'].hist(
    bins=30, alpha=0.5, label='K562 day 6 vs K562 day 8', color='tab:blue'
)
comp_df[comp_df['gene_name'].isin(genes_to_show) & (comp_df['k'] == 50)]['sim'].hist(
    bins=30, alpha=0.5, label='Predicted from hESC (K=50)', color='tab:orange'
)
#comp_df[comp_df['gene_name'].isin(genes_to_show) & (comp_df['k'] == 0)]['sim'].hist(
#    bins=30, alpha=0.5, label='K562 vs hESC', color='tab:green'
#)
ax.grid(False)
ax.legend()
ax.set_xlabel('Correlation')
ax.set_ylabel('Count')