In [None]:
import pandas as pd
import numpy as np
import scanpy as sc

from statsmodels.stats.multitest import multipletests

from tqdm.notebook import tqdm

from sklearn.metrics import confusion_matrix, balanced_accuracy_score, classification_report

import seaborn as sns
import matplotlib.pyplot as plt

import palettable

In [None]:
def pretty_ax_wlabels(ax):
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(
        axis='both',  
        which='both',      
        bottom=True,     
        top=False,
        left=False,
        labelbottom=True,
        labelleft = True)
    ax.spines["bottom"].set_linewidth(1.5)
    ax.spines["left"].set_linewidth(1.5)

In [None]:
colorlist = palettable.colorbrewer.qualitative.Set1_7.mpl_colors
colormapping_mal = {"cNMF_1": colorlist[0], "cNMF_2": colorlist[1], "cNMF_3": colorlist[3], 
                    "cNMF_4": colorlist[4], "cNMF_5": colorlist[6]}
colormapping_mal["Mixed"] = "lightgrey"
colormapping_mal["Outlier"] = "grey"

In [None]:
import math
def softmax_w_temp(logits, temp: float=1):
    logits_temp = [x/temp for x in logits]
    bottom = sum([math.exp(x) for x in logits_temp])
    softmax = [math.exp(x)/bottom for x in logits_temp]
    return np.array(softmax)

def get_probs_and_class(Xz, y, n_regions, temp):
    
    scores = pd.concat([Xz[n_regions[state].index].mean(axis=1) for state in sorted(n_regions.keys())],axis=1)
    scores.columns = sorted(n_regions.keys())
        
    y_pred = scores.idxmax(axis=1).replace({f"cNMF_{i}": i-1 for i in range(1,6)})

    m = scores.apply(lambda row: softmax_w_temp(row, temp), axis=1)
    m = pd.DataFrame(np.vstack(m.values), index=scores.index, columns=scores.columns)
    
    plasticity = m.applymap(lambda x: -x*np.log(x)).sum(axis=1)
    plasticity.name = "Plasticity score"
    
    prob_w_class = pd.concat([m, y],axis=1)
    y_pred = pd.DataFrame(y_pred, index=y.index, columns=["Pred"])

    return y_pred, plasticity, prob_w_class

In [None]:
peak_info = pd.read_csv("/add/path/here/peaks_closestfeatures.csv").set_index("query_region")

atac = sc.read_h5ad("/add/path/here/combined_atac.h5ad")

scores = pd.read_csv("/add/path/here/adata_cNMF_scores_wtop.csv",index_col=0)

In [None]:
most_corr_dir = pl.Path("/add/path/here/")
all_corrs = {}
for state in [f"cNMF_{i}" for i in range(1,6)]:
    all_corrs[state] = pd.read_csv(most_corr_dir / f"{state}_region_correlation.csv",index_col=0)

all_ps = {}
for state in [f"cNMF_{i}" for i in range(1,6)]:
    all_ps[state] = pd.read_csv(most_corr_dir / f"{state}_region_pval.csv",index_col=0)

In [None]:
atac = atac[atac.obs.nCount_ATAC>2000].copy()

In [None]:
# add the cNMF scores

new_annot = []
for sample in atac.obs.dataset.unique():

    df1 = scores[scores.sample_id==sample].copy()
    df1.index = df1.index.str[:-2]
    
    df2 = atac.obs[atac.obs.sample_id==sample].copy()
    raw_idx = df2.index.copy()
    df2.index = np.hstack(df2.index.str.split("_").str[1:])
    
    dict_map = {df2.index[i]: raw_idx[i] for i in range(len(raw_idx))}
    
    df = pd.concat([df1.loc[df2.index.intersection(df1.index),['cNMF_1_score', 'cNMF_2_score',
       'cNMF_3_score', 'cNMF_4_score', 'cNMF_5_score', 'highlevel_wtop']],df2.refined_annotation],axis=1).iloc[:,:-1]
    
    df = df.rename(index=dict_map)
    new_annot.append(df)
