In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import anndata as ad
import squidpy as sq
import yaml
from tqdm import tqdm
import matplotlib.pyplot as plt 
import glob
import plotnine as p9

In [None]:
translate = {
    'Tumor': 'Tumor', 
    'Stroma': 'Stroma', 
    'Normal lymphoid tissue': 'Normal\nlymphoid', 
    'Blood and necrosis': 'Blood/\nnecrosis',
    'Pigment': 'Pigment'
}

In [None]:
dataset = "10x_TuPro"

In [None]:
with open(f"../{dataset}/config_dataset.yaml", "r") as stream:
    config_dataset = yaml.safe_load(stream)

models = config_dataset["MODEL"]
all_samples = set(config_dataset["SAMPLE"])
known_genes = np.array(config_dataset["known_genes"])
top_n_genes_to_predict = int(config_dataset["top_n_genes_to_predict"])
top_n_genes_to_predict

In [None]:
out_folder = "out_benchmark"
genes = pd.read_csv(f"../{dataset}/{out_folder}/info_highly_variable_genes.csv")
selected_genes_bool = genes.isPredicted.values
genes_to_predict = genes[selected_genes_bool]
genes_to_predict

In [None]:
adata_pred_list = {s:{} for s in all_samples}
adata_true_list = {}
sample = "MELIPIT-1-2"
for sample in [sample]:
    adata_true = sc.read_h5ad(f"../{dataset}/{out_folder}/data/h5ad/{sample}.h5ad")
    adata_true = adata_true[:,adata_true.var.index.isin(genes_to_predict.gene_name)]
    sc.pp.normalize_total(adata_true)
    sc.pp.log1p(adata_true)
    
    adata_true.var["method"] = "Visium, 10x Genomics"
    adata_true.obs["method"] = "Visium, 10x Genomics"
    adata_true.obs["sample_id"] = sample
    adata_true_list[sample] = adata_true
    for model in models:
        top_model = pd.read_csv(f"../{dataset}/out_benchmark/evaluation/{model}/top_model_per_test_sample.csv")
        row = top_model[top_model.test_sample.apply(lambda x: sum([s == sample for s in x.split("_")]) == 1)].iloc[0]
        path = f"../{dataset}/{out_folder}/evaluation/{row.test_sample}/*/{model}/prediction/{row.model}_test.pkl"
        path = path.replace('[', '+-+').replace(']', '-+-')
        path = path.replace('+-+', '[[]').replace('-+-', '[]]')
        expression_predicted_file = glob.glob(path)[0]
        expression_predicted = pd.read_pickle(expression_predicted_file)
        expression_predicted = expression_predicted[expression_predicted.index.to_series().apply(lambda x: sample == x.split("_")[1])]
        expression_predicted.index = [i.split("_")[0] for i in expression_predicted.index]
        expression_predicted = expression_predicted[adata_true.var.index]
        counts = expression_predicted.loc[adata_true.obs_names].values
        counts[counts < 0] = 0 
        counts = counts + 1
        adata_predicted = ad.AnnData(counts, obs=adata_true.obs)
        adata_predicted.obs_names = adata_true.obs_names
        adata_predicted.obs["ground_truth"] = adata_true.obs.ground_truth
        adata_predicted.var_names = expression_predicted.columns   
        sc.pp.highly_variable_genes(adata_predicted, n_top_genes=3000)
        adata_predicted.obs["method"] = model
        adata_predicted.obs["sample_id"] = sample
        adata_predicted.var["method"] = model
        adata_pred_list[sample][model] = adata_predicted
    #adata_pred.var.index = [f"{i}_predicted" for i in adata_pred.var.index]
    

In [None]:
adata_pred_list_model = {m: ad.concat(v.values()) for m,v in adata_pred_list.items() if len(v) > 0}
#adatas_true = ad.concat(list(adata_true_list.values()))
#adatas_pred = ad.concat(list(adata_pred_list_model.values()))
#adatas_true = adatas_true[:,adatas_true.var.index.isin(adatas_pred.var.index)]
#adatas_true

In [None]:
model = 'DeepSpot'

adata_sample_true = adata_true_list[sample].copy()
adata_sample_true.obs["Pathology annotation"] = adata_sample_true.obs.ground_truth.apply(lambda x: translate[x])
adata_sample_true = adata_sample_true[adata_sample_true.obs.ground_truth != "Pigment"].copy()

adata_sample_predicted = adata_pred_list[sample][model].copy()
adata_sample_predicted.obs["Pathology annotation"] = adata_sample_predicted.obs.ground_truth.apply(lambda x: translate[x])
adata_sample_predicted = adata_sample_predicted[adata_sample_predicted.obs.ground_truth != "Pigment"].copy()

