In [8]:
import numpy as np
import pandas as pd
import yaml
import pickle
from metrics import *

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [9]:
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config
CATEGORIES_YAML = load_config("/hai/scratch/zwefers/seq2loc/metadata/level_classes.yaml")

In [10]:
trainset = pd.read_csv("/hai/scratch/zwefers/seq2loc/metadata/hpa_uniprot_combined_trainset.csv")
testset = pd.read_csv("/hai/scratch/zwefers/seq2loc/metadata/hou_testset.csv")
outdir = "outputs/seq2locbench/"
predictions_df = pd.DataFrame()

for model in ["prott5", "esm1"]:
    val_avg_df = []
    test_metrics_avg_df = []
    for level in [1,2,3]:
        val_perclass_df = []
        test_targets_all = []
        test_probs_all = []
        test_preds_all = []

        for fold in range(5):
            val_path = f"{outdir}/{model}/{fold}_1Layer_combined_level{level}_testout.pkl"
            val_df = pd.read_pickle(val_path)
            val_df = val_df.merge(trainset, 
                                left_on="ACC", 
                                right_on="uniprot_id", 
                                how="inner")
            val_probs = np.stack(val_df.preds.to_numpy())
            val_targets = []
            for locs in val_df[f"level{level}"].str.split(";").to_list():
                val_targets.append([1 if loc in locs else 0 
                                    for loc in CATEGORIES_YAML[f"level{level}"]])
            val_targets = np.array(val_targets)

            thresholds = [get_best_threshold_mcc(val_targets[:, i], val_probs[:, i]) 
                        for i in range(val_targets.shape[1])]
            thresholds = np.array(thresholds)
            #TODO: save thresholds

            _, val_metrics_perclass, val_metrics_avg = all_metrics(val_targets, 
                                                                val_probs, 
                                                                thresholds=thresholds)
            val_metrics_perclass["label"] = CATEGORIES_YAML[f"level{level}"]
            val_metrics_perclass["fold"] = fold
            val_metrics_avg["level"] = level
            val_metrics_avg["fold"] = fold
            val_perclass_df.append(val_metrics_perclass)
            val_avg_df.append(val_metrics_avg)

            test_path = f"{outdir}/{model}/{fold}_1Layer_combined_level{level}_hou_testout.pkl"
            test_df = pd.read_pickle(test_path)
            test_df = test_df.merge(testset, 
                                    left_on="ACC", 
                                    right_on="uniprot_id",
                                    how="inner")
            test_probs = np.stack(test_df.preds.to_numpy())
            test_targets = []
            for locs in test_df[f"level{level}"].str.split(";").to_list():
                test_targets.append([1 if loc in locs else 0 
                                    for loc in CATEGORIES_YAML[f"level{level}"]])
            test_targets = np.array(test_targets)


            test_preds = test_probs > thresholds[np.newaxis, :]

            test_targets_all.append(test_targets)
            test_probs_all.append(test_probs)   
            test_preds_all.append(test_preds)
        
        val_perclass_df = pd.concat(val_perclass_df)
        val_perclass_df.to_csv(f"{outdir}/{model}/val_metrics_perclass_level{level}.csv", index=False)

        test_targets_all = np.array(test_targets_all)
        assert np.all(test_targets_all[0, :, :] == test_targets_all)
        test_targets = test_targets_all[0, :, :]
        test_probs = np.array(test_probs_all).mean(axis=0)
        #test_preds = np.array(test_preds_all).max(axis=0)
        test_preds = (np.array(test_preds_all).mean(axis=0) > 0.5).astype(np.int32)


        labels = np.array(CATEGORIES_YAML[f"level{level}"])
        predicted_labels = [set(labels[np.where(pred==1)[0]]) for pred in test_preds]
        predictions_df[f"DL2_{model}_level{level}_predictions"] = predicted_labels

        #Cut out empty categories in testset like int-fils and plastid
        idxs = np.where(test_targets.sum(axis=0) != 0)[0]
        test_targets = test_targets[:, idxs]
        test_probs = test_probs[:, idxs]
        test_preds = test_preds[:, idxs]
        thresholds = thresholds[idxs]

        _, test_metrics_perclass, test_metrics_avg = all_metrics(
                                                                test_targets, 
                                                                test_probs, 
                                                                y_pred_bin = test_preds,
                                                                thresholds=thresholds
                                                                )
        test_metrics_perclass["label"] = np.array(CATEGORIES_YAML[f"level{level}"])[idxs]
        test_metrics_perclass.to_csv(
            f"{outdir}/{model}/test_metrics_perclass_level{level}.csv", index=False)
        
        test_metrics_avg["level"] = level
        test_metrics_avg_df.append(test_metrics_avg)

    val_avg_df = pd.concat(val_avg_df)
    val_avg_df.to_csv(f"{outdir}/{model}/val_metrics_avg.csv", index=False)
    test_metrics_avg_df = pd.concat(test_metrics_avg_df)
    test_metrics_avg_df.to_csv(f"{outdir}/{model}/test_metrics_avg.csv", index=False)

