In [None]:
import os
#os.chdir('../../TCGA_SKCM/') #
#os.chdir('../../TCGA_KIRC/') #
#os.chdir('../../TCGA_COAD/') #

In [None]:
import pandas as pd
import numpy as np
import plotnine as p9
import scanpy as sc
import glob
from tqdm import tqdm
import pandas as pd
import anndata as ad
import numpy as np
import sys
sys.path.append('../')
from src.utils import bootstrapping
import yaml

import glob
from tqdm import tqdm
import scanpy as sc
from concurrent.futures import ProcessPoolExecutor, as_completed

from sklearn.model_selection import train_test_split
import numpy as np
import random
from plotnine_prism import *

In [None]:
out_folder = "out_benchmark_ffpe"

In [None]:
with open("config_dataset.yaml", "r") as stream:
    config_dataset = yaml.safe_load(stream)

models = config_dataset['MODEL']
source_data_path = config_dataset['source_data_path']
metadata_path = config_dataset['metadata_path']
metadata_path_other = config_dataset['metadata_path_other']
metadata_path, metadata_path_other

In [None]:
metadata = pd.read_csv(metadata_path)
classes = metadata.sample_type.value_counts()[metadata.sample_type.value_counts() > 5].index#
metadata = metadata[metadata.sample_type.isin(classes)]
metadata["isMetastatic"] = metadata["sample_type"] == "Metastatic"
metadata

In [None]:
metadata_other = pd.read_csv(metadata_path_other)
metadata = metadata[metadata.case_id.isin(metadata_other.case_id)]
metadata

In [None]:
import matplotlib.pyplot as plt 
metadata.case_id.value_counts().plot.hist(log=1)
plt.show()

In [None]:
metadata.case_id.value_counts()

In [None]:
#metadata = metadata.drop_duplicates('case_id')
metadata_labels = metadata.copy()
metadata_labels = metadata_labels.drop_duplicates('case_id')
metadata_labels = metadata_labels.set_index('case_id')
metadata = metadata.set_index('case_id')
metadata

In [None]:
len(metadata.id_pair.unique())

In [None]:
# Create a dictionary with case_ids as keys and empty lists as values
case_ids_to_id_pair = {case_id: [] for case_id in metadata.index.unique()}

# Iterate over the rows of the DataFrame and append the id_pair to the corresponding case_id
for _, row in metadata.reset_index().iterrows():
    case_ids_to_id_pair[row.case_id].append(row.id_pair)
case_ids_to_id_pair['TCGA-ER-A2NF']

In [None]:
genes = pd.read_csv(f"../{source_data_path}/out_benchmark/info_highly_variable_genes.csv")
selected_genes_bool = genes.isPredicted.values
genes_predict = genes[selected_genes_bool]
genes_predict

In [None]:
import torch
num_workers = torch.get_num_threads()
num_workers

In [None]:
# Function to process a single file
def process_file(case_id):

    data = []
    
    for model in models:
        adata = []
        for sample_id in case_ids_to_id_pair[case_id]:
    
            adata.append(sc.read_h5ad(f"{out_folder}/prediction/{model}/data/h5ad/{sample_id}.h5ad"))

        n_size = np.array([len(a) for a in adata])
        n_size = np.argsort(-n_size)
        adata = adata[n_size[0]]
        
        adata.X[adata.X < 0] = 0
        predicted_all_bulk = pd.Series(np.median(adata.X, axis=0), index=adata.var.index, name=f"{model}")
        data.append(predicted_all_bulk)

    observed_bulk = []
    for sample_id in case_ids_to_id_pair[case_id]:
        adata = sc.read_h5ad(f"{out_folder}/data/h5ad/{sample_id}.h5ad") # expression is always the same
        observed_bulk.append(adata.var.bulk_norm_tpm_unstranded)
    observed_bulk = observed_bulk[n_size[0]]
    observed_bulk = pd.Series(observed_bulk, index=adata.var.index, name="bulk_norm_tpm_unstranded")
    data.append(observed_bulk)


    bulk_data = pd.concat(data, axis=1)
    bulk_data["case_id"] = case_id
    return bulk_data