new_annot = pd.concat(new_annot)

atac.obs = pd.concat([atac.obs,new_annot],axis=1)

In [None]:
subatac = atac[atac.obs.highlevel_wtop.isin([f"cNMF_{i}" for i in range(1,6)])].copy()

n_regions = {}
for state in all_corrs:
    statedf = pd.concat([all_corrs[state],all_ps[state]],axis=1)
    statedf.columns = ["Correlation", "p"]
    statedf["q"] = multipletests(all_ps[state].values.ravel())[1]
    
    n_regions[state] = statedf[(statedf["q"]<0.05)].sort_values(by="Correlation",ascending=False).head(200)

predatac = subatac[np.random.choice(range(0,subatac.shape[0]),size=(subatac.shape[0],),replace=False),np.unique(np.hstack([n_regions[state].index for state in n_regions]))].copy()

In [None]:
X = predatac.X.toarray().copy()
y = predatac.obs.highlevel_wtop.str[-1].astype(int) - 1

Xz = predatac.to_df().copy()
Xz = (Xz - Xz.mean())/Xz.std()

In [None]:
alltemp_probs = []
for temp in [0.05, 0.1, 0.25, 0.5, 1, 2]:
    _, _, all_probs = get_probs_and_class(Xz, y, n_regions, temp)
    alltemp_probs.append(all_probs)

In [None]:
clrs = ["red", "blue", "purple", "pink", "green", "yellow"]

fig, ax = plt.subplots(1,5, figsize=(13,2))
flatax = ax.flatten()
     
for i in range(len(flatax)):# Plot perfectly calibrated
    flatax[i].plot([0, 1], [0, 1], linestyle = '--', label = 'Ideally Calibrated')

for i in range(len(alltemp_probs)):
    prob_w_class = alltemp_probs[i]
    for cl in prob_w_class["highlevel_wtop"].unique():
    
        binary_probs = [prob_w_class.iloc[i,cl] for i in range(prob_w_class.shape[0])]
        binary_class = [1 if prob_w_class["highlevel_wtop"].iloc[i]==cl else 0 for i in range(prob_w_class.shape[0])]
        
        from sklearn.calibration import calibration_curve
        # Creating Calibration Curve
        x_cal, y_cal = calibration_curve(binary_class, binary_probs, n_bins = 15)
         
        # Plot model's calibration curve
        flatax[cl].plot(y_cal, x_cal, marker = '.', label = 'Softmax w/temp', c=clrs[i])
        flatax[cl].set_xlabel('Avg Pred. Prob in each bin')
        flatax[cl].set_ylabel('Ratio of positives')
        flatax[cl].set_title(f"Calibration for class cNMF_{cl+1}")
fig.tight_layout()
fig.savefig("figures/calibration_curves_cnmf_class.svg", dpi=200, bbox_inches="tight")

In [None]:
y_pred, plasticity, all_probs = get_probs_and_class(Xz, y, n_regions, 0.25)

In [None]:
print("Confusion matrix")
print(confusion_matrix(y.ravel(), y_pred.values.ravel()))

print(f"BAC={balanced_accuracy_score(y.ravel(), y_pred.values.ravel()):.2f}")

print("Classification report")
print(classification_report(y.ravel(), y_pred.values.ravel()))

In [None]:
plasticity_df = pd.concat([plasticity,predatac.obs.highlevel_wtop],axis=1)

In [None]:
fig, ax = plt.subplots(1,1,figsize=(1.5,2))
sns.boxplot(data=plasticity_df, y="highlevel_wtop", x="Plasticity score", order=["cNMF_1","cNMF_2","cNMF_3","cNMF_4","cNMF_5"],
            palette=colormapping_mal)
