In [None]:
import os
os.environ["KMP_WARNINGS"] = "off"
import warnings
warnings.filterwarnings("ignore")

In [None]:
from pathlib import Path
import sys

sys.path.insert(1, str(Path().cwd().parent))
import common_config

In [None]:
common_config.figure_journal_basic()

In [None]:
import pertpy as pt
import scanpy as sc
import matplotlib.pyplot as plt
from sklearn.metrics import silhouette_samples
import pandas as pd
import seaborn as sns

In [None]:
G1_CYCLE = [
    "CDKN1A",
    {"CDKN1B", "CDKN1A"},
    "CDKN1B",
    {"CDKN1C", "CDKN1A"},
    {"CDKN1C", "CDKN1B"},
    "CDKN1C",
]

ERYTHROID = [
    {"CBL", "CNN1"},
    {"CBL", "PTPN12"},
    {"CBL", "PTPN9"},
    {"CBL", "UBASH3B"},
    {"SAMD1", "PTPN12"},
    {"SAMD1", "UBASH3B"},
    {"UBASH3B", "CNN1"},
    {"UBASH3B", "PTPN12"},
    {"UBASH3B", "PTPN9"},
    {"UBASH3B", "UBASH3A"},
    {"UBASH3B", "ZBTB25"},
    {"BPGM", "SAMD1"},
    "PTPN1",
    {"PTPN12", "PTPN9"},
    {"PTPN12", "UBASH3A"},
    {"PTPN12", "ZBTB25"},
    {"UBASH3A", "CNN1"},
]

PIONEER_FACTORS = [
    {"FOXA1", "FOXF1"},
    {"FOXA1", "FOXL2"},
    {"FOXA1", "HOXB9"},
    {"FOXA3", "FOXA1"},
    {"FOXA3", "FOXF1"},
    {"FOXA3", "FOXL2"},
    {"FOXA3", "HOXB9"},
    "FOXA3",
    {"FOXF1", "FOXL2"},
    {"FOXF1", "HOXB9"},
    {"FOXL2", "MEIS1"},
    "HOXA13",
    "HOXC13",
    {"POU3F2", "FOXL2"},
    "TP73",
    "MIDN",
    {"LYL1", "IER5L"},
    "HOXC13",
    {"DUSP9", "SNAI1"},
    {"ZBTB10", "SNAI1"},
]

GRANULOCYTE_APOPTOSIS = [
    "SPI1",
    "CEBPA",
    {"CEBPB", "CEBPA"},
    "CEBPB",
    {"CEBPE", "CEBPA"},
    {"CEBPE", "CEBPB"},
    {"CEBPE", "RUNX1T1"},
    {"CEBPE", "SPI1"},
    "CEBPE",
    {"ETS2", "CEBPE"},
    {"KLF1", "CEBPA"},
    {"FOSB", "CEBPB"},
    {"FOSB", "CEBPE"},
    {"ZC3HAV1", "CEBPA"},
    {"JUN", "CEBPA"},
]

PRO_GROWTH = [
    {"CEBPE", "KLF1"},
    "KLF1",
    {"KLF1", "BAK1"},
    {"KLF1", "MAP2K6"},
    {"KLF1", "TGFBR2"},
    "ELMSAN1",
    {"MAP2K3", "SLC38A2"},
    {"MAP2K3", "ELMSAN1"},
    "MAP2K3",
    {"MAP2K3", "MAP2K6"},
    {"MAP2K6", "ELMSAN1"},
    "MAP2K6",
    {"MAP2K6", "KLF1"},
]

MEGAKARYOCYTE = [
    {"MAPK1", "TGFBR2"},
    "MAPK1",
    {"ETS2", "MAPK1"},
    "ETS2",
    {"CEBPB", "MAPK1"},
]

programmes = {
    "G1 cell cycle": G1_CYCLE,
    "Erythroid": ERYTHROID,
    "Pioneer factors": PIONEER_FACTORS,
    "Granulocyte apoptosis": GRANULOCYTE_APOPTOSIS,
    "Pro-growth": PRO_GROWTH,
    "Megakaryocyte": MEGAKARYOCYTE,
}

In [None]:
adata = sc.read_h5ad("data/norman_preprocessed.h5ad")
adata.obs["guide_ids"] = adata.obs["guide_ids"].cat.rename_categories({"": "control"})
adata

