Import libraries

In [1]:
import os 
import math
import json
import torch
import numpy as np 
from tqdm import tqdm 
from sklearn.metrics.pairwise import cosine_similarity
from utils import load_pickle, load_json, pickle_data, compute_MS_sim

Settings

In [None]:
baseline_folder = "./cache/baselines"
data_folder = "/data/rbg/users/klingmin/projects/MS_processing/data/"
splits_folder = "/data/rbg/users/klingmin/projects/MS_processing/data_splits/"
batched_data_folder = [os.path.join(baseline_folder, f) for f in os.listdir(baseline_folder) if "nist2023" in f and os.path.isdir(os.path.join(baseline_folder, f))]

datasets = ["canopus", "massspecgym"]
splits = ["scaffold_vanilla", "inchikey_vanilla", "random", "LS"]

Helper functions

In [None]:
def string_to_bits(string): 

    bits = np.array([int(c) for c in string])

    return bits

@torch.no_grad()
def batch_jaccard_index(FP_pred, FP):

    # Intersection = bitwise AND
    intersection = np.logical_and(FP, FP_pred).sum(axis=1)

    # Union = bitwise OR
    union = np.logical_or(FP, FP_pred).sum(axis=1)

    # Avoid division-by-zero by adding a small epsilon
    jaccard_scores = intersection / (union + 1e-9)

    return jaccard_scores

def bin_MS(peaks, bin_resolution = 0.25, max_da = 2000):

    mz = [p["mz"] for p in peaks]
    intensities = [p["intensity"] for p in peaks]
    
    n_bins = int(math.ceil(max_da / bin_resolution))

    mz_binned = [0 for _ in range(n_bins)]
    for m, i in zip(mz, intensities):
        
        m = math.floor(m / bin_resolution)
        if m >= n_bins: continue 
        mz_binned[m] += i

    return mz_binned

Load in the data

In [None]:
dataset_info = {} 

canopus = load_pickle(os.path.join(data_folder, "canopus", "canopus_w_mol_info_w_frag_CF_preds.pkl"))
canopus = {str(r["id_"]) : r for r in canopus}
print("Done loading canopus")

massspecgym = load_pickle(os.path.join(data_folder, "massspecgym", "massspecgym_w_mol_info_w_frag_CF_preds.pkl"))
massspecgym = {str(r["id_"]) : r for r in massspecgym}
print("Done loading MSG")

dataset_info["canopus"] = canopus
dataset_info["massspecgym"] = massspecgym

Iterate through the files within each folder to get the nearest neighbour for each test (only for NIST2023)

For the full training

In [None]:
for folder in batched_data_folder:

    output_path = folder.replace("_batched", "") + ".pkl"
    if os.path.exists(output_path): continue 

    files = [os.path.join(folder, f) for f in os.listdir(folder)]

    best_scores = None

    print(f"Processing: {folder} now")
    
    for f in files:

        test_ids, train_id, sim = load_pickle(f)
        if best_scores is None: 
            best_scores = {test_ids[i]: {"train": train_id[i], "sim": sim[i]} for i in range(len(test_ids))}
        
        else:

            for i, test_id in enumerate(test_ids):
                current_best_sim = best_scores[test_id]["sim"]
                if current_best_sim < sim[i]:
                    best_scores[test_id] = {"train": train_id[i], "sim": sim[i]}

    pickle_data(best_scores, output_path)

`` Get the jaccard score now ``

`` 1. Across all train, using MS similarity ``

`` a. canopus and MSG ``

In [None]:
k = 1 
compute = False 