In [None]:
sc.pp.pca(adata_sample_predicted)
sc.pp.neighbors(adata_sample_predicted, n_neighbors=50)
sc.tl.umap(adata_sample_predicted)

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})

import pyvips
import umap
from matplotlib.offsetbox import OffsetImage, AnnotationBbox 
import matplotlib as mpl

def get_umap2(emb):
    reducer = umap.UMAP()
    umap_emb = reducer.fit_transform(emb)
    return umap_emb

format_to_dtype = {
    'uchar': np.uint8,
    'char': np.int8,
    'ushort': np.uint16,
    'short': np.int16,
    'uint': np.uint32,
    'int': np.int32,
    'float': np.float32,
    'double': np.float64,
    'complex': np.complex64,
    'dpcomplex': np.complex128,
}

def get_spot(image, x, y, spot_diameter_fullres):
    x = x - int(spot_diameter_fullres // 2)
    y = y - int(spot_diameter_fullres // 2)
    spot = image.crop(x, y, spot_diameter_fullres // 1.3, spot_diameter_fullres)
    spot_array = np.ndarray(buffer=spot.write_to_memory(),
                            dtype=format_to_dtype[spot.format],
                            shape=[spot.height, spot.width, spot.bands])
    return spot_array

def create_pc_qcut(pc1, labels, n=10):
    pc1 = np.array(pc1)
    labels = np.array(labels)
    
    pc_q = np.empty(len(labels)).astype(str)
    
    labels_unique = np.unique(labels)
    
    for label in labels_unique:
        values = np.array([f"{i}_{label}" for i in pd.qcut(pc1[label == labels], n, labels=False)])
        pc_q[label == labels] = values
    
    return pc_q
    

def select_idx(pc1, labels, n=3):
    import itertools
    
    pc_q = create_pc_qcut(pc1, labels)
    
    labels = np.array(labels)
    
    pc_q_unique = np.unique(pc_q)
    pc_q_unique = [qc for qc in pc_q_unique if int(qc.split("_")[0]) in [1,5,8]]
    print(pc_q_unique)
    labels_unique = np.unique(labels)
    
    indxs = []    
    for pc_q_group, label_group in itertools.product(pc_q_unique, labels_unique):
        indx_pass = np.logical_and(pc_q == pc_q_group, labels == label_group)
        if sum(indx_pass) > 0:
            size_n = min(n, sum(indx_pass) - 1)
            indx = np.random.choice(np.where(indx_pass)[0], size_n, replace=False)
            indxs.extend(indx)
    return np.array(indxs)

def getImage(spot, zoom=1):
    return OffsetImage(spot, zoom=zoom)

img_path = f"../10x_TuPro/out_benchmark/data/image/{sample}.tif"

labels = adata_sample_predicted.obs["Pathology annotation"]
# Randomly select 10 points for each label using NumPy and get the indices
selected_indices = [np.random.choice(np.where(labels == label)[0], 5, replace=False) for label in np.unique(labels)]
# Concatenate the indices from all labels into a single list
selected_indices = np.concatenate(selected_indices)


image = pyvips.Image.new_from_file(img_path)

emb = adata_sample_predicted.obsm["X_umap"]
# Randomly select 10 points for each label using NumPy and get the indices
indeces_to_plot = select_idx(emb[:,0], labels)
# Concatenate the indices from all labels into a single list

colormap = {
    'Tumor': 'r', 
    'Blood/\nnecrosis': 'b', 
    'Normal\nlymphoid': 'g', 
    'Stroma': 'purple'
}


spot_diameter_fullres = 320
tab = adata_sample_predicted.obs.iloc[indeces_to_plot]
x, y = emb[indeces_to_plot][:,0], emb[indeces_to_plot][:,1]
#colormap = dict(zip(model_ordered_dict.values(), adata_sample_predicted.uns["Pathology annotation"]))

n_labels = np.unique(labels).size
fig, ax = plt.subplots(figsize=(9, 8))
ax.scatter(emb[:,0], emb[:,1], c=[colormap[labels[i]] for i in range(len(labels))], alpha=1, s=10)
labels = tab["Pathology annotation"]
for i in range(len(tab)):
    row = tab.iloc[i]
    img = get_spot(image, row.y_pixel, row.x_pixel, spot_diameter_fullres)
    bbox_props = dict(boxstyle="square,pad=0.02", fc="none", ec=colormap[labels[i]], lw=3)
    ab = AnnotationBbox(getImage(img, zoom=0.18), (x[i], y[i]), frameon=True, bboxprops=bbox_props)
    ax.add_artist(ab)

    legend_patches = []
for label, color in colormap.items():
    legend_patches.append(mpl.patches.Patch(color=color, label=label))
# Display the legend
ax.legend(handles=legend_patches, fontsize=20)
ax.set_xlabel(f"UMAP_1")
ax.set_ylabel(f"UMAP_2")
plt.savefig(f"figures/Figure4A-{sample}_{dataset}_spots_in_latent_space.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
sc.tl.rank_genes_groups(adata_sample_predicted, 'Pathology annotation', method='wilcoxon', 

                        n_genes=adata_sample_predicted.shape[1], key_added = "wilcoxon")

In [None]:
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})
sc.pl.rank_genes_groups_matrixplot(adata_sample_predicted, 
                                   n_genes=10,
                                   min_logfoldchange=1,
                                   values_to_plot="logfoldchanges",
                                   colorbar_title = "logFC",
                                   key="wilcoxon",
                                   vmin=-2,
                                   vmax=2,
                                   figsize=(15, 3),
                                   save=f"figures/Figure4B-{sample}_{model}_matrixplot_genes.png",
                                   groupby="Pathology annotation")

In [None]:
df = sc.get.rank_genes_groups_df(adata_sample_predicted, group='Tumor', key="wilcoxon")
df = df[df.pvals_adj < 0.01]
# df["pct_abs"] = abs(df.pct_nz_group - df.pct_nz_reference)
df = df.sort_values(["logfoldchanges"], ascending=False)

In [None]:
import gseapy as gp
from gseapy import barplot, dotplot
gene_list = genes = df.names[:100].values.tolist()

gene_sets = [
    "Cancer_Cell_Line_Encyclopedia"
]


enr = gp.enrichr(gene_list=gene_list,
                 gene_sets=gene_sets,
                 organism='human',
                 outdir=None,  # don't write to disk
                 )

cmap = plt.cm.get_cmap("tab20", len(gene_sets))

In [None]:
ax = barplot(enr.results,
             column="Adjusted P-value",
             #group='Gene_set',  # set group, so you could do a multi-sample/library comparsion
             size=10,
             top_term=15,
             figsize=(5, 6),
             title=f"Cancer Cell Line Encyclopedia",
             color=dict(zip(gene_sets, [cmap(i) for i in range(len(gene_sets))]))
             )
plt.savefig(f"figures/Figure4C1-Cancer_Cell_Line_Encyclopedia_100.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
n_genes = 200
sc.tl.rank_genes_groups(adata_sample_true, 'Pathology annotation', method='wilcoxon', n_genes=n_genes, key_added = "wilcoxon", pts=True)
sc.pl.rank_genes_groups(adata_sample_true, key="wilcoxon")
sc.pl.rank_genes_groups_heatmap(adata_sample_true,
                                n_genes=15, key="wilcoxon", 
                                groupby="Pathology annotation", 
                                show_gene_labels=True)

In [None]:
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})

sc.tl.rank_genes_groups(adata_sample_predicted, 'Pathology annotation', method='wilcoxon', 
                        n_genes=n_genes, key_added = "wilcoxon", pts=True)
sc.pl.rank_genes_groups(adata_sample_predicted, key="wilcoxon")
sc.pl.rank_genes_groups_heatmap(adata_sample_predicted, n_genes=15, 
                                key="wilcoxon", groupby="Pathology annotation",
                                save=f"figures/Figure4D_{sample}_{model}_heatmap_genes.png",
                                show_gene_labels=True)

In [None]:
import decoupler as dc
progeny = dc.get_progeny(organism='human', top=500)
dc.run_mlm(
    mat=adata_sample_predicted,
    net=progeny,
    source='source',
    target='target',
    weight='weight',
    verbose=True,
    use_raw=False,
    min_n = 5
)

In [None]:
adata_sample_predicted.obsm["mlm_estimate_norm"] = pd.DataFrame(adata_sample_predicted.obsm["mlm_estimate"], 
                                                                columns=adata_sample_predicted.obsm["mlm_estimate"].columns, 
                                                                index=adata_sample_predicted.obsm["mlm_estimate"].index)
acts = dc.get_acts(adata_sample_predicted, obsm_key='mlm_estimate_norm')
acts

In [None]:
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 20})

sc.pl.matrixplot(acts, var_names=acts.var_names, 
                 groupby="Pathology annotation", 
                 dendrogram=False,
                 colorbar_title='Activity\n', 
                 cmap='viridis', 
                 figsize=(10,3), 
                 title="Pathway",
                 save=f"figures/Figure4D-{sample}_{model}_pathway.png", 
                 standard_scale="var", 
                )