In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
import pathlib as pl

In [None]:
from statannotations.Annotator import Annotator

In [None]:
from tqdm.notebook import tqdm

In [None]:
from scipy.stats import pearsonr

In [None]:
from statsmodels.stats.multitest import multipletests

In [None]:
def pretty_ax(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=True,
        labelleft = True)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

In [None]:
import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as mcolors

from sklearn.decomposition import PCA

from adjustText import adjust_text

def plot_pcs_color(ax, state, annotate=False, mapping=None):
    vcenter = 0
    vmin, vmax = X_pca[state].min(), X_pca[state].max()
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = matplotlib.colormaps['RdBu_r']
    sns.scatterplot(
        y=X_pca["PC2"],
        x=X_pca["PC1"],
        c=X_pca[state],
        norm=normalize,
        cmap=colormap,
        ax=ax
    )
    scalarmappaple = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappaple.set_array(X_pca[state])
    fig.colorbar(scalarmappaple, ax=ax)
    ax.set_title(state)
    pretty_ax(ax)

    top_hits = X_pca[state].sort_values(ascending=False).head(5).index.to_numpy()
    if annotate:
        texts = []
        for cl in top_hits:
            texts.append(ax.text(X_pca["PC1"].loc[cl]+0.01, X_pca["PC2"].loc[cl], 
                     mapping[cl], ha='center', va='center',
                     size=5, color='black', weight='semibold'))
        adjust_text(texts, arrowprops=dict(arrowstyle="-", color='black', lw=0.5), ax=ax)

# Download cell line info

In [None]:
celline_dir = pl.Path("/add/path/here/")

In [None]:
celllines_metadata = pd.read_csv(celline_dir / "internal-23q2_v98-model.csv",index_col=0)

In [None]:
celllines_gex = pd.read_csv(celline_dir / "internal-23q2_v98-omicsexpressionproteincodinggenestpmlogp1.csv",index_col=0)

In [None]:
celllines_metadata_eac = celllines_metadata[celllines_metadata["DepmapModelType"]=="ESCA"]

celllines_gex_eac= celllines_gex.loc[celllines_gex.index.intersection(celllines_metadata_eac.index)].copy()

celllines_gex_eac.columns = celllines_gex_eac.columns.str.split(" \(").str[0]

In [None]:
celllines_metadata_eac[celllines_metadata_eac.CellLineName.isin(["CCLF_UPGI_0034_T","CCLF_NEURO_0046_T","OANC1","IS076A","SK-GT-4"])].iloc[:,-25:-15]

In [None]:
celllines_metadata_eac[celllines_metadata_eac.CellLineName.isin(["OE33","CCLF_UPGI_0081_T",
                                                                 "CCLF_UPGI_0070_T","CCLF_UPGI_0012_T",
                                                                 "CCLF_UPGI_0086_T"])].iloc[:,-25:-15]

# Score signatures

In [None]:
mTFs = ['KLF5', 'ELF3', 'SMAD3', 'TCF7L2', 'HMGA2',"BNC2"]

In [None]:
signature_dir = pl.Path("/add/path/here")

full_sigs = {}
for s in (signature_dir).iterdir():
    sig = s.stem
    full_sigs[sig] = pd.read_csv(s,index_col=0)
    _N_GENES = full_sigs[sig].shape[0]
    full_sigs[sig] = full_sigs[sig].head(100).index.ravel()

In [None]:
std_gex = (celllines_gex_eac - celllines_gex_eac.mean())/celllines_gex_eac.std()

In [None]:
state_score = []
for sig, genes in full_sigs.items():
    selgenes = std_gex.columns.intersection(genes)
    selgenes = np.setdiff1d(selgenes, mTFs)
    scores = std_gex[selgenes].mean(axis=1)
    scores.name = sig
    state_score.append(scores)
state_score = pd.concat(state_score,axis=1)

In [None]:
all_rs, all_ps = [],[]
for tf in tqdm(std_gex.columns.intersection(list_tfs)):
    r,p = pearsonr(std_gex[tf].fillna(0),state_score["cNMF_4"])
    all_rs.append(r)
    all_ps.append(p)

In [None]:
all_qs = multipletests(all_ps)[1]