pretty_ax_wlabels(ax)
ax.set_ylabel("")
ax.set_yticks(ax.get_yticks(), ["cNMF$_{1}$","cNMF$_{2}$","cNMF$_{3}$","cNMF$_{4}$","cNMF$_{5}$"])
fig.savefig("figures/malignant/plasticity_score.svg", dpi=200, bbox_inches="tight")

In [None]:
confs = confusion_matrix(y.ravel(), y_pred)

perc = (confs.T/confs.sum(axis=1)).T

In [None]:
import itertools
import matplotlib
import matplotlib.cm as cm
import matplotlib.colors as mcolors

coordinates = pd.DataFrame(np.array(list(itertools.product(np.arange(0,5), np.arange(0,5)))), columns=["True","Pred."])

coordinates["Conf"] = np.hstack(perc)

coordinates["True"] = coordinates["True"].replace({0:4, 1:3, 2:2, 3:1, 4:0})

vcenter = 0.1
vmin, vmax = 0,1
normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
colormap = matplotlib.colormaps['RdBu_r']

fig, ax = plt.subplots(1,1,figsize=(2,2))
sns.scatterplot(data=coordinates, x="Pred.", y="True", size=np.hstack(perc), c=np.hstack(perc), norm=normalize,
        cmap=colormap,)
plt.legend(bbox_to_anchor=(1,1,0,0), frameon=False)
ax.set_xticks([0,1,2,3,4], [f"cNMF_{i}" for i in range(1,6)], rotation=45, ha="right")
ax.set_yticks([4,3,2,1,0], [f"cNMF_{i}" for i in range(1,6)], ha="right")
pretty_ax_wlabels(ax)
ax.set_yticks(ax.get_yticks(), ["cNMF$_{1}$","cNMF$_{2}$","cNMF$_{3}$","cNMF$_{4}$","cNMF$_{5}$"])
ax.set_xticks(ax.get_xticks(), ["cNMF$_{1}$","cNMF$_{2}$","cNMF$_{3}$","cNMF$_{4}$","cNMF$_{5}$"])
ax.set_xlabel("ATAC identity")
ax.set_ylabel("RNA identity")
fig.savefig("figures/malignant/confusion_matrix_prediction_atac_from_rna.svg", dpi=200, bbox_inches="tight")

# Plot

In [None]:
redatac = subatac[np.random.choice(range(0,subatac.shape[0]),size=(subatac.shape[0],),replace=False),
    np.hstack([n_regions[st].index for st in n_regions])].copy()

In [None]:
cell_idx = redatac.obs["highlevel_wtop"].sort_values().index.to_numpy()

row_colors = []
mispred = redatac.obs.loc[cell_idx,"highlevel_wtop"]
for cell in cell_idx:
    row_colors.append(colormapping_mal[mispred.loc[cell]])

In [None]:
df = redatac.to_df().loc[cell_idx].copy()
df = (df - df.mean())/df.std()

In [None]:
clmap = sns.clustermap(data=df, row_cluster=False, 
                       col_cluster=False, row_colors=row_colors, cmap="vlag", center=0, vmax=2, vmin=-2)
clmap.ax_heatmap.set_xticklabels([])
clmap.ax_heatmap.set_xticks([])
clmap.ax_heatmap.set_yticklabels([])
clmap.ax_heatmap.set_yticks([])
clmap.ax_cbar.set_position((0.82, .32, .03, .4))
clmap.ax_cbar.set_title('ATAC\nZ-score')
clmap.ax_heatmap.vlines([200, 400, 600, 800], 0, redatac.shape[0], linewidth=2, color="gray")
clmap.ax_heatmap.hlines(np.cumsum(mispred.value_counts().loc[["cNMF_1","cNMF_2","cNMF_3","cNMF_4","cNMF_5"]].ravel())[:-1], 0, 
                        redatac.shape[1], linewidth=2, color="gray")
clmap.savefig("figures/malignant/heatmap_ATAC_to_RNA_openness.png", dpi=300, bbox_inches="tight")