In [None]:
import pandas as pd
import numpy as np
import hickle as hkl
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

In [None]:
embedding = hkl.load("datasets/final_X_tcga_processed.hkl")
gene_effects_df = hkl.load("datasets/CRISPRGeneEffect_processed.hkl")
feature_importance_df = hkl.load("datasets/feature_importances.hkl")

# Figure 3A

In [None]:
def make_big_sl_plot():
    feature_importances = feature_importance_df
    top_features = feature_importances.idxmax()

    top_deviations = feature_importances.max() - feature_importances.mean()
    
    ind = top_deviations.argsort()[::-1]
    top_kos = feature_importances.columns[ind]
    top_deviations = top_deviations[ind]

    count = 0
    pairs = []
    scores = []
    for i in range(100):
        ko = top_kos[i]
        try:
            my_feature = top_features[ko]
            top_dev = top_deviations[i]
            pairs.append((ko,my_feature))
            scores.append(top_dev)
            print((ko,my_feature))
        except:
            continue
    
    return pairs,scores

pairs,scores = make_big_sl_plot()

pos_pairs = []
pos_scores = []
neg_pairs = []
neg_scores = []
driver_pairs = []
driver_scores = []
for i in range(len(pairs)):
    p1,p2 = pairs[i]
    score = scores[i]
    val = pcc[p1].loc[p2]
    if p2.split("_")[0] == p1:
        driver_pairs.append((p1,p2))
        driver_scores.append(score)
    elif val > 0:
        pos_pairs.append((p1,p2))
        pos_scores.append(score)
    else:
        neg_pairs.append((p1,p2))
        neg_scores.append(score)

str_pos_pairs = [x + " / " + y.split("_")[0] for x,y in pos_pairs]
str_neg_pairs = [x + " / " + y.split("_")[0] for x,y in neg_pairs]
str_pairs = [x + " / " + y.split("_")[0] for x,y in pairs]
str_driver_pairs = [x + " / " + y.split("_")[0] for x,y in driver_pairs]

fig = plt.figure(figsize=(16,4))
plt.bar(str_pairs,scores,label="Dependency on Underexpression",color="seagreen")
plt.bar(str_neg_pairs,neg_scores,label="Dependency on Overexpression",color="darkorange")
plt.bar(str_driver_pairs,driver_scores,label="Self Pairs",color="tab:pink")

fig.canvas.draw()
plt.xlabel("Knockout / Dependency")
plt.xticks(rotation=90)
plt.ylabel("Score")
plt.legend(loc="best")
plt.title("Top 100 Proposed SL Pairs")
locs,labels = plt.xticks()
i_labels = [i for i in range(len(labels)) if labels[i].get_text().split()[0] in ["RPS4X","CHMP4B","PELO","MED1"]]

sub_labels = np.array(labels)[i_labels]
for txt in sub_labels:
    txt.set_weight("bold")
    # txt.set_color("red")
# plt.yscale('log')
plt.show()

# Supplementary Figure 3C (Similar to Figure 3B)

In [None]:
def get_our_self_pairs():
    feature_importances = feature_importance_df
    top_features = feature_importances.idxmax()
    top_deviations = feature_importances.max() - feature_importances.mean()
    
    ind = top_deviations.argsort()[::-1]
    top_kos = feature_importances.columns[ind]
    top_deviations = top_deviations[ind]

    cancer_genes = pd.read_csv("datasets/cancerGeneList.tsv",sep="\t")
    cancer_genes = cancer_genes["Hugo Symbol"]

    driver = 0
    drivers = []
    count = 0
    counts = []
    scores = []
    for i in range(len(top_deviations)):
        ko = top_kos[i]
        try:
            my_feature = top_features[ko]
            top_dev = top_deviations[i]

            if ko == my_feature.split("_")[0]:
                if ko in list(cancer_genes):
                    driver += 1
                count += 1
                print(driver,count,driver/count,top_dev)
                drivers.append(driver)
                counts.append(count)
                scores.append(top_dev)
        except:
            continue
    
    return counts,drivers,scores