In [11]:
testset.level1 = testset.level1.str.split(";").apply(lambda x: set(x))
testset.level2 = testset.level2.str.split(";").apply(lambda x: set(x))
testset.level3 = testset.level3.str.split(";").apply(lambda x: set(x))

In [12]:
pd.concat([testset, predictions_df], axis=1)

Unnamed: 0,uniprot_id,ensembl_id,level1,level2,level3,sequence,DL2_prott5_level1_predictions,DL2_prott5_level2_predictions,DL2_prott5_level3_predictions,DL2_esm1_level1_predictions,DL2_esm1_level2_predictions,DL2_esm1_level3_predictions
0,A0A087X0K9,ENSG00000104067,{plasma-membrane},{plasma-membrane},{plasma-membrane},MSARAAAAKSTAMEETAIWEQHTVTLHRAPGFGFGIAISGGRDNPH...,"{vesicles, cytosol, microtubules, cytoskeleton...","{vesicles, cytosol, cytoskeleton}","{cytosol, cytoskeleton}",{nucleoplasm},{nucleus},{nucleus}
1,A0A0C4DGS1,ENSG00000244038,{endoplasmic-reticulum},{endoplasmic-reticulum},{endomembrane-system},MEPSTAARAWALFWLLLPLLGAVCASGPRTLVLLDNLNVRETHSLF...,"{vesicles, cytosol, microtubules, cytoskeleton...","{cytosol, cytoskeleton}","{cytosol, cytoskeleton}","{nuclear-bodies, nucleoplasm, nucleoli-fibrill...",{nucleus},{nucleus}
2,A0A1B0GTU4,ENSG00000089159,{plasma-membrane},{plasma-membrane},{plasma-membrane},MDDLDALLADLESTTSHISKRPVFLSEETPYSYPTGNHTYQEIAVP...,"{vesicles, microtubules, nucleoplasm}","{vesicles, nucleus}",{nucleus},"{nuclear-bodies, nucleoplasm, nucleoli-fibrill...",{nucleus},{nucleus}
3,A0A2R8YG42,ENSG00000102606,{plasma-membrane},{plasma-membrane},{plasma-membrane},MNSAEQTVTWLITLGVLESPKKTISDPEGFLQASLKDGVVLCRLLE...,"{vesicles, cytosol, nuclear-bodies, nucleoli-f...","{vesicles, cytosol, nucleus}",{nucleus},"{vesicles, lysosomes, nucleoli-fibrillar-cente...","{vesicles, plasma-membrane}","{endomembrane-system, plasma-membrane}"
4,A0A6Q8PGB0,ENSG00000181222,{nucleoplasm},{nucleus},{nucleus},MHGGGPPSGDSACPLRTIKRVQFGVLSPDELKRMSVTEGGIKYPET...,"{vesicles, cytosol, nuclear-bodies, microtubul...","{vesicles, cytosol, cytoskeleton}","{cytosol, cytoskeleton}","{cytosol, nucleoplasm, nucleoli-fibrillar-cent...",{nucleus},{nucleus}
...,...,...,...,...,...,...,...,...,...,...,...,...
3809,Q9Y6X4,ENSG00000198780,{nuclear-membrane},{nucleus},{nucleus},MAFPVDMLENCSHEELENSAEDYMSDLRCGDPENPECFSLLNITIP...,{cytosol},"{nucleus, cytosol}",{cytosol},"{cytosol, nucleoplasm}","{nucleus, cytosol}","{nucleus, cytosol}"
3810,Q9Y6X8,ENSG00000178764,{nucleoplasm},{nucleus},{nucleus},MASKRKSTTPCMVRTSQVVEQDVPEEVDRAKEKGIGTPQPDVAKDS...,"{nucleoli-fibrillar-center, nucleoplasm}","{nucleus, nucleoli}","{nucleus, nucleoli}","{nucleoli, nucleoplasm}","{nucleus, cytosol}","{nucleus, cytosol}"
3811,Q9Y6X9,ENSG00000133422,"{cytosol, nucleoplasm}","{nucleus, cytosol}","{nucleus, cytosol}",MAFTNYSSLNRAQLTFEYLHTNSTTHEFLFGALAELVDNARDADAT...,"{vesicles, mitochondria, endoplasmic-reticulum...","{vesicles, mitochondria}",{mitochondria},"{mitochondria, endoplasmic-reticulum}",{endoplasmic-reticulum},{}
3812,Q9Y6Y0,ENSG00000116679,{cytosol},{cytosol},{cytosol},MIPNGYLMFEDENFIESSVAKLNALRKSGQFCDVRLQVCGHEMLAH...,{mitochondria},{mitochondria},{mitochondria},"{vesicles, mitochondria, endoplasmic-reticulum}",{mitochondria},{mitochondria}