In [None]:
gene_programme = []

for target_pert in adata.obs["perturbation_name"]:
    if target_pert == "control":
        gene_programme.append("Control")
        continue

    found_programme = False
    for programme, pert_list in programmes.items():
        for pert in pert_list:
            if (type(pert) == set and pert == set(target_pert.split("+"))) or (
                target_pert == pert
            ):
                gene_programme.append(programme)
                found_programme = True
                break

    if not found_programme:
        gene_programme.append("Unknown")

adata.obs["gene_programme"] = gene_programme
adata.obs["gene_programme"] = adata.obs["gene_programme"].astype("category")

# Nearest neighbor Mixscape

In [None]:
nn_adata = adata.copy()

In [1]:
ms_pt = pt.tl.Mixscape()
ms_pt.perturbation_signature(nn_adata, pert_key="perturbation_name", control="control")

In [None]:
adata_pert = nn_adata.copy()
adata_pert.X = adata_pert.layers["X_pert"]
sc.pp.pca(adata_pert)
sc.pp.neighbors(adata_pert, metric="cosine")
sc.tl.umap(adata_pert)
sc.pl.umap(adata_pert, color="phase", palette="Set3", show=False)
plt.savefig("figures/nn_mixscape_phase_corrected_umap.png", bbox_inches="tight")

In [None]:
ms_pt.mixscape(adata=nn_adata, control="control", labels="guide_ids", layer="X_pert")

In [None]:
plt.rcParams["figure.figsize"] = (5, 4)
ms_pt.plot_perturbscore(
    adata=nn_adata,
    labels="guide_ids",
    target_gene="IGDCC3,TGFBR2",
    palette={
        "control": common_config.pt_blue,
        "IGDCC3,TGFBR2 NP": common_config.pt_orange,
        "IGDCC3,TGFBR2 KO": common_config.pt_red,
    },
)
plt.savefig("figures/nn_mixscape_perturb_score_example.png", bbox_inches="tight")

In [None]:
adata_pert.obs["mixscape_class_global"] = nn_adata.obs["mixscape_class_global"]
sc.pl.umap(
    adata_pert,
    color="mixscape_class_global",
    palette={
        "control": common_config.pt_blue,
        "KO": common_config.pt_red,
        "NP": common_config.pt_orange,
    },
    title="Mixscape class",
    show=False
)
plt.savefig("figures/nn_mixscape_pert_status_mixscape_corrected_umap.png", bbox_inches="tight")

In [None]:
classification_counts = nn_adata.obs["mixscape_class_global"].value_counts()

colors = {
    "control": common_config.pt_blue,
    "KO": common_config.pt_red,
    "NP": common_config.pt_orange,
}
plt.figure(figsize=(3.5, 3))
classification_counts.plot(
    kind="bar", color=[colors[class_] for class_ in classification_counts.index]
)
plt.xlabel("Mixscape Classification")
plt.ylabel("Cell Count")
plt.savefig("figures/nn_mixscape_corrected_class_barplot.png", bbox_inches="tight")
plt.show()

In [None]:
nn_adata_mixscape_cleaned = nn_adata[nn_adata.obs["mixscape_class_global"] != "NP"]
nn_adata_mixscape_cleaned

In [None]:
ps = pt.tl.MLPClassifierSpace()
pert_embeddings = ps.compute(
    nn_adata_mixscape_cleaned,
    target_col="perturbation_name",
    hidden_dim=[512, 256],
    dropout=0.05,
    batch_size=64,
    batch_norm=True,
    max_epochs=5,
)

In [None]:
ps = pt.tl.PseudobulkSpace()
nn_psadata_classifier = ps.compute(
    pert_embeddings,
    target_col="perturbations",
    groups_col="perturbations",
    mode="mean",
    min_cells=0,
    min_counts=0,
)

In [None]:
sc.pp.neighbors(nn_psadata_classifier, use_rep="X")
sc.tl.umap(nn_psadata_classifier)
sc.pl.umap(nn_psadata_classifier, color="perturbations", show=False)
plt.savefig("figures/nn_mixscape_discriminator_perturbation_name_umap.png", bbox_inches="tight")

