In [1]:
import numpy as np
import pandas as pd
import yaml
import pickle
from metrics import *

def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

In [2]:
def load_config(config_path):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config
CATEGORIES_YAML = load_config("/hai/scratch/zwefers/seq2loc/metadata/location_levels.yaml")

In [5]:
trainset = pd.read_csv("/hai/scratch/zwefers/seq2loc/metadata/uniprot_trainset.csv")
testset = pd.read_csv("/hai/scratch/zwefers/seq2loc/metadata/hou_testset.csv")
outdir = "outputs/seq2locbench/"

for model in ["prott5", "esm1"]:
    val_avg_df = []
    test_metrics_avg_df = []
    for level in [1,2,3]:
        val_perclass_df = []
        test_targets_all = []
        test_probs_all = []
        test_preds_all = []

        for fold in range(5):
            val_path = f"{outdir}/{model}/{fold}_1Layer_uniprot_level{level}_testout.pkl"
            val_df = pd.read_pickle(val_path)
            val_df = val_df.merge(trainset, 
                                left_on="ACC", 
                                right_on="uniprot_id", 
                                how="inner")
            val_probs = np.stack(val_df.preds.to_numpy())
            val_targets = []
            for locs in val_df[f"level{level}"].str.split(";").to_list():
                val_targets.append([1 if loc in locs else 0 
                                    for loc in CATEGORIES_YAML[f"level{level}"]])
            val_targets = np.array(val_targets)

            thresholds = [get_best_threshold_mcc(val_targets[:, i], val_probs[:, i]) 
                        for i in range(val_targets.shape[1])]
            thresholds = np.array(thresholds)
            #TODO: save thresholds

            _, val_metrics_perclass, val_metrics_avg = all_metrics(val_targets, 
                                                                val_probs, 
                                                                thresholds=thresholds)
            val_metrics_perclass["label"] = CATEGORIES_YAML[f"level{level}"]
            val_metrics_perclass["fold"] = fold
            val_metrics_avg["level"] = level
            val_metrics_avg["fold"] = fold
            val_perclass_df.append(val_metrics_perclass)
            val_avg_df.append(val_metrics_avg)

            test_path = f"{outdir}/{model}/{fold}_1Layer_uniprot_level{level}_hou_testout.pkl"
            test_df = pd.read_pickle(test_path)
            test_df = test_df.merge(testset, 
                                    left_on="ACC", 
                                    right_on="uniprot_id",
                                    how="inner")
            test_probs = np.stack(test_df.preds.to_numpy())
            test_targets = []
            for locs in test_df[f"level{level}"].str.split(";").to_list():
                test_targets.append([1 if loc in locs else 0 
                                    for loc in CATEGORIES_YAML[f"level{level}"]])
            test_targets = np.array(test_targets)


            test_preds = test_probs > thresholds[np.newaxis, :]
            test_max = (test_probs == test_probs.max(axis=1)[:, np.newaxis])
            test_preds = np.logical_or(test_preds, test_max)

            test_targets_all.append(test_targets)
            test_probs_all.append(test_probs)   
            test_preds_all.append(test_preds)
        
        val_perclass_df = pd.concat(val_perclass_df)
        val_perclass_df.to_csv(f"{outdir}/{model}/val_metrics_perclass_level{level}.csv", index=False)

        test_targets_all = np.array(test_targets_all)
        assert np.all(test_targets_all[0, :, :] == test_targets_all)
        test_targets = test_targets_all[0, :, :]
        test_probs = np.array(test_probs_all).mean(axis=0)
        #test_preds = np.array(test_preds_all).max(axis=0)
        test_preds = (np.array(test_preds_all).mean(axis=0) > 0.5).astype(np.int32)


        #Cut out empty categories in testset like int-fils and plastid
        idxs = np.where(test_targets.sum(axis=0) != 0)[0]
        test_targets = test_targets[:, idxs]
        test_probs = test_probs[:, idxs]
        test_preds = test_preds[:, idxs]
        thresholds = thresholds[idxs]


        _, test_metrics_perclass, test_metrics_avg = all_metrics(
                                                                test_targets, 
                                                                test_probs, 
                                                                y_pred_bin = test_preds,
                                                                thresholds=thresholds
                                                                )
        test_metrics_perclass["label"] = np.array(CATEGORIES_YAML[f"level{level}"])[idxs]
        test_metrics_perclass.to_csv(
            f"{outdir}/{model}/test_metrics_perclass_level{level}.csv", index=False)
        
        test_metrics_avg["level"] = level
        test_metrics_avg_df.append(test_metrics_avg)

    val_avg_df = pd.concat(val_avg_df)
    val_avg_df.to_csv(f"{outdir}/{model}/val_metrics_avg.csv", index=False)
    test_metrics_avg_df = pd.concat(test_metrics_avg_df)
    test_metrics_avg_df.to_csv(f"{outdir}/{model}/test_metrics_avg.csv", index=False)