if compute: 

    for dataset in datasets: 

        for split in splits:
            
            similarity, test_ids, train_ids = load_pickle(os.path.join(folder, f"{dataset}_{split}.pkl"))

            # Get the top k 
            top_train_idx = np.argmax(similarity, axis = 1)

            # Get jaccard list 
            all_jaccard = []

            for i, test in enumerate(test_ids):

                top_train = str(train_ids[top_train_idx[i]])
                test = str(test)

                test_FP = string_to_bits(dataset_info[dataset][test]["FPs"]["morgan4_4096"])
                train_FP = string_to_bits(dataset_info[dataset][top_train]["FPs"]["morgan4_4096"])

                test_FP = np.expand_dims(test_FP, axis = 0)
                train_FP = np.expand_dims(train_FP, axis = 0)

                jaccard = batch_jaccard_index(train_FP, test_FP)
                all_jaccard.append(jaccard)

            print(dataset, split, np.mean(all_jaccard))

`` b. NIST2023 ``

In [None]:
compute = False 

if compute:

    for splitname in splits:

        results = load_pickle(f"./cache/baselines/nist2023_{splitname}.pkl")
        
        # Get jaccard list 
        all_jaccard = []
        
        for test_id, rec in tqdm(results.items()): 

            top_train = rec["train"]


            train_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{top_train}.pkl"))["FPs"]["morgan4_4096"])
            test_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{test_id}.pkl"))["FPs"]["morgan4_4096"])

            test_FP = np.expand_dims(test_FP, axis = 0)
            train_FP = np.expand_dims(train_FP, axis = 0)

            jaccard = batch_jaccard_index(train_FP, test_FP)
            all_jaccard.append(jaccard)
        
        print(splitname, np.mean(all_jaccard))
        
# scaffold_vanilla 0.181914345591279
# inchikey_vanilla 0.22553987089030317
# random 0.653309319738518
# LS 0.17827926093869034

`` 2. Across all train, using dreaMS ``

`` a. Canopus and MSG``

In [None]:
k = 1 
compute = False 

if compute: 

    for dataset in datasets: 

        if dataset != "massspecgym": continue

        for split in splits:
            
            similarity, test_ids, train_ids = load_pickle(os.path.join(baseline_folder, f"{dataset}_{split}_w_emb.pkl"))

            # Get the top k
            top_train_idx = np.argmax(similarity, axis = 1)
            
            # Get jaccard list 
            all_jaccard = []

            for i, test in enumerate(test_ids):

                train_idx = top_train_idx[i]
                if train_idx >= len(train_ids): print(i, train_idx)

                top_train = str(train_ids[train_idx])
                test = str(test)

                test_FP = string_to_bits(dataset_info[dataset][test]["FPs"]["morgan4_4096"])
                train_FP = string_to_bits(dataset_info[dataset][top_train]["FPs"]["morgan4_4096"])

                test_FP = np.expand_dims(test_FP, axis = 0)
                train_FP = np.expand_dims(train_FP, axis = 0)

                jaccard = batch_jaccard_index(train_FP, test_FP)
                all_jaccard.append(jaccard)

            print(dataset, split, np.mean(all_jaccard))

# canopus scaffold_vanilla 0.2612578164438585
# canopus inchikey_vanilla 0.3756794990803615
# canopus random 0.6902512724412937
# canopus LS 0.23034789850227105
# massspecgym scaffold_vanilla 0.296826992523911
# massspecgym inchikey_vanilla 0.3212731824194335
# massspecgym random 0.8947928743650205
# massspecgym LS 0.24703036830506597

`` b. NIST2023 ``

In [None]:
compute = True

if compute:

    for splitname in splits:

        results = load_pickle(f"./cache/baselines/nist2023_{splitname}_w_emb.pkl")
        
        # Get jaccard list 
        all_jaccard = []
        
        for test_id, rec in tqdm(results.items()): 

            top_train = rec["train"]

            train_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{top_train}.pkl"))["FPs"]["morgan4_4096"])
            test_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{test_id}.pkl"))["FPs"]["morgan4_4096"])

            test_FP = np.expand_dims(test_FP, axis = 0)
            train_FP = np.expand_dims(train_FP, axis = 0)

            jaccard = batch_jaccard_index(train_FP, test_FP)
            all_jaccard.append(jaccard)
        
        print(splitname, np.mean(all_jaccard))