In [None]:
sc.pl.umap(nn_psadata_classifier, color="gene_programme", show=False)
plt.savefig("figures/nn_mixscape_discriminator_gene_programme.png", bbox_inches="tight")

In [None]:
for key in nn_psadata_classifier.obs.keys():
    nn_psadata_classifier.obs[key] = nn_psadata_classifier.obs[key].astype("category")
sc.write("data/norman/nn_mixscape_psadata_classifier.h5ad", nn_psadata_classifier)

# Gemgroup Mixscape

In [None]:
gemgroup_adata = adata.copy()

In [None]:
ms_pt = pt.tl.Mixscape()
ms_pt.perturbation_signature(gemgroup_adata,
                             pert_key="perturbation_name",
                             control="control",
                             ref_selection_mode="split_by",
                             split_by="gemgroup"
                             )

In [None]:
adata_pert = gemgroup_adata.copy()
adata_pert.X = adata_pert.layers["X_pert"]
sc.pp.pca(adata_pert)
sc.pp.neighbors(adata_pert, metric="cosine")
sc.tl.umap(adata_pert)
sc.pl.umap(adata_pert, color="phase", palette="Set3", show=False)
plt.savefig("figures/gemgroup_mixscape_phase_corrected_umap.png", bbox_inches="tight")

In [None]:
ms_pt.mixscape(adata=gemgroup_adata, labels="guide_ids", control="control", layer="X_pert")

In [None]:
plt.rcParams["figure.figsize"] = (5, 4)
ms_pt.plot_perturbscore(
    adata=gemgroup_adata,
    labels="guide_ids",
    target_gene="IGDCC3,TGFBR2",
    palette={
        "control": common_config.pt_blue,
        "IGDCC3,TGFBR2 NP": common_config.pt_orange,
        "IGDCC3,TGFBR2 KO": common_config.pt_red,
    },
)
plt.savefig("figures/gemgroup_mixscape_perturb_score_example.png", bbox_inches="tight")

In [None]:
adata_pert.obs["mixscape_class_global"] = gemgroup_adata.obs["mixscape_class_global"]
sc.pl.umap(
    adata_pert,
    color="mixscape_class_global",
    palette={
        "control": common_config.pt_blue,
        "KO": common_config.pt_red,
        "NP": common_config.pt_orange,
    },
    title="Mixscape class",
    show=False
)
plt.savefig("figures/gemgroup_mixscape_pert_status_mixscape_corrected_umap.png", bbox_inches="tight")

In [None]:
classification_counts = gemgroup_adata.obs["mixscape_class_global"].value_counts()

colors = {
    "control": common_config.pt_blue,
    "KO": common_config.pt_red,
    "NP": common_config.pt_orange,
}
plt.figure(figsize=(3.5, 3))
classification_counts.plot(
    kind="bar", color=[colors[class_] for class_ in classification_counts.index]
)
plt.xlabel("Mixscape Classification")
plt.ylabel("Cell Count")
plt.savefig("figures/gemgroup_mixscape_corrected_class_barplot.png", bbox_inches="tight")
plt.show()

In [None]:
gemgroup_adata_mixscape_cleaned = gemgroup_adata[gemgroup_adata.obs["mixscape_class_global"] != "NP"]
gemgroup_adata_mixscape_cleaned

In [None]:
ps = pt.tl.MLPClassifierSpace()
pert_embeddings = ps.compute(
    gemgroup_adata_mixscape_cleaned,
    target_col="perturbation_name",
    hidden_dim=[512, 256],
    dropout=0.05,
    batch_size=64,
    batch_norm=True,
    max_epochs=5,
)

In [None]:
ps = pt.tl.PseudobulkSpace()
gemgroup_psadata_classifier = ps.compute(
    pert_embeddings,
    target_col="perturbations",
    groups_col="perturbations",
    mode="mean",
    min_cells=0,
    min_counts=0,
)

In [None]:
sc.pp.neighbors(gemgroup_psadata_classifier, use_rep="X")
sc.tl.umap(gemgroup_psadata_classifier)
sc.pl.umap(gemgroup_psadata_classifier, color="perturbations", show=False)
plt.savefig("figures/gemgroup_mixscape_discriminator_perturbation_name_umap.png", bbox_inches="tight")

In [None]:
sc.pl.umap(gemgroup_psadata_classifier, color="gene_programme", show=False)
plt.savefig("figures/gemgroup_mixscape_discriminator_gene_programme.png", bbox_inches="tight")