In [None]:
std_gex.columns.intersection(list_tfs)[all_qs<0.01]

In [None]:
mTF_score = std_gex[std_gex.columns.intersection(["KLF5","ELF3","SMAD3","TCF7L2","HNF4G"])].mean(axis=1)
mTF_score.name = "mTF_score"

In [None]:
red_mTF_score = std_gex[std_gex.columns.intersection(["KLF5","ELF3"])].mean(axis=1)
red_mTF_score.name = "red_mTF_score"

### Relationship between states

In [None]:
df = state_score[["cNMF_1","cNMF_3","cNMF_4"]]
g = sns.PairGrid(df, diag_sharey=False, corner=True)
g.map_lower(sns.kdeplot)
g.map_diag(sns.histplot)
g.fig.savefig("figures/celllines/cNMF_pairgrid_relbetweenscore.png", dpi=300, bbox_inches="tight")

In [None]:
pca = PCA(n_components=2)
X_pca = pd.DataFrame(pca.fit_transform(std_gex.fillna(0)),index=std_gex.index,columns=["PC1","PC2"])

X_pca = pd.concat([X_pca, state_score],axis=1)

_MAPPING = celllines_metadata_eac.loc[X_pca.index]["CellLineName"].to_dict()


fig, ax = plt.subplots(1,3, figsize=(10,2))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_3")
plot_pcs_color(flatax[1], "cNMF_1")
plot_pcs_color(flatax[2], "cNMF_4")
#plot_pcs_color(flatax[0], "cNMF_3", annotate=True, mapping=_MAPPING)
#plot_pcs_color(flatax[1], "cNMF_1", annotate=True, mapping=_MAPPING)
#plot_pcs_color(flatax[2], "cNMF_4", annotate=True, mapping=_MAPPING)
fig.tight_layout()
fig.savefig("/figures/celllines/PC_wCNMF_score.svg", dpi=200, bbox_inches="tight")

fig, ax = plt.subplots(1,2, figsize=(6,2))
flatax = ax.flatten()

plot_pcs_color(flatax[0], "cNMF_2")
plot_pcs_color(flatax[1], "cNMF_5")

fig.tight_layout()

In [None]:
X_pca.corr()

### Correlation between signatures and TFs

In [None]:
corr_df = pd.concat([state_score[["cNMF_1","cNMF_3","cNMF_4"]],
           celllines_gex_eac.loc[:,mTFs]],axis=1)

In [None]:
heatmap_df = corr_df.corr().loc[["cNMF_3","cNMF_1","cNMF_4"],mTFs]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(4,1.5))
sns.heatmap(data=heatmap_df, annot=heatmap_df, cmap="vlag", center=0, ax=ax)
fig.savefig("figures/celllines/heatmap_cNMF_TF_corr.png", dpi=300, bbox_inches="tight")

In [None]:
corr_df = pd.concat([state_score[["cNMF_3","cNMF_1","cNMF_4"]],
           mTF_score],axis=1)

heatmap_df = corr_df.corr().loc[["mTF_score"],["cNMF_3","cNMF_1","cNMF_4"]]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(2,0.5))
sns.heatmap(data=heatmap_df, annot=heatmap_df, cmap="vlag", center=0, ax=ax)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(["cNMF_3","cNMF_1","cNMF_4"], rotation=45, ha="right")
fig.savefig("figures/celllines/heatmap_cNMF_mTFscore_corr.png", dpi=300, bbox_inches="tight")

In [None]:
corr_df = pd.concat([state_score[["cNMF_3","cNMF_1","cNMF_4"]],
           red_mTF_score],axis=1)

heatmap_df = corr_df.corr().loc[["red_mTF_score"],["cNMF_3","cNMF_1","cNMF_4"]]

In [None]:
fig, ax = plt.subplots(1,1,figsize=(2,0.5))
sns.heatmap(data=heatmap_df, annot=heatmap_df, cmap="vlag", center=0, ax=ax)
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
ax.set_xticklabels(["cNMF_3","cNMF_1","cNMF_4"], rotation=45, ha="right")
fig.savefig("figures/celllines/heatmap_cNMF_red_mTFscore_corr.png", dpi=300, bbox_inches="tight")