# scaffold_vanilla 0.10875046435233811
# Inchikey_vanilla 0.12172197379638601
# random 0.7942182863047585
# LS 0.25098703847393744

`` 3. Identical CF using MS similarity ``

`` a. Canopus and MSG ``

In [None]:
k = 1 
compute = True 

if compute:

    for dataset in datasets:

        for split in splits:
            
            similarity, test_ids, train_ids = load_pickle(os.path.join(baseline_folder, f"{dataset}_{split}.pkl"))

            # Get jaccard list 
            all_jaccard = []

            for i, test in enumerate(test_ids):

                test_CF  = dataset_info[dataset][test]["formula"]

                current_similarity = similarity[i, :]
                sort_idx = np.argsort(current_similarity)[::-1]
                current_similarity = [current_similarity[i] for i in sort_idx]
                current_train_ids = [train_ids[i] for i in sort_idx]
                current_train_ids = [(idx, t) for idx, t in enumerate(current_train_ids) if 
                                    dataset_info[dataset][t]["formula"] == test_CF]
                
                if len(current_train_ids) == 0: continue 
                top_train = current_train_ids[0][1]

                test_FP = string_to_bits(dataset_info[dataset][test]["FPs"]["morgan4_4096"])
                train_FP = string_to_bits(dataset_info[dataset][top_train]["FPs"]["morgan4_4096"])

                test_FP = np.expand_dims(test_FP, axis = 0)
                train_FP = np.expand_dims(train_FP, axis = 0)
                jaccard = batch_jaccard_index(train_FP, test_FP)

                all_jaccard.append(jaccard)

            print(dataset, split, np.mean(all_jaccard))

# canopus scaffold_vanilla 0.2848325413342964
# canopus inchikey_vanilla 0.38684986445373354
# canopus random 0.824085252277534
# canopus LS 0.34640551573294376
# massspecgym scaffold_vanilla 0.4549473041238532
# massspecgym inchikey_vanilla 0.333642300735451
# massspecgym random 0.95552811447368
# massspecgym LS 0.5200194596102301

`` b. NIST2023 ``

In [None]:
compute = True

if compute:

    for splitname in splits:

        results = load_pickle(f"./cache/baselines/nist2023_{splitname}_same_CF.pkl")
        
        # Get jaccard list 
        all_jaccard = []
        
        for test_id, rec in tqdm(results.items()): 

            top_train = rec["train"]
            if top_train == "-": continue

            train_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{top_train}.pkl"))["FPs"]["morgan4_4096"])
            test_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{test_id}.pkl"))["FPs"]["morgan4_4096"])

            test_FP = np.expand_dims(test_FP, axis = 0)
            train_FP = np.expand_dims(train_FP, axis = 0)

            jaccard = batch_jaccard_index(train_FP, test_FP)

            print(jaccard, rec[@sim])
            all_jaccard.append(jaccard)
        
        print(splitname, np.mean(all_jaccard))
        

`` 4. Only for identical expt``

`` a. Canopus and MSG ``

In [None]:
k = 1 
compute = False 

if compute:

    for dataset in datasets: 

        if dataset != "massspecgym": continue 
        
        for split in splits:

            similarity, test_ids, train_ids = load_pickle(os.path.join(baseline_folder, f"{dataset}_{split}.pkl"))

            # Get jaccard list 
            all_jaccard = []

            for i, test in enumerate(test_ids):

                test_adduct  = dataset_info[dataset][test]["precursor_type"]
                test_instrument = dataset_info[dataset][test]["instrument_type"]

                current_similarity = similarity[i, :]
                sort_idx = np.argsort(current_similarity)[::-1]
                current_similarity = [current_similarity[i] for i in sort_idx]
                current_train_ids = [train_ids[i] for i in sort_idx]
                current_train_ids = [(idx, t) for idx, t in enumerate(current_train_ids) if 
                                    dataset_info[dataset][t]["precursor_type"] == test_adduct and
                                    dataset_info[dataset][t]["instrument_type"] == test_instrument]
                
                if len(current_train_ids) == 0: continue 
                top_train = current_train_ids[0][1]

                test_FP = string_to_bits(dataset_info[dataset][test]["FPs"]["morgan4_4096"])
                train_FP = string_to_bits(dataset_info[dataset][top_train]["FPs"]["morgan4_4096"])

                test_FP = np.expand_dims(test_FP, axis = 0)
                train_FP = np.expand_dims(train_FP, axis = 0)
                jaccard = batch_jaccard_index(train_FP, test_FP)

                all_jaccard.append(jaccard)

            print(dataset, split, np.mean(all_jaccard))