def get_depmap_self_pairs():
    driver = 0
    drivers = []
    count = 0
    counts = []
    # the following dataset is publically available on DepMap:
    depmap_rankings = pd.read_csv("datasets/Chronos_Combined_predictability_results.csv")
    for index, row in depmap_rankings.iterrows():
        ko = row["gene"].split()[0]
        top_depmap_feature = row["feature0"].split("_")[0]

        if top_depmap_feature == ko:
            if ko in list(cancer_genes):
                driver += 1
            count += 1
            print(driver,count,driver/count)
            drivers.append(driver)
            counts.append(count)
    return counts,drivers

counts,drivers,scores = get_our_self_pairs()
counts2,drivers2 = get_ko_self_pairs()

fig, ax1 = plt.subplots()
data1 = np.array(drivers)/np.array(counts)
data2 = np.array(drivers2)/np.array(counts2)

ax1.plot(counts, data1, label="Our Model")
ax1.plot(counts2, data2, label="DepMap")
fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.show()

# Figure 4A

In [None]:
tcga_kos = [p1 for p1,p2 in pairs if p1 != p2.split("_")[0]]
neg_kos = [p1 for p1,p2 in neg_pairs if p1 != p2.split("_")[0]]
neg_features = [p2 for p1,p2 in neg_pairs if p1 != p2.split("_")[0]]

tcga_features = [p2 for p1,p2 in pairs if p1 != p2.split("_")[0]]
gene_effects_df = hkl.load("datasets/2023/CRISPRGeneEffect_processed.hkl").fillna(0)
# the following dataset is publically available on DepMap:
sample_info = pd.read_csv("datasets/sample_info.csv")
cell_embedding = hkl.load("embeddings/final_X_tcga_processed.hkl")
common_cells = gene_effects_df.index.intersection(cell_embedding.index)
gene_effects_df = gene_effects_df.loc[common_cells]
cell_embedding = cell_embedding.loc[common_cells]

diseases = [sample_info[sample_info["DepMap_ID"] == cell]["primary_disease"].values[0] for cell in gene_effects_df.index]

gene_effects_df = gene_effects_df[tcga_kos]
cell_embedding = cell_embedding[tcga_features]
data = pd.DataFrame(gene_effects_df.values * cell_embedding.values,index=gene_effects_df.index)
data["disease"] = diseases
data = data.groupby("disease").median()

xticks = np.array([p1 + " / " + p2.split("_")[0] for p1,p2 in pairs if p1 != p2.split("_")[0]])
fig, ax = plt.subplots(figsize=(20,50))
c = ax.imshow(data,cmap="RdBu")
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
ax.set_xticklabels(xticks)
ax.set_yticklabels(gene_effects_df_grouped.index)
fig.colorbar(c,fraction=0.01, pad=0.005)
plt.xticks(rotation=90)
# ax.xaxis.set_label_position('top') 

locs,labels = plt.xticks()
i_labels = [i for i in range(len(labels)) if labels[i].get_text().split()[0] in ["RPS4X","CHMP4B","PELO","MED1"]]

sub_labels = np.array(labels)[i_labels]
for txt in sub_labels:
    txt.set_weight("bold")

i_labels = [i for i in range(len(labels)) if labels[i].get_text().split()[0] in neg_kos]
sub_labels = np.array(labels)[i_labels]
for txt in sub_labels:
    txt.set_color("darkorange")
    txt.set_weight("bold")

plt.xlabel("Feature / Knockout")
plt.ylabel("DepMap Cell Line Primary Disease")
plt.title("Strength of SL Interaction by Primary Disease")

plt.show()

# Figure 4B

In [None]:
tcga = hkl.load("datasets/tcga_data_processed_figures.hkl")

for gene,ko in [("ZFY","RPS4X"),("CHMP4A","CHMP4B"),("KLHL9","PELO"),("MED31","MED1")]:

    fig = plt.figure(figsize=(5,5))

    x,y = tcga[gene + "_exp"],tcga[ko + "_exp"]

    plt.scatter(x,y,zorder=-10,color="mediumaquamarine",alpha=0.6,label="TCGA")

    plt.xticks(np.arange(19))
    plt.yticks(np.arange(19))
    plt.xlabel(f"log2(TPM + 1) of {gene}")
    plt.ylabel(f"log2(TPM + 1) of {ko}")
    plt.show()