In [None]:
import pandas as pd
import numpy as np
import hickle as hkl
import matplotlib.pyplot as plt
from tqdm import tqdm
import random

In [None]:
embedding = hkl.load("datasets/final_X_tcga_processed.hkl")
gene_effects_df = hkl.load("datasets/CRISPRGeneEffect_processed.hkl")
feature_importance_df = hkl.load("datasets/feature_importances.hkl")

# Figure 3A

In [None]:
def make_big_sl_plot():
    feature_importances = feature_importance_df
    top_features = feature_importances.idxmax()

    top_deviations = feature_importances.max() - feature_importances.mean()
    
    ind = top_deviations.argsort()[::-1]
    top_kos = feature_importances.columns[ind]
    top_deviations = top_deviations[ind]

    count = 0
    pairs = []
    scores = []
    for i in range(len(feature_importances.columns)):
        ko = top_kos[i]
        try:
            my_feature = top_features[ko]
            top_dev = top_deviations[i]
            pairs.append((ko,my_feature))
            scores.append(top_dev)
            print((ko,my_feature))
        except:
            continue
        if top_dev < 0.007101906647612656:
            break
    
    return pairs,scores

pairs,scores = make_big_sl_plot()

pos_pairs = []
pos_scores = []
neg_pairs = []
neg_scores = []
driver_pairs = []
driver_scores = []
for i in range(len(pairs)):
    p1,p2 = pairs[i]
    score = scores[i]
    val = pcc[p1].loc[p2]
    if p2.split("_")[0] == p1:
        driver_pairs.append((p1,p2))
        driver_scores.append(score)
    elif val > 0:
        pos_pairs.append((p1,p2))
        pos_scores.append(score)
    else:
        neg_pairs.append((p1,p2))
        neg_scores.append(score)

str_pos_pairs = [x + " / " + y.split("_")[0] for x,y in pos_pairs]
str_neg_pairs = [x + " / " + y.split("_")[0] for x,y in neg_pairs]
str_pairs = [x + " / " + y.split("_")[0] for x,y in pairs]
str_driver_pairs = [x + " / " + y.split("_")[0] for x,y in driver_pairs]

fig = plt.figure(figsize=(16,4))
plt.bar(str_pairs,scores,label="Dependency on Underexpression",color="seagreen")
plt.bar(str_neg_pairs,neg_scores,label="Dependency on Overexpression",color="darkorange")
plt.bar(str_driver_pairs,driver_scores,label="Self Pairs",color="tab:pink")

fig.canvas.draw()
plt.xticks(rotation=90)
plt.ylabel("Score")
plt.legend(loc="best")
plt.title(f"Top {len(pairs)} Proposed SL Pairs")
locs,labels = plt.xticks()

plt.show()

# Figure 3B

In [None]:
def get_cancer_gene_stats():
    feature_importances = hkl.load("datasets/feature_importances.hkl")
    top_features = feature_importances.idxmax()
    std = feature_importances.std()

    top_deviations = feature_importances.max() - feature_importances.mean()
    
    ind = top_deviations.argsort()[::-1]
    top_kos = feature_importances.columns[ind]
    top_deviations = top_deviations[ind]

    depmap_df = pd.read_csv("datasets/Chronos_Combined_predictability_results.csv")
    depmap_df["gene"] = [g.split()[0] for g in depmap_df["gene"]]

    cancer_genes = pd.read_csv("datasets/cancerGeneList.tsv",sep="\t")
    cancer_genes = cancer_genes["Hugo Symbol"]

    driver = 0
    drivers = []
    count = 0
    counts = []
    scores = []
    for i in range(len(top_deviations)):
        ko = top_kos[i]

        row = depmap_df[(depmap_df["gene"] == ko) & (depmap_df["best"])]
        our_ranks = feature_importances[ko].sort_values(ascending=False)
        features_ranked = [g.split("_")[0] for g in our_ranks.index]

        try:
            depmap_feature = row["feature0"].values[0]
            my_feature = top_features[ko]
            top_dev = top_deviations[i]

            if ko == my_feature.split("_")[0]:
                if ko in list(cancer_genes):
                    driver += 1
                count += 1
                drivers.append(driver)
                counts.append(count)
                scores.append(top_dev)
        except:
            continue
    
    return counts,drivers,scores

counts,drivers,scores = get_cancer_gene_stats()

fig, ax1 = plt.subplots()
data1 = np.array(drivers)/np.array(counts)

ax1.set_xlabel('Number of Self-Pairs, Sorted by Score')
ax1.set_ylabel('Percent', color="black")
ax1.plot(counts, (data1*100), label="SL-RFM")
plt.title("Percent of Self-Pairs that are OncoKB Oncogenes")
ax1.tick_params(axis='y', labelcolor="black")
plt.show()

# Figure 4A

In [None]:
tcga_kos = [p1 for p1,p2 in pairs if p1 != p2.split("_")[0]]
neg_kos = [p1 for p1,p2 in neg_pairs if p1 != p2.split("_")[0]]
neg_features = [p2 for p1,p2 in neg_pairs if p1 != p2.split("_")[0]]

tcga_features = [p2 for p1,p2 in pairs if p1 != p2.split("_")[0]]
sample_info = pd.read_csv("datasets/sample_info.csv")