In [None]:
for key in gemgroup_psadata_classifier.obs.keys():
    gemgroup_psadata_classifier.obs[key] = gemgroup_psadata_classifier.obs[key].astype("category")
sc.write("data/norman/gemgroup_mixscape_psadata_classifier.h5ad", gemgroup_psadata_classifier)

# No Mixscape

In [None]:
ps = pt.tl.MLPClassifierSpace()
pert_embeddings = ps.compute(
    adata,
    target_col="perturbation_name",
    hidden_dim=[512, 256],
    dropout=0.05,
    batch_size=64,
    batch_norm=True,
    max_epochs=5,
)

In [None]:
ps = pt.tl.PseudobulkSpace()
no_mixscape_psadata_classifier = ps.compute(
    pert_embeddings,
    target_col="perturbations",
    groups_col="perturbations",
    mode="mean",
    min_cells=0,
    min_counts=0,
)

In [None]:
sc.pp.neighbors(no_mixscape_psadata_classifier, use_rep="X")
sc.tl.umap(no_mixscape_psadata_classifier)
sc.pl.umap(no_mixscape_psadata_classifier, color="perturbations", show=False)
plt.savefig("figures/no_mixscape_discriminator_perturbation_name_umap.png", bbox_inches="tight")

In [None]:
sc.pl.umap(no_mixscape_psadata_classifier, color="gene_programme", show=False)
plt.savefig("figures/no_mixscape_discriminator_gene_programme.png", bbox_inches="tight")

In [None]:
for key in no_mixscape_psadata_classifier.obs.keys():
    no_mixscape_psadata_classifier.obs[key] = no_mixscape_psadata_classifier.obs[key].astype("category")
sc.write("data/norman/no_mixscape_psadata_classifier.h5ad", no_mixscape_psadata_classifier)

# Compare Perturbation Spaces

In [None]:
nn_psadata_classifier = nn_psadata_classifier[nn_psadata_classifier.obs["gene_programme"] != "Unknown"]
nn_psadata_classifier.obs["silhouette"] = silhouette_samples(nn_psadata_classifier.obsm["X_umap"], nn_psadata_classifier.obs["gene_programme"])

In [None]:
gemgroup_psadata_classifier = gemgroup_psadata_classifier[gemgroup_psadata_classifier.obs["gene_programme"] != "Unknown"]
gemgroup_psadata_classifier.obs["silhouette"] = silhouette_samples(gemgroup_psadata_classifier.obsm["X_umap"], gemgroup_psadata_classifier.obs["gene_programme"])

In [None]:
no_mixscape_psadata_classifier = no_mixscape_psadata_classifier[no_mixscape_psadata_classifier.obs["gene_programme"] != "Unknown"]
no_mixscape_psadata_classifier.obs["silhouette"] = silhouette_samples(no_mixscape_psadata_classifier.obsm["X_umap"], no_mixscape_psadata_classifier.obs["gene_programme"])

In [None]:
# Get mean silhouette scores per gene programme for each method
nn_means = nn_psadata_classifier.obs.groupby('gene_programme')['silhouette'].mean()
gemgroup_means = gemgroup_psadata_classifier.obs.groupby('gene_programme')['silhouette'].mean()
no_mixscape_means = no_mixscape_psadata_classifier.obs.groupby('gene_programme')['silhouette'].mean()

# Create dataframe for plotting
plot_data = pd.DataFrame({
    'Neural Network': nn_means,
    'GemGroup': gemgroup_means,
    'No Mixscape': no_mixscape_means
}).reset_index()

# Melt the dataframe for easier plotting
plot_data_melted = plot_data.melt(
    id_vars=['gene_programme'],
    var_name='Method',
    value_name='Silhouette Score'
)

# Create the grouped barplot
plt.figure(figsize=(12, 6))
sns.barplot(
    data=plot_data_melted,
    x='gene_programme',
    y='Silhouette Score',
    hue='Method'
)

# Customize the plot
plt.xticks(rotation=45, ha='right')
plt.title('Mean Silhouette Scores by Gene Programme and Method')
plt.tight_layout()

# Save the figure
plt.savefig("figures/silhouette_scores_comparison.png", bbox_inches="tight")