In [13]:
val_avg_df

Unnamed: 0,macro_ap,micro_ap,acc,f1_macro,f1_micro,jaccard_macro,jaccard_micro,rocauc_macro,rocauc_micro,mlrap,cov_error,num_labels,level,fold
0,0.313532,0.379915,0.256241,0.335243,0.508623,0.233324,0.508623,0.828956,0.849241,0.59548,5.681354,2.304463,1,0
0,0.323443,0.450596,0.193683,0.343265,0.43603,0.236491,0.43603,0.839082,0.865023,0.636962,5.307547,3.11916,1,1
0,0.307749,0.477932,0.215623,0.336025,0.429958,0.23341,0.429958,0.828384,0.869527,0.661702,5.085493,2.940987,1,2
0,0.308154,0.43995,0.259266,0.337292,0.475554,0.231973,0.475554,0.832141,0.86801,0.635695,5.158283,2.310136,1,3
0,0.322266,0.449575,0.257283,0.354143,0.497394,0.244661,0.497394,0.842035,0.872175,0.641016,5.091563,2.192584,1,4
0,0.562774,0.622263,0.30711,0.558278,0.597457,0.414985,0.597457,0.848554,0.87429,0.751466,2.888994,1.687595,2,0
0,0.560204,0.631755,0.351239,0.553215,0.61208,0.410425,0.61208,0.85422,0.877866,0.756554,2.858143,1.516928,2,1
0,0.56896,0.64016,0.405145,0.564738,0.62816,0.42105,0.62816,0.854312,0.882543,0.76537,2.770191,1.352563,2,2
0,0.554558,0.62406,0.345499,0.54341,0.592685,0.400785,0.592685,0.85316,0.877325,0.752287,2.843608,1.570537,2,3
0,0.567367,0.625252,0.368331,0.568381,0.615424,0.420068,0.615424,0.857447,0.879482,0.756395,2.832009,1.465759,2,4


In [14]:
test_metrics_avg_df

Unnamed: 0,macro_ap,micro_ap,acc,f1_macro,f1_micro,jaccard_macro,jaccard_micro,rocauc_macro,rocauc_micro,mlrap,cov_error,num_labels,level
0,0.324091,0.422995,0.108547,0.338775,0.422436,0.231574,0.422436,0.845203,0.874091,0.637396,4.271106,3.300472,1
0,0.589578,0.671429,0.356057,0.581389,0.63108,0.426826,0.63108,0.891232,0.904055,0.786128,2.188778,1.61248,2
0,0.661898,0.703067,0.45333,0.64237,0.685771,0.486806,0.685771,0.892814,0.89739,0.810943,1.974043,1.403513,3