common_cells = gene_effects_df.index.intersection(cell_embedding.index)
gene_effects_df = gene_effects_df.loc[common_cells]
cell_embedding = cell_embedding.loc[common_cells]

all_diseases = [sample_info[sample_info["DepMap_ID"] == cell]["primary_disease"].values[0] for cell in gene_effects_df.index]
seen = set()
diseases = []
for d in all_diseases:
    if d not in seen:
        seen.add(d)
        diseases.append(d)
diseases.sort()

gene_effects_df = gene_effects_df[tcga_kos]
cell_embedding = cell_embedding[tcga_features]
cell_embedding.columns = tcga_kos

diseases = [d for d in diseases if d not in ("Unknown","Teratoma","Non-Cancerous")]
data = pd.DataFrame((gene_effects_df * cell_embedding).values,index=cell_embedding.index,columns=tcga_kos)
data["disease"] = all_diseases
data = data[~(data["disease"]).isin(("Unknown","Teratoma","Non-Cancerous"))]
data = data.groupby("disease").sum()

import matplotlib.colors as colors

xticks = np.array([p1 + " / " + p2.split("_")[0] for p1,p2 in pairs if p1 != p2.split("_")[0]])

vmin = data.min().min()
vmax = data.max().max()
norm = colors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)

fig, ax = plt.subplots(figsize=(20,50))
c = ax.imshow(data,cmap="RdBu",norm=norm)
ax.set_xticks(np.arange(data.shape[1]))
ax.set_yticks(np.arange(data.shape[0]))
ax.set_xticklabels(xticks)
ax.set_yticklabels(data.index)
fig.colorbar(c,fraction=0.01, pad=0.005)
plt.xticks(rotation=90)

locs,labels = plt.xticks()
i_labels = [i for i in range(len(labels)) if labels[i].get_text().split()[0] in ["RPS4X","CHMP4B","PELO","MED1","SCAP","SOX10"]]

sub_labels = np.array(labels)[i_labels]
for txt in sub_labels:
    txt.set_weight("bold")

i_labels = [i for i in range(len(labels)) if labels[i].get_text().split()[0] in neg_kos]
sub_labels = np.array(labels)[i_labels]
for txt in sub_labels:
    txt.set_color("darkorange")

# plt.xlabel("Feature / Knockout")
# plt.ylabel("DepMap Cell Line Primary Disease")
plt.title("Selectivity of SL Pairs by Cancer Type")

plt.show()

# Figure 4B

In [None]:
tcga = hkl.load("datasets/tcga_data_processed_figures.hkl")

In [None]:
def get_underexpression_percentages(pairs):
    tcga_arr = np.zeros((22))
    for low_expression_cutoff in np.arange(22):
        low_exp_mask_tcga = tcga < low_expression_cutoff
        low_exp_mask_tcga = low_exp_mask_tcga.sum(axis=0)
        num_genes_tcga = 0
        for gene1,gene2 in pairs:
            if gene1 in low_exp_mask_tcga.index and gene2 in low_exp_mask_tcga.index:
                num_genes_tcga += 1
                try:
                    if low_exp_mask_tcga[gene1] * low_exp_mask_tcga[gene2] > 0:
                        tcga_arr[low_expression_cutoff] += 1
                except:
                    if (low_exp_mask_tcga[gene1] * low_exp_mask_tcga[gene2]).values[0] > 0:
                        tcga_arr[low_expression_cutoff] += 1
        
    return tcga_arr/num_genes_tcga

tcga_arr_sl = get_underexpression_percentages([(x+"_exp",y) for x,y in pos_pairs])

number_samples = 10
tcga_arr_rand = np.zeros((22))
random_pairs_tcga = []
for _ in tqdm(range(number_samples)):
    this_random_pairs = np.random.choice(exp_cols,size=(67,2))
    this_random_pairs = [(this_random_pairs[i,0],this_random_pairs[i,1]) for i in range(67)]
    
    x = get_underexpression_percentages(this_random_pairs)
    random_pairs_tcga.append(x)

random_pairs_tcga = np.array(random_pairs_tcga)

plt.figure(figsize=(9,7))

plt.plot(np.arange(22),tcga_arr_sl * 100,label="Proposed SL Pairs with Dependency on Underexpression")
y = random_pairs_tcga.mean(axis=0) * 100
error = random_pairs_tcga.std(axis=0) * 100
plt.plot(np.arange(22),y,label="Random Pairs")
plt.fill_between(np.arange(22), y-error, y+error,color="peachpuff")

plt.legend(loc="best")
plt.xlabel("Expression Cutoff")
plt.ylabel("Percent of Gene Pairs with Expression Below Cutoff")
plt.title("Joint Expression of Random Gene Pairs and Proposed SL Pairs")
plt.show()

# Figure 4C

In [None]:
for ko,gene in [("PELO","KLHL9_exp"),("SCAP","MVK_exp"),("SOX10","CDH19_exp"),("MED1","MED31_exp")]:
    fig = plt.figure(figsize=(5,5))

    x,y = tcga[gene],tcga[ko + "_exp"]

    plt.scatter(x,y,zorder=-10,color="chocolate",alpha=0.6,label="TCGA")

    plt.xticks(np.arange(19))
    plt.yticks(np.arange(19))
    plt.xlabel(f"Expression of {gene}")
    plt.ylabel(f"Expression of {ko}")
    plt.title(f"Expression of {gene} and {ko}")
    plt.show()