In [None]:
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

In [None]:
import os

from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import anndata
import scanpy as sc
import umap
import scipy.stats

import torch

from scmg.model.contrastive_embedding import (CellEmbedder,  embed_adata)

from scmg.preprocessing.data_standardization import GeneNameMapper
gene_name_mapper = GeneNameMapper()


In [None]:
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams["figure.autolayout"] = False
matplotlib.rc('pdf', fonttype=42)
plt.rcParams['font.family'] = 'FreeSans'
sc.set_figure_params(vector_friendly=True, dpi_save=300)
plt.rcParams['axes.grid'] = False

In [None]:
# Load the autoencoder model
model_ce_path = '../../contrastive_embedding/trained_embedder/'

model_ce = torch.load(os.path.join(model_ce_path, 'model.pt'))
model_ce.load_state_dict(torch.load(os.path.join(model_ce_path, 'best_state_dict.pth')))

device = 'cuda:0'
model_ce.to(device)
model_ce.eval()

In [None]:
pert_data_files = [
    '/GPUData_xingjie/SCMG/perturbation_data/AdamsonWeissman2016_GSM2406681_10X010.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/FrangiehIzar2021_RNA.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/hESC_TF_screen.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_IFNG.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_INS.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TGFB.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/JiangSatija2024_TNFA.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/Joung_TFScreen_HS_2023.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_human.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/knockTF_mouse.h5ad',
    #'/GPUData_xingjie/SCMG/perturbation_data/omnipath.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/PertOrg.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_essential.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_K562_gwps.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/ReplogleWeissman2022_rpe1.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRa.h5ad',
    '/GPUData_xingjie/SCMG/perturbation_data/TianKampmann2021_CRISPRi.h5ad',
    #'/GPUData_xingjie/SCMG/hESC_perturb_seq/pseudo_bulk.h5ad', # Test
]

adata_pert_list = []
for pdf in pert_data_files:
    adata_pert_list.append(sc.read_h5ad(pdf))
    print(os.path.basename(pdf), adata_pert_list[-1].shape[0])

adata_pert = anndata.concat(adata_pert_list, axis=0)
adata_pert.var['gene_name'] = adata_pert_list[0].var['gene_name']

adata_pert

In [None]:
adata_pert_ctl = adata_pert.copy()
adata_pert_ctl.X = np.exp(adata_pert_ctl.layers['control']) - 1
embed_adata(model_ce, adata_pert_ctl, batch_size=8192)

adata_pert.obsm['X_ctl_ce_latent'] = adata_pert_ctl.obsm['X_ce_latent']
adata_pert

In [None]:
import scipy.spatial

pert_compare_dict = {
    'pert_id1' : [],
    'pert_id2' : [],
    'perturbed_gene' : [],
    'perturbed_gene_name' : [],
    'perturbation_sign1' : [],
    'perturbation_sign2' : [],
    'mag1' : [],
    'mag2' : [],
    'cosine_similarity' : [],
    'emb_distance' : [],
}

for i in tqdm(range(adata_pert.shape[0] - 1)):
    perturbed_gene1 = adata_pert.obs['perturbed_gene'].iloc[i]
    perturbation_sign1 = adata_pert.obs['perturbation_sign'].iloc[i]
    v1 = adata_pert.X[i]
    mag1 = np.linalg.norm(v1)
    emb1 = adata_pert.obsm['X_ctl_ce_latent'][i]

    for j in range(i+1, adata_pert.shape[0]):
        perturbed_gene2 = adata_pert.obs['perturbed_gene'].iloc[j]
        perturbation_sign2 = adata_pert.obs['perturbation_sign'].iloc[j]

        if perturbed_gene1 != perturbed_gene2:
            continue

        v2 = adata_pert.X[j]
        mag2 = np.linalg.norm(v2)
        emb2 = adata_pert.obsm['X_ctl_ce_latent'][j]

        # Mask out the perturbed gene itself
        mask = np.ones(adata_pert.shape[1], dtype=bool)
        if perturbed_gene1 in adata_pert.var.index:
            mask[adata_pert.var.index.get_loc(perturbed_gene1)] = False

        cos_sim = 1 - scipy.spatial.distance.cosine(v1[mask], v2[mask])
        emb_dist = np.linalg.norm(emb1 - emb2)

        pert_compare_dict['pert_id1'].append(adata_pert.obs.index[i])
        pert_compare_dict['pert_id2'].append(adata_pert.obs.index[j])
        pert_compare_dict['perturbed_gene'].append(perturbed_gene1)
        pert_compare_dict['perturbed_gene_name'].append(adata_pert.obs['perturbed_gene_name'].iloc[i])
        pert_compare_dict['perturbation_sign1'].append(perturbation_sign1)
        pert_compare_dict['perturbation_sign2'].append(perturbation_sign2)
        pert_compare_dict['mag1'].append(mag1)
        pert_compare_dict['mag2'].append(mag2)
        pert_compare_dict['cosine_similarity'].append(cos_sim)
        pert_compare_dict['emb_distance'].append(emb_dist)

pert_compare_df = pd.DataFrame(pert_compare_dict)
pert_compare_df = pert_compare_df[pert_compare_df['emb_distance'] != 0]
pert_compare_df.to_parquet('pert_compare_df.parquet')
pert_compare_df

