Import libraries

In [9]:
import os 
import json
import itertools
import numpy as np
import pandas as pd
from scipy import stats
from utils import load_pickle, load_json

import torch

Helper Functions

In [61]:
def get_correlation(results, method1, method2):

    loss_method1, loss_method2 = [],[]
    rank_method1, rank_method2 = [],[]

    for _, v in results.items(): 

        if method1 in v and method2 in v:
            loss_method1.append(v[method1][0])
            loss_method2.append(v[method2][0])
            rank_method1.append(v[method1][1])
            rank_method2.append(v[method2][1])

    rank_correlation = round(float(stats.spearmanr(rank_method1, rank_method2).statistic), 3)
    linear_correlation = round(float(stats.pearsonr(loss_method1, loss_method2).statistic), 3)

    return rank_correlation, linear_correlation, len(loss_method1)

Get samples that appear across all splits

In [25]:
splits_folder = "/data/rbg/users/klingmin/projects/MS_processing/data_splits"
test_ids = {}

for dataset in os.listdir(splits_folder):
    
    test_ids[dataset] = {}
    all_ids = None

    for splits in os.listdir(os.path.join(splits_folder, dataset, "splits")):
        
        if "CF" in splits: continue

        current_ids = load_json(os.path.join(splits_folder, dataset, "splits", splits))["test"]

        if all_ids == None: all_ids = set(current_ids)
        all_ids = all_ids.intersection(current_ids)

    test_ids[dataset] = all_ids
    print(dataset, len(all_ids), list(all_ids)[:3])

nist2023 2926 ['296987.pkl', '7072.pkl', '343565.pkl']
nist2020 3052 ['1629975.pkl', '3172409.pkl', '3033903.pkl']
massspecgym 387 ['MassSpecGymID0027651.pkl', 'MassSpecGymID0208724.pkl', 'MassSpecGymID0226518.pkl']
canopus 65 ['8196.pkl', '1666.pkl', '11457.pkl']


Load in the scores

In [44]:
results_folder = "/data/rbg/users/klingmin/projects/ML_MS_analysis/FP_prediction/baseline_models/models_cached/morgan4_4096"
mist_results_folder = "/data/rbg/users/klingmin/projects/ML_MS_analysis/FP_prediction/mist/models_cached/morgan4_4096"

all_results = {} 

for folder in [results_folder, mist_results_folder]:

    for dataset in os.listdir(folder):

        dataset_folder = os.path.join(folder, dataset)

        if dataset not in all_results: all_results[dataset] = {} 

        for f in os.listdir(dataset_folder):

            test_filepath = os.path.join(dataset_folder, f, "test_results.pkl")
            test_results = load_pickle(test_filepath)
            test_scores = [v["loss"] for _, v in test_results.items()]
            test_scores = sorted(test_scores, reverse = True)

            # Add to all results
            for k, v in test_results.items():
                if torch.is_tensor(k): k = k.item()

                if k not in all_results[dataset]: all_results[dataset][k] = {}
                all_results[dataset][k][f] = (v["loss"], test_scores.index(v["loss"]) / len(test_scores))

Analyze MSG

In [58]:
MSG_test_results_sieved = {}

for dataset, v in all_results.items():
    if dataset != "massspecgym": continue 

    for k, scores in v.items():
        if f"{k}.pkl" in test_ids[dataset]: 
            MSG_test_results_sieved[k] = scores

In [78]:
# Because sorted = True ==> descending order ==> smaller the rank, more difficult the example
THRESHOLD = 0.2
kept_set = {} 

for model in ["binned", "mist", "MS"]:

    id_list = [] 

    for k, s_pair in MSG_test_results_sieved.items():

        sieved_s_pair = [p[1] > THRESHOLD for _, p in s_pair.items() if model in _]
        if True in sieved_s_pair: continue 
        id_list.append(k)
    
    kept_set[model] = id_list


In [1]:
kept_set

NameError: name 'kept_set' is not defined

In [31]:
for m1, m2 in itertools.combinations(all_methods, 2):

    spearman, pearson, support = get_correlation(all_results["canopus"], m1, m2)
    m1 = "_".join(m1.split("_")[1:])
    m2 = "_".join(m2.split("_")[1:])

    if support < 300: continue

    print(m1, m2, pearson, spearman, support)


MS_4096_random formula_4096_inchikey_vanilla 0.521 0.521 371
MS_4096_random binned_4096_inchikey_vanilla 0.536 0.527 424
MS_4096_random MS_4096_inchikey_vanilla 0.473 0.496 424
MS_4096_random binned_4096_random 0.881 0.893 2744
MS_4096_random formula_4096_random 0.797 0.807 394
MS_4096_random binned_4096_scaffold_vanilla 0.452 0.447 415
MS_4096_random formula_4096_scaffold_vanilla 0.47 0.394 365
MS_4096_random MS_4096_scaffold_vanilla 0.364 0.36 415
formula_4096_inchikey_vanilla binned_4096_inchikey_vanilla 0.828 0.833 329
formula_4096_inchikey_vanilla MS_4096_inchikey_vanilla 0.853 0.848 329
formula_4096_inchikey_vanilla binned_4096_random 0.467 0.47 371
formula_4096_inchikey_vanilla formula_4096_random 0.45 0.583 360
formula_4096_inchikey_vanilla binned_4096_scaffold_vanilla 0.633 0.65 421
formula_4096_inchikey_vanilla MS_4096_scaffold_vanilla 0.734 0.71 421
binned_4096_inchikey_vanilla MS_4096_inchikey_vanilla 0.858 0.884 2690
binned_4096_inchikey_vanilla binned_4096_random 0.475 0.