In [None]:
import os
#os.chdir('../../TCGA_SKCM/')

In [None]:
import scanpy as sc
import glob
from tqdm import tqdm
import pandas as pd
import anndata as ad
import numpy as np
import sys
sys.path.append('../')
from src.utils import compute_pearson_top_n
from src.utils import compute_area_under_pearson_top_n
from src.utils import bootstrapping
import yaml

In [None]:
out_folder = "out_benchmark"

In [None]:
with open("config_dataset.yaml", "r") as stream:
    config_dataset = yaml.safe_load(stream)

models = config_dataset['MODEL']
source_data_path = config_dataset['source_data_path']
metadata_path = config_dataset['metadata_path']

In [None]:
id_pair = pd.read_csv(metadata_path).id_pair.values.astype(str)
len(id_pair)

In [None]:
for model in models:
    files = glob.glob(f'{out_folder}/prediction/{model}/data/h5ad/*')
    files = [f for f in files if f.split('/')[-1].split('.')[0] in id_pair]
    print(model, len(files))

In [None]:
genes = pd.read_csv(f"../{source_data_path}/out_benchmark/info_highly_variable_genes.csv")
selected_genes_bool = genes.isPredicted.values
genes_predict = genes[selected_genes_bool]
genes_predict

In [None]:
import torch
num_workers = torch.get_num_threads()
num_workers

In [None]:
import glob
from tqdm import tqdm
import scanpy as sc
from concurrent.futures import ProcessPoolExecutor, as_completed

# Function to process a single file
def process_file(sample_id):

    data = []
    
    adata = sc.read_h5ad(f"{out_folder}/data/h5ad/{sample_id}.h5ad")
    observed_bulk = adata.var.bulk_norm_tpm_unstranded
    data.append(observed_bulk)
    for model in models:
        
    
        adata = sc.read_h5ad(f"{out_folder}/prediction/{model}/data/h5ad/{sample_id}.h5ad")
        adata.X[adata.X < 0] = 0
        
        expr = adata.X.mean(axis=0)
        
        predicted_bulk = pd.Series(expr, index=adata.var.index, name=model)
        data.append(predicted_bulk)

    bulk_data = pd.concat(data, axis=1)
    bulk_data["sample_id"] = sample_id
    return bulk_data
# Main code

In [None]:
bulk = []

# Using ProcessPoolExecutor to parallelize the processing
with ProcessPoolExecutor(max_workers=num_workers) as executor:
    # Submit all file processing tasks
    
    futures = [executor.submit(process_file, sample_id) for sample_id in id_pair]

    # Collect results as they complete
    for future in tqdm(as_completed(futures), total=len(id_pair)):
        bulk.append(future.result())

In [None]:
bulk_counts = pd.concat(bulk, axis=0)
models = [c for c in bulk_counts.columns if c not in ["bulk_norm_tpm_unstranded", "sample_id"]]
bulk_counts["gene_name"] = bulk_counts.index
#bulk_counts = bulk_counts.melt(["sample_id", "gene_name", "bulk_tpm_unstranded"])

ref_bulk = bulk_counts.pivot(index="sample_id", columns="gene_name", values="bulk_norm_tpm_unstranded")
scores = []
for model in tqdm(models):
    
    pred_bulk = bulk_counts.pivot(index="sample_id", columns="gene_name", values=model)

    score = ref_bulk.corrwith(pred_bulk, method="pearson").fillna(0).reset_index()
    score.columns = ["gene", "pearson"]
    score["model"] = model
    scores.append(score)

scores = pd.concat(scores)
scores = compute_pearson_top_n(scores, "model", genes_predict)
scores

In [None]:
tab = scores.groupby(["gene", "model", "top_n"]).pearson.agg("mean").reset_index()
tab = tab.groupby(["model", "top_n"]).pearson.apply(lambda x: bootstrapping(x)).reset_index()
df_plot = pd.DataFrame(tab["pearson"].to_list(), columns=['pearson_median', 'pearson_std'])
df_plot["model"] = tab.model
df_plot["top_n"] = tab.top_n
df_plot.head()

In [None]:
import plotnine as p9
position_dodge_width = 0.5

df_plot.top_n = pd.Categorical(df_plot.top_n.astype(str), 
                                    df_plot.top_n.drop_duplicates().sort_values().astype(str))
g = (p9.ggplot(df_plot, p9.aes("top_n", "pearson_median", color="model", group='model')) 
 + p9.geom_line(linetype="dashed", alpha=0.8, position=p9.position_dodge(width=position_dodge_width))
 + p9.geom_point(position=p9.position_dodge(width=position_dodge_width))
 + p9.theme_bw()
 + p9.geom_errorbar(p9.aes(x="top_n", ymin="pearson_median-pearson_std",ymax="pearson_median+pearson_std"), 
                    alpha=0.5, size=0.3, width=1, position=p9.position_dodge(width=position_dodge_width))
# + scale_color_prism(palette = "colors")
 + p9.ylab("Pearson correlation")
 + p9.xlab("Top highly variable genes")
 + p9.theme(figure_size=(16, 8))
)
#g.save(f"{out_folder}/evaluation/pearson_score_per_top_n.png", dpi=300)
g

In [None]:
df_plot.to_csv(f"{out_folder}/prediction/model_evaluation_table.csv", index=False)