Import libraries

In [1]:
import os 
import numpy as np
import pandas as pd
from scipy import stats
from pprint import pprint
from utils import load_pickle, load_json, pickle_data

import torch

Settings

In [16]:
cache_folder = "./cache"
if not os.path.exists(cache_folder): os.makedirs(cache_folder)

THRESHOLD = 0.9

Helper Functions

In [17]:
def get_correlation(results, method1, method2):

    loss_method1, loss_method2 = [],[]
    rank_method1, rank_method2 = [],[]

    for _, v in results.items(): 

        if method1 in v and method2 in v:
            loss_method1.append(v[method1][0])
            loss_method2.append(v[method2][0])
            rank_method1.append(v[method1][1])
            rank_method2.append(v[method2][1])

    rank_correlation = round(float(stats.spearmanr(rank_method1, rank_method2).statistic), 3)
    linear_correlation = round(float(stats.pearsonr(loss_method1, loss_method2).statistic), 3)

    return rank_correlation, linear_correlation, len(loss_method1)

Get samples that appear across all splits

In [37]:
repeated_test_ids_path = os.path.join(cache_folder, "repeated_test_ids.pkl")

if not os.path.exists(repeated_test_ids_path):

    splits_folder = "/data/rbg/users/klingmin/projects/MS_processing/data_splits"
    repeated_test_ids = {}

    for dataset in os.listdir(splits_folder):
        
        repeated_test_ids[dataset] = {}
        all_ids = None

        for splits in os.listdir(os.path.join(splits_folder, dataset, "splits")):
            
            if "CF" in splits: continue

            current_ids = load_json(os.path.join(splits_folder, dataset, "splits", splits))["test"]

            if all_ids == None: all_ids = set(current_ids)
            all_ids = all_ids.intersection(current_ids)

        repeated_test_ids[dataset] = all_ids
    
    pickle_data(repeated_test_ids, repeated_test_ids_path)

repeated_test_ids = load_pickle(repeated_test_ids_path)

for dataset, rec in repeated_test_ids.items():
    print(dataset, len(rec))

nist2023 2926
nist2020 3052
massspecgym 387
canopus 65


Load in the scores

In [20]:
results_folder = "/data/rbg/users/klingmin/projects/ML_MS_analysis/FP_prediction/baseline_models/models_cached/morgan4_4096"
mist_results_folder = "/data/rbg/users/klingmin/projects/ML_MS_analysis/FP_prediction/mist/models_cached/morgan4_4096"
merged_results_path = os.path.join(cache_folder, "merged_test_results.pkl")

if not os.path.exists(merged_results_path):

    merged_results = {} 

    for folder in [results_folder, mist_results_folder]:

        for dataset in os.listdir(folder):

            dataset_folder = os.path.join(folder, dataset)

            if dataset not in merged_results: merged_results[dataset] = {} 

            for f in os.listdir(dataset_folder):

                test_filepath = os.path.join(dataset_folder, f, "test_results.pkl")
                test_results = load_pickle(test_filepath)
                test_scores = [v["loss"] for _, v in test_results.items()]
                test_scores = sorted(test_scores, reverse = False) # Ascending order

                # Add to all results
                for k, v in test_results.items():
                    if torch.is_tensor(k): k = k.item()

                    if k not in merged_results[dataset]: merged_results[dataset][k] = {}
                    merged_results[dataset][k][f] = (v["loss"], test_scores.index(v["loss"]) / len(test_scores))
    
    pickle_data(merged_results, merged_results_path)

else:

    merged_results = load_pickle(merged_results_path)

Get test sets that are consistently bad

In [21]:
merged_results_sieved = {}

for dataset, results in merged_results.items():

    merged_results_sieved[dataset] = {}
    selected_ids = repeated_test_ids[dataset]

    for id_, scores in results.items():

        if f"{id_}.pkl" in selected_ids: 
            merged_results_sieved[dataset][id_] = scores

In [27]:
# Because sorted = False ==> ascending order ==> bigger the rank, more difficult the example
datasets = ["massspecgym", "canopus", "nist2023"]
models = ["binned_4096", "mist_4096", "MS_4096", "mist"]
kept_set = {d : {} for d in datasets} 

for dataset in datasets:

    for model in models:

        id_list = [] 

        for k, s_pair in merged_results_sieved[dataset].items():
            
            sieved_s_pair = [p[1] > THRESHOLD for _, p in s_pair.items() if model in _]
            if True not in sieved_s_pair: continue 
            id_list.append(str(k))
        
        kept_set[dataset][model] = list(set(id_list))
    
    # Get difficult test samples across all models 
    current_id_list = list(set(id_list)) 
    for model in models: 
        current_id_list = [i for i in current_id_list if i in kept_set[dataset][model]]
    
    kept_set[dataset]["all"] = current_id_list


In [32]:
for dataset, rec in kept_set.items():

    for model, id_list in rec.items():
        
        print(dataset, model, round(len(id_list) / len(repeated_test_ids[dataset]) * 100, 1),  len(repeated_test_ids[dataset]))
    
    print()

massspecgym binned_4096 25.1 387
massspecgym mist_4096 17.3 387
massspecgym MS_4096 26.1 387
massspecgym mist 17.3 387
massspecgym all 9.0 387

canopus binned_4096 10.8 65
canopus mist_4096 7.7 65
canopus MS_4096 21.5 65
canopus mist 7.7 65
canopus all 0.0 65

nist2023 binned_4096 0.0 2926
nist2023 mist_4096 5.1 2926
nist2023 MS_4096 0.0 2926
nist2023 mist 5.1 2926
nist2023 all 0.0 2926