# canopus scaffold_vanilla 0.2057512538980379
# canopus inchikey_vanilla 0.2869633417787436
# canopus random 0.5315042784719447
# canopus LS 0.18068946559898919
# massspecgym scaffold_vanilla 0.22198074741488052
# massspecgym inchikey_vanilla 0.22648021955462153
# massspecgym random 0.7314962309652212
# massspecgym LS 0.1731793547288181

`` 5. Identical CF using dreaMS ``

`` a. Canopus and MSG``

In [None]:
k = 1 
compute = True 

if compute: 

    for dataset in datasets: 

        for split in splits:
            
            similarity, test_ids, train_ids = load_pickle(os.path.join(baseline_folder, f"{dataset}_{split}_w_emb.pkl"))

            # Get the top k
            top_train_idx = np.argmax(similarity, axis = 1)
            
            # Get jaccard list 
            all_jaccard = []

            for i, test in enumerate(test_ids):

                test_CF  = dataset_info[dataset][test]["formula"]

                current_similarity = similarity[i, :]
                sort_idx = np.argsort(current_similarity)[::-1]
                current_similarity = [current_similarity[i] for i in sort_idx]
                current_train_ids = [train_ids[i] for i in sort_idx]
                current_train_ids = [(idx, t) for idx, t in enumerate(current_train_ids) if 
                                    dataset_info[dataset][t]["formula"] == test_CF]
                
                if len(current_train_ids) == 0: continue 
                top_train = current_train_ids[0][1]

                test_FP = string_to_bits(dataset_info[dataset][test]["FPs"]["morgan4_4096"])
                train_FP = string_to_bits(dataset_info[dataset][top_train]["FPs"]["morgan4_4096"])

                test_FP = np.expand_dims(test_FP, axis = 0)
                train_FP = np.expand_dims(train_FP, axis = 0)
                jaccard = batch_jaccard_index(train_FP, test_FP)

                all_jaccard.append(jaccard)

            print(dataset, split, np.mean(all_jaccard))

# canopus scaffold_vanilla 0.29268316215631746
# canopus inchikey_vanilla 0.3866522313822027
# canopus random 0.8299599868742643
# canopus LS 0.35553839294013057
# massspecgym scaffold_vanilla 0.47139295762321953
# massspecgym inchikey_vanilla 0.35329460164624493
# massspecgym random 0.9656876401618149
# massspecgym LS 0.5272586113054561 


`` b. NIST2023 ``

In [None]:
compute = True

if compute:

    for splitname in splits:

        results = load_pickle(f"./cache/baselines/nist2023_{splitname}_w_emb_same_CF.pkl")
        
        # Get jaccard list 
        all_jaccard = []
        
        for test_id, rec in tqdm(results.items()): 

            top_train = rec["train"]
            if top_train == "-": continue

            train_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{top_train}.pkl"))["FPs"]["morgan4_4096"])
            test_FP = string_to_bits(load_pickle(os.path.join(data_folder, "nist2023", "frags_preds", f"{test_id}.pkl"))["FPs"]["morgan4_4096"])

            test_FP = np.expand_dims(test_FP, axis = 0)
            train_FP = np.expand_dims(train_FP, axis = 0)

            jaccard = batch_jaccard_index(train_FP, test_FP)
            all_jaccard.append(jaccard)
        
        print(splitname, np.mean(all_jaccard))

# scaffold_vanilla 0.09020181422765648
# inchikey_vanilla 0.09216237971678144