In [None]:
pert_compare_df = pd.read_parquet('pert_compare_df.parquet')
pert_compare_df = pert_compare_df[(~pert_compare_df['pert_id1'].str.startswith('hESC_perturb_seq')) 
                            & (~pert_compare_df['pert_id2'].str.startswith('hESC_perturb_seq'))]

# Remove potentional duplates
pert_compare_df = pert_compare_df[np.abs(pert_compare_df['cosine_similarity']) < 0.95]

sim_sign_pert_compare_df = pert_compare_df[pert_compare_df['perturbation_sign1'] == pert_compare_df['perturbation_sign2']].copy()
diff_sign_pert_compare_df = pert_compare_df[pert_compare_df['perturbation_sign1'] != pert_compare_df['perturbation_sign2']].copy()

In [None]:
np.mean(sim_sign_pert_compare_df['cosine_similarity']), np.std(sim_sign_pert_compare_df['cosine_similarity'])

In [None]:
np.mean(diff_sign_pert_compare_df['cosine_similarity']), np.std(diff_sign_pert_compare_df['cosine_similarity'])

In [None]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import explained_variance_score

X = np.array([
    sim_sign_pert_compare_df['emb_distance'].values,
    np.minimum(sim_sign_pert_compare_df['mag1'], sim_sign_pert_compare_df['mag2'])
    #sim_sign_pert_compare_df['mag1'].values,
    #sim_sign_pert_compare_df['mag2'].values
    ]).T
y = sim_sign_pert_compare_df['cosine_similarity'].values

reg = LinearRegression().fit(X, y)

y_pred = reg.predict(X)
explained_variance_score(y, y_pred)

In [None]:
import scipy.stats

plt.scatter(sim_sign_pert_compare_df['emb_distance'], sim_sign_pert_compare_df['cosine_similarity'],
            c=np.minimum(sim_sign_pert_compare_df['mag1'], sim_sign_pert_compare_df['mag2']), 
            vmin=0, vmax=18, s=0.1, cmap='gnuplot')

scipy.stats.pearsonr(sim_sign_pert_compare_df['emb_distance'], sim_sign_pert_compare_df['cosine_similarity'])

In [None]:
import scipy.stats

plt.scatter(diff_sign_pert_compare_df['emb_distance'], diff_sign_pert_compare_df['cosine_similarity'],
            c=np.minimum(diff_sign_pert_compare_df['mag1'], diff_sign_pert_compare_df['mag2']), 
            vmin=0, vmax=18, s=0.1, cmap='gnuplot')

scipy.stats.pearsonr(diff_sign_pert_compare_df['emb_distance'], diff_sign_pert_compare_df['cosine_similarity'])

In [None]:
sns.jointplot(data=sim_sign_pert_compare_df, x='emb_distance', y='cosine_similarity', kind='hex', bins='log',
              xlim=(-0.5, 9), ylim=(-0.7, 1))

In [None]:
sns.jointplot(data=diff_sign_pert_compare_df, x='emb_distance', y='cosine_similarity', kind='hex', bins='log',
              xlim=(-0.5, 9), ylim=(-0.7, 1))

In [None]:
x = np.minimum(sim_sign_pert_compare_df['mag1'], sim_sign_pert_compare_df['mag2']).values
y = sim_sign_pert_compare_df['cosine_similarity'].values
z = sim_sign_pert_compare_df['emb_distance'].values

order = np.arange(len(x))
np.random.shuffle(order)

fig, ax = plt.subplots(figsize=(5.3, 4), dpi=300)
sct = ax.scatter(x[order], y[order], c=z[order], 
            vmin=0, vmax=8, s=0.6, alpha=1, cmap='cool_r', edgecolors='none',
            rasterized=True)
fig.colorbar(sct, label='embedding distance')
ax.set_xlim(0, 20)
ax.set_ylim(-0.5, 1)
ax.set_xlabel('Min perturbation magnitude')
ax.set_ylabel('Cosine similarity')

scipy.stats.pearsonr(x, y)

fig.savefig('pert_dataset_stats/pert_condition_compare_scatter.pdf')

In [None]:
x = np.minimum(sim_sign_pert_compare_df['mag1'], sim_sign_pert_compare_df['mag2']).values
y = sim_sign_pert_compare_df['cosine_similarity'].values
z = sim_sign_pert_compare_df['emb_distance'].values

order = np.arange(len(x))
np.random.shuffle(order)

fig, ax = plt.subplots(figsize=(5.3, 4), dpi=300)
sct = ax.scatter(x[order], z[order], c=y[order], 
             cmap='coolwarm', s=1, vmin=-0.7, vmax=0.7, alpha=1, edgecolors='none')
fig.colorbar(sct, label='cosine similarity')
ax.set_xlim(0, 20)

ax.set_xlabel('Min perturbation magnitude')
ax.set_ylabel('Embedding distance')

In [None]:
selected_pert_compare_df = sim_sign_pert_compare_df[
    np.logical_xor(sim_sign_pert_compare_df['pert_id1'].str.startswith('ReplogleWeissman2022'),
    sim_sign_pert_compare_df['pert_id2'].str.startswith('ReplogleWeissman2022'))
]

plt.scatter(selected_pert_compare_df['emb_distance'], selected_pert_compare_df['cosine_similarity'], s=1)
scipy.stats.pearsonr(selected_pert_compare_df['emb_distance'], selected_pert_compare_df['cosine_similarity'])