In [None]:
bulk = []

# Using ProcessPoolExecutor to parallelize the processing
with ProcessPoolExecutor(max_workers=num_workers) as executor:
    # Submit all file processing tasks
    
    futures = [executor.submit(process_file, case_id) for case_id in case_ids_to_id_pair.keys()]

    # Collect results as they complete
    for future in tqdm(as_completed(futures), total=len(case_ids_to_id_pair.keys())):
        bulk.append(future.result())

In [None]:
bulk_counts = pd.concat(bulk, axis=0)
models = [c for c in bulk_counts.columns if c not in ["bulk_norm_tpm_unstranded", "case_id"]] #bulk_norm_tpm_unstranded
bulk_counts["gene_name"] = bulk_counts.index
bulk_counts

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
from scipy.stats import bootstrap
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline


def bootstrapping(x_list):
    res = bootstrap((x_list,), np.median)
    standard_error = res.standard_error
    median = np.median(res.bootstrap_distribution)
    return [median, standard_error]

def get_c_index(bulk, method_name, metadata, random_state=2024, n_repeats=1000):
    y = metadata_labels.loc[bulk.index].isMetastatic.values
    X = bulk.values
    scores = []
    for n_patients in [75, 100, 125]:
        score_patient = []
        for _ in range(n_repeats):
            X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=n_patients, stratify=y)
            
            # Define model and pipeline
            pipe = Pipeline([
                ('scaler', StandardScaler()),
                ('model', LogisticRegression(max_iter=1000))
            ])
        
            pipe.fit(X_train, y_train)
            score = f1_score(y_test, pipe.predict(X_test))
            
            score_patient.append(score)
        score_mean, score_std = bootstrapping(score_patient)
        scores.append([method_name, score_mean, score_std, n_patients])
    
    # Calculate mean and std of c-index scores
    
    out = pd.DataFrame(scores, columns=["method", "f1_mean", "f1_std", "n_patients"])
    return out

In [None]:
scores = []
for st_type in tqdm(["bulk_norm_tpm_unstranded", "DeepSpot", "BLEEP", "STNet"]):
    bulk = bulk_counts.pivot(index="case_id", columns="gene_name", values=st_type)
    if st_type == "bulk_norm_tpm_unstranded":
        st_type = "bulk RNA-seq"
    res = get_c_index(bulk, st_type, metadata)
    scores.append(res)
scores = pd.concat(scores)
scores

In [None]:
scores["method"] = pd.Categorical(scores["method"], scores.groupby("method").f1_mean.agg("mean").sort_values(ascending=False).index)
scores["n_patients"] = pd.Categorical(scores["n_patients"], sorted(scores["n_patients"].unique()))
position_dodge_width = 0.3
scores

In [None]:
((p9.ggplot(scores, p9.aes("n_patients", "f1_mean", color="method", group="method"))) 
 + p9.geom_line(position=p9.position_dodge(width=position_dodge_width), linetype="dashed")
 + p9.geom_point(
     position=p9.position_dodge(width=position_dodge_width)
 )
 #+ p9.facet_wrap("~n_patients", scales="y_free")
 + p9.theme_bw()
 + p9.geom_errorbar(p9.aes(x="n_patients", ymin="f1_mean-f1_std",
                           ymax="f1_mean+f1_std", color="method"), 
                    width=0.4, alpha=1, size=0.5, 
                    position=p9.position_dodge(width=position_dodge_width)
                   )
 + scale_color_prism(palette = "colors")
 + p9.theme(axis_text_x = p9.element_text(angle = 90, hjust = 1))
 + p9.ylab("F1 score")
)

In [None]:
scores.to_csv(f"{out_folder}/prediction/tumor_type_prediction.csv", index=False)