Import libraries

In [2]:
import os 
import copy
import collections
import numpy as np 
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from pprint import pprint
import matplotlib.pyplot as plt 

import torch

import rdkit.Chem as Chem
from rdkit import DataStructs
from rdkit.Chem import AllChem

from utils import load_pickle, pickle_data

Rough idea: 

    - In this part of the analysis, we look at the worst mistakes made by all the models and try to understand how data can potentially harm predictions 
    - We start with the hypothesis that having good coverage is helpful
        - To this effect, we look at how IDENTICAL molecules can affect the performance

Settings

In [3]:
TOP_K = 1000
CACHE_FOLDER = "./cache"
if not os.path.exists(CACHE_FOLDER): os.makedirs(CACHE_FOLDER)

baseline_models_folder = "../FP_prediction/baseline_models/best_models"
mist_folder = "../FP_prediction/mist/best_models"

DATA_FOLDER = Path("/data/rbg/users/klingmin/projects/MS_processing/data")
RESULTS_FOLDER = []

for folder in [mist_folder, baseline_models_folder]:
    for dataset in os.listdir(folder):
        for checkpoint in os.listdir(os.path.join(folder, dataset)):
            RESULTS_FOLDER.append(os.path.join(folder, dataset, checkpoint))

DIFF_EXPT_CONDITIONS = ["diff_adduct", "diff_instrument", "diff_CE",
                        "diff_adduct_instrument", "diff_adduct_CE",
                        "diff_instrument_CE", "diff_adduct_instrument_CE"]

Helper Functions

In [7]:
SELECTED_ID = 13
CHECKPOINT = RESULTS_FOLDER[SELECTED_ID]
DATASET = CHECKPOINT.split("/")[-2]
CURRENT_DATA_FOLDER = DATA_FOLDER / DATASET / "frags_preds"
CHECKPOINT_NAME = CHECKPOINT.split("/")[-1]
print(f"Analyzing: {CHECKPOINT_NAME} now")

test_ids = load_pickle(os.path.join(CHECKPOINT, "test_ids.pkl"))
# pairwise_scores = load_pickle(os.path.join(CHECKPOINT, "EK-FAC_scores_formatted.pkl"))

# Get the 3 different categories of test samples 
bad_mistakes = test_ids[:TOP_K]
good_predictions = test_ids[TOP_K: TOP_K * 2]
random_tests = test_ids[TOP_K * 2:]

Analyzing: MSG_formula_4096_scaffold_vanilla now


`` Task 1: Understand what harms predictions ``

1. Look at the worst mistakes and we aim to answer the following questions: 

    - What are actually harmful
    - Why are they harmful

``Hypothesis 1: molecule similarity``

Maybe the more different the molecule, the more harmful it is?

In [25]:
for test_id in test_ids:
    
    test_id = test_id.split("/")[-1]
    test_record = load_pickle(CURRENT_DATA_FOLDER / test_id.replace(".ms", ".pkl"))

    # Get the scores 
    train_records = pairwise_scores[test_id.replace(".pkl", "")]
    train_records = sorted(train_records.items(), key = lambda x: x[1])

    # Look at the most helpful 
    top_k_most_helpful = [load_pickle(CURRENT_DATA_FOLDER / f"{name}.pkl") for name, _ in train_records[-TOP_K:]]
    top_k_most_harmful = [load_pickle(CURRENT_DATA_FOLDER / f"{name}.pkl") for name, _ in train_records[:TOP_K]]

    print(top_k_most_helpful[0])
    a = z 

    # # Look at the most harmful 
    # top_k_most_harmful = sorted(harmful_train, key = lambda x: x[1])[:TOP_K]
    # 

    # Look at things at the feature level 

    avg_count_helpful, avg_count_harmful = [], []
    for i in range(TOP_K):
        avg_count_helpful.append(sum([int(c) for c in top_k_most_helpful[i]["FPs"]["morgan4_4096"]]))
        avg_count_harmful.append(sum([int(c) for c in top_k_most_harmful[i]["FPs"]["morgan4_4096"]]))


    print(np.mean(avg_count_helpful))
    print(np.mean(avg_count_harmful))
    a = z

{'id_': 'MassSpecGymID0230911', 'precursor_type': '[M+H]+', 'precursor_MZ': 564.26798, 'precursor_MZ_final': 564.26798, 'spectrum_type': 'MS2', 'instrument_type': 'Orbitrap', 'collision_energy': 60.0, 'NCE': None, 'inchikey_original': 'BZXMLCVDKDXRQY', 'inchikey': 'BZXMLCVDKDXRQY-UHFFFAOYSA-N', 'smiles': 'CC(C)(C)C1CCC(CC1)N(CC2=CC=C(C=C2)C(=O)NCCC(=O)O)C(=O)NC3=CC=C(C=C3)OC(F)(F)F', 'formula': 'C29H36F3N3O5', 'exact_mass': 563.260704, 'peaks': [{'mz': 41.038734, 'intensity': 1.106, 'intensity_norm': 1.106, 'comment': {'f': '', 'l': '', 'f_pred': 'C3H5'}}, {'mz': 53.038605, 'intensity': 0.8009999999999999, 'intensity_norm': 0.8009999999999999, 'comment': {'f': '', 'l': '', 'f_pred': 'C4H5'}}, {'mz': 55.054268, 'intensity': 30.836999999999996, 'intensity_norm': 30.836999999999996, 'comment': {'f': '', 'l': '', 'f_pred': 'C4H7'}}, {'mz': 57.069897, 'intensity': 62.135, 'intensity_norm': 62.135, 'comment': {'f': '', 'l': '', 'f_pred': 'C4H9'}}, {'mz': 57.072849, 'intensity': 0.81599999999

NameError: name 'z' is not defined

In [None]:
def get_diff_expt_param_no_energy(test_info, train_dict):

    if len(train_dict) == 0 : return {}

    test_adduct = test_info["precursor_type"]
    test_instrument = test_info["instrument_type"]

    diff_adduct = 0
    diff_instrument = 0 
    diff_adduct_instrument = 0 
    reason_diff = {}

    for _, rec in train_dict.items():

        train_info = rec["train_info"]
        train_adduct = train_info["precursor_type"]
        train_instrument = train_info["instrument_type"]

        check_diff_instrument = test_instrument != train_instrument
        check_diff_adduct = test_adduct != train_adduct

        if check_diff_adduct and check_diff_instrument: diff_adduct_instrument += 1 
        elif check_diff_instrument: diff_instrument += 1 
        elif check_diff_adduct: diff_adduct += 1 

    reason_diff["diff_adduct"] = diff_adduct
    reason_diff["diff_instrument"] = diff_instrument
    reason_diff["diff_adduct_instrument"] = diff_adduct_instrument
    
    return reason_diff

In [None]:
def analyze_influence_records(test_info, train_recs, include_energy = True):

    # Get number of times this training sample is helpful 
    helpful = {k:v for k,v in train_recs.items() if v["score"] > 0}
    harmful = {k:v for k,v in train_recs.items() if v["score"] < 0}

    # Separate out to experimental settings 
    helpful_same_expt = {k:v for k,v in helpful.items() if same_expt(test_info, v, include_energy = True)}
    helpful_diff_expt = {k:v for k,v in helpful.items() if not same_expt(test_info, v, include_energy = True)}

    harmful_same_expt = {k:v for k,v in harmful.items() if same_expt(test_info, v, include_energy = True)}
    harmful_diff_expt = {k:v for k,v in harmful.items() if not same_expt(test_info, v, include_energy = True)}

    # Get the reasons for different experiments
    if include_energy:
        reason_helpful_diff_expt = get_diff_expt_param_w_energy(test_info, helpful_diff_expt)
        reason_harmful_diff_expt = get_diff_expt_param_w_energy(test_info, harmful_diff_expt)

    else: 
        reason_helpful_diff_expt = get_diff_expt_param_no_energy(test_info, helpful_diff_expt)
        reason_harmful_diff_expt = get_diff_expt_param_no_energy(test_info, harmful_diff_expt)

    # Get the distance between the helpful and harmful molecules 
    helpful_dist = compute_dist_mols(test_info, helpful)
    harmful_dist = compute_dist_mols(test_info, harmful)

    # Merge the results 
    results = {} 
    results["helpful"] = {"count": len(helpful)}
    results["harmful"] = {"count": len(harmful)}

    results["helpful"]["same_expt_count"] = len(helpful_same_expt) 
    results["helpful"]["diff_expt_count"] = len(helpful_diff_expt) 
    results["helpful"]["diff_expt_reason"] = reason_helpful_diff_expt
    results["helpful"]["dist"] = helpful_dist

    results["harmful"]["same_expt_count"] = len(harmful_same_expt) 
    results["harmful"]["diff_expt_count"] = len(harmful_diff_expt)
    results["harmful"]["diff_expt_reason"] = reason_harmful_diff_expt
    results["harmful"]["dist"] = harmful_dist

    results["same_expt"] = {"helpful": len(helpful_same_expt), 
                            "harmful": len(harmful_same_expt),
                            "count": len(helpful_same_expt) + len(harmful_same_expt)}

    results["diff_expt"] = {"helpful": len(helpful_diff_expt), 
                            "harmful": len(harmful_diff_expt),
                            "count": len(helpful_diff_expt) + len(harmful_diff_expt)}
    
    return results

In [None]:
def merge_reasons(master, sub):

    if sub == {}: return master

    for k, _ in master.items():
        if sub[k] is None: continue 
        master[k] += sub[k]

    return master 

def merge_results_wo_energy(results):

    merged = {}
    merged["diff_mol"] = {"same_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                          "diff_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                                         
                          "harmful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "dist": 0,
                                       "diff_expt_reason": {"diff_adduct": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_instrument": 0}},
                          "helpful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "dist": 0,
                                       "diff_expt_reason": {"diff_adduct": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_instrument": 0}}}
    
    merged["same_mol"] = {"same_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                          "diff_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                          "harmful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "diff_expt_reason": {"diff_adduct": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_instrument": 0}},
                          "helpful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "diff_expt_reason": {"diff_adduct": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_instrument": 0}}}

    for _, r in results.items():

        # For different mols
        merged["diff_mol"]["same_expt"]["helpful"] += r["diff_mol"]["same_expt"]["helpful"] 
        merged["diff_mol"]["same_expt"]["harmful"] += r["diff_mol"]["same_expt"]["harmful"] 
        merged["diff_mol"]["same_expt"]["count"] += r["diff_mol"]["same_expt"]["count"] 

        merged["diff_mol"]["diff_expt"]["helpful"] += r["diff_mol"]["diff_expt"]["helpful"] 
        merged["diff_mol"]["diff_expt"]["harmful"] += r["diff_mol"]["diff_expt"]["harmful"] 
        merged["diff_mol"]["diff_expt"]["count"] += r["diff_mol"]["diff_expt"]["count"]   

        merged["diff_mol"]["harmful"]["diff_expt_count"] += r["diff_mol"]["harmful"]["diff_expt_count"] 
        merged["diff_mol"]["harmful"]["same_expt_count"] += r["diff_mol"]["harmful"]["same_expt_count"] 
        merged["diff_mol"]["harmful"]["count"] += r["diff_mol"]["harmful"]["count"]   
        merged["diff_mol"]["harmful"]["dist"] += r["diff_mol"]["harmful"]["dist"]   
        merged["diff_mol"]["harmful"]["diff_expt_reason"] = merge_reasons(merged["diff_mol"]["harmful"]["diff_expt_reason"], r["diff_mol"]["harmful"]["diff_expt_reason"]) 

        merged["diff_mol"]["helpful"]["diff_expt_count"] += r["diff_mol"]["helpful"]["diff_expt_count"] 
        merged["diff_mol"]["helpful"]["same_expt_count"] += r["diff_mol"]["helpful"]["same_expt_count"] 
        merged["diff_mol"]["helpful"]["count"] += r["diff_mol"]["helpful"]["count"]   
        merged["diff_mol"]["helpful"]["dist"] += r["diff_mol"]["helpful"]["dist"]   
        merged["diff_mol"]["helpful"]["diff_expt_reason"] = merge_reasons(merged["diff_mol"]["helpful"]["diff_expt_reason"], r["diff_mol"]["helpful"]["diff_expt_reason"]) 

        # For same mols
        merged["same_mol"]["same_expt"]["helpful"] += r["same_mol"]["same_expt"]["helpful"] 
        merged["same_mol"]["same_expt"]["harmful"] += r["same_mol"]["same_expt"]["harmful"] 
        merged["same_mol"]["same_expt"]["count"] += r["same_mol"]["same_expt"]["count"] 

        merged["same_mol"]["diff_expt"]["helpful"] += r["same_mol"]["diff_expt"]["helpful"] 
        merged["same_mol"]["diff_expt"]["harmful"] += r["same_mol"]["diff_expt"]["harmful"] 
        merged["same_mol"]["diff_expt"]["count"] += r["same_mol"]["diff_expt"]["count"] 

        merged["same_mol"]["harmful"]["diff_expt_count"] += r["same_mol"]["harmful"]["diff_expt_count"] 
        merged["same_mol"]["harmful"]["same_expt_count"] += r["same_mol"]["harmful"]["same_expt_count"] 
        merged["same_mol"]["harmful"]["count"] += r["same_mol"]["harmful"]["count"]   
        merged["same_mol"]["harmful"]["diff_expt_reason"] = merge_reasons(merged["same_mol"]["harmful"]["diff_expt_reason"], r["same_mol"]["harmful"]["diff_expt_reason"]) 

        merged["same_mol"]["helpful"]["diff_expt_count"] += r["same_mol"]["helpful"]["diff_expt_count"] 
        merged["same_mol"]["helpful"]["same_expt_count"] += r["same_mol"]["helpful"]["same_expt_count"] 
        merged["same_mol"]["helpful"]["count"] += r["same_mol"]["helpful"]["count"]   
        merged["same_mol"]["helpful"]["diff_expt_reason"] = merge_reasons(merged["same_mol"]["helpful"]["diff_expt_reason"], r["same_mol"]["helpful"]["diff_expt_reason"]) 

    return merged

def merge_results_w_energy(results):

    merged = {}
    merged["diff_mol"] = {"same_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                          "diff_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                                         
                          "harmful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "dist": 0,
                                       "diff_expt_reason": {"total_diff_CE": 0,
                                                           "diff_CE": 0,
                                                           "diff_adduct": 0,
                                                           "diff_adduct_CE": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_adduct_instrument_CE": 0,
                                                           "diff_instrument": 0,
                                                           "diff_instrument_CE": 0}},
                          "helpful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "dist": 0,
                                       "diff_expt_reason": {"total_diff_CE": 0,
                                                           "diff_CE": 0,
                                                           "diff_adduct": 0,
                                                           "diff_adduct_CE": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_adduct_instrument_CE": 0,
                                                           "diff_instrument": 0,
                                                           "diff_instrument_CE": 0}}}
    
    merged["same_mol"] = {"same_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                          "diff_expt" : {"helpful": 0,
                                         "harmful": 0,
                                         "count": 0},
                          "harmful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "diff_expt_reason": {"total_diff_CE":0,
                                                           "diff_CE": 0,
                                                           "diff_adduct": 0,
                                                           "diff_adduct_CE": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_adduct_instrument_CE": 0,
                                                           "diff_instrument": 0,
                                                           "diff_instrument_CE": 0}},
                          "helpful" : {"diff_expt_count": 0,
                                       "same_expt_count": 0,
                                       "count": 0,
                                       "diff_expt_reason": {"total_diff_CE": 0,
                                                           "diff_CE": 0,
                                                           "diff_adduct": 0,
                                                           "diff_adduct_CE": 0,
                                                           "diff_adduct_instrument": 0,
                                                           "diff_adduct_instrument_CE": 0,
                                                           "diff_instrument": 0,
                                                           "diff_instrument_CE": 0}}}

    for _, r in results.items():

        # For different mols
        merged["diff_mol"]["same_expt"]["helpful"] += r["diff_mol"]["same_expt"]["helpful"] 
        merged["diff_mol"]["same_expt"]["harmful"] += r["diff_mol"]["same_expt"]["harmful"] 
        merged["diff_mol"]["same_expt"]["count"] += r["diff_mol"]["same_expt"]["count"] 

        merged["diff_mol"]["diff_expt"]["helpful"] += r["diff_mol"]["diff_expt"]["helpful"] 
        merged["diff_mol"]["diff_expt"]["harmful"] += r["diff_mol"]["diff_expt"]["harmful"] 
        merged["diff_mol"]["diff_expt"]["count"] += r["diff_mol"]["diff_expt"]["count"]   

        merged["diff_mol"]["harmful"]["diff_expt_count"] += r["diff_mol"]["harmful"]["diff_expt_count"] 
        merged["diff_mol"]["harmful"]["same_expt_count"] += r["diff_mol"]["harmful"]["same_expt_count"] 
        merged["diff_mol"]["harmful"]["count"] += r["diff_mol"]["harmful"]["count"]   
        merged["diff_mol"]["harmful"]["dist"] += r["diff_mol"]["harmful"]["dist"]   
        merged["diff_mol"]["harmful"]["diff_expt_reason"] = merge_reasons(merged["diff_mol"]["harmful"]["diff_expt_reason"], r["diff_mol"]["harmful"]["diff_expt_reason"]) 

        merged["diff_mol"]["helpful"]["diff_expt_count"] += r["diff_mol"]["helpful"]["diff_expt_count"] 
        merged["diff_mol"]["helpful"]["same_expt_count"] += r["diff_mol"]["helpful"]["same_expt_count"] 
        merged["diff_mol"]["helpful"]["count"] += r["diff_mol"]["helpful"]["count"]   
        merged["diff_mol"]["helpful"]["dist"] += r["diff_mol"]["helpful"]["dist"]   
        merged["diff_mol"]["helpful"]["diff_expt_reason"] = merge_reasons(merged["diff_mol"]["helpful"]["diff_expt_reason"], r["diff_mol"]["helpful"]["diff_expt_reason"]) 

        # For same mols
        merged["same_mol"]["same_expt"]["helpful"] += r["same_mol"]["same_expt"]["helpful"] 
        merged["same_mol"]["same_expt"]["harmful"] += r["same_mol"]["same_expt"]["harmful"] 
        merged["same_mol"]["same_expt"]["count"] += r["same_mol"]["same_expt"]["count"] 

        merged["same_mol"]["diff_expt"]["helpful"] += r["same_mol"]["diff_expt"]["helpful"] 
        merged["same_mol"]["diff_expt"]["harmful"] += r["same_mol"]["diff_expt"]["harmful"] 
        merged["same_mol"]["diff_expt"]["count"] += r["same_mol"]["diff_expt"]["count"] 

        merged["same_mol"]["harmful"]["diff_expt_count"] += r["same_mol"]["harmful"]["diff_expt_count"] 
        merged["same_mol"]["harmful"]["same_expt_count"] += r["same_mol"]["harmful"]["same_expt_count"] 
        merged["same_mol"]["harmful"]["count"] += r["same_mol"]["harmful"]["count"]   
        merged["same_mol"]["harmful"]["diff_expt_reason"] = merge_reasons(merged["same_mol"]["harmful"]["diff_expt_reason"], r["same_mol"]["harmful"]["diff_expt_reason"]) 

        merged["same_mol"]["helpful"]["diff_expt_count"] += r["same_mol"]["helpful"]["diff_expt_count"] 
        merged["same_mol"]["helpful"]["same_expt_count"] += r["same_mol"]["helpful"]["same_expt_count"] 
        merged["same_mol"]["helpful"]["count"] += r["same_mol"]["helpful"]["count"]   
        merged["same_mol"]["helpful"]["diff_expt_reason"] = merge_reasons(merged["same_mol"]["helpful"]["diff_expt_reason"], r["same_mol"]["helpful"]["diff_expt_reason"]) 

    return merged

def merge_results(results, include_energy):
    if include_energy: return merge_results_w_energy(results)
    return merge_results_wo_energy(results)

In [None]:
def get_average_reasons_w_energy(results):

    results = copy.deepcopy(results)
    new_reasons = {}
    n_expt_diff_CE = results["diff_expt_reason"]["diff_CE"] + \
                     results["diff_expt_reason"]["diff_adduct_CE"] +\
                     results["diff_expt_reason"]["diff_adduct_instrument_CE"] +\
                     results["diff_expt_reason"]["diff_instrument_CE"]

    new_reasons["avg_CE"]= round(results["diff_expt_reason"]["total_diff_CE"] / n_expt_diff_CE, 3) 
    new_reasons["percent_diff_CE"]= round(results["diff_expt_reason"]["diff_CE"] / (1  + results["diff_expt_count"]) * 100, 3) 
    new_reasons["percent_diff_adduct"]= round(results["diff_expt_reason"]["diff_adduct"] / (1  + results["diff_expt_count"]) * 100, 3) 
    new_reasons["percent_diff_adduct_CE"]= round(results["diff_expt_reason"]["diff_adduct_CE"] / (1  + results["diff_expt_count"])* 100, 3) 
    new_reasons["percent_diff_adduct_instrument"]= round(results["diff_expt_reason"]["diff_adduct_instrument"] / (1  + results["diff_expt_count"])* 100, 3) 
    new_reasons["percent_diff_adduct_instrument_CE"]= round(results["diff_expt_reason"]["diff_adduct_instrument_CE"] / (1  + results["diff_expt_count"])* 100, 3) 
    new_reasons["percent_diff_instrument"]= round(results["diff_expt_reason"]["diff_instrument"] / (1  + results["diff_expt_count"]) * 100, 3) 
    new_reasons["percent_diff_instrument_CE"]= round(results["diff_expt_reason"]["diff_instrument_CE"] / (1  + results["diff_expt_count"]) * 100, 3) 

    results["diff_expt_reason"] = new_reasons

    return results 

def get_average_reasons_wo_energy(results):

    results = copy.deepcopy(results)
    new_reasons = {}

    new_reasons["percent_diff_adduct"]= round(results["diff_expt_reason"]["diff_adduct"] / (1  + results["diff_expt_count"]) * 100, 3) 
    new_reasons["percent_diff_adduct_instrument"]= round(results["diff_expt_reason"]["diff_adduct_instrument"] / (1  + results["diff_expt_count"]) * 100, 3) 
    new_reasons["percent_diff_instrument"]= round(results["diff_expt_reason"]["diff_instrument"] /(1  + results["diff_expt_count"]) * 100, 3) 

    results["diff_expt_reason"] = new_reasons

    return results 

def get_average_reasons(results, include_energy = True):

    if include_energy: return get_average_reasons_w_energy(results)
    
    return get_average_reasons_wo_energy(results)

def get_average(merged_results, include_energy = True):
    
    # Diff mol
    merged_results["diff_mol"]["diff_expt"]["percent_harmful"] = round(merged_results["diff_mol"]["diff_expt"]["harmful"] / merged_results["diff_mol"]["diff_expt"]["count"] * 100, 3) 
    merged_results["diff_mol"]["diff_expt"]["percent_helpful"] = round(merged_results["diff_mol"]["diff_expt"]["helpful"] / merged_results["diff_mol"]["diff_expt"]["count"] * 100, 3) 

    merged_results["diff_mol"]["same_expt"]["percent_harmful"] = round(merged_results["diff_mol"]["same_expt"]["harmful"] / merged_results["diff_mol"]["same_expt"]["count"] * 100, 3) 
    merged_results["diff_mol"]["same_expt"]["percent_helpful"] = round(merged_results["diff_mol"]["same_expt"]["helpful"] / merged_results["diff_mol"]["same_expt"]["count"] * 100, 3) 
    
    merged_results["diff_mol"]["harmful"]["percent"] = round(merged_results["diff_mol"]["harmful"]["count"] / (merged_results["diff_mol"]["harmful"]["count"] + merged_results["diff_mol"]["helpful"]["count"] ) * 100, 3) 
    merged_results["diff_mol"]["helpful"]["percent"] = round(merged_results["diff_mol"]["helpful"]["count"] / (merged_results["diff_mol"]["harmful"]["count"] + merged_results["diff_mol"]["helpful"]["count"] ) * 100, 3) 
    
    merged_results["diff_mol"]["harmful"]["percent_same_expt"] = round(merged_results["diff_mol"]["harmful"]["same_expt_count"] / (merged_results["diff_mol"]["harmful"]["same_expt_count"] + merged_results["diff_mol"]["helpful"]["diff_expt_count"] ) * 100, 3) 
    merged_results["diff_mol"]["harmful"]["percent_diff_expt"] = round(merged_results["diff_mol"]["harmful"]["diff_expt_count"] / (merged_results["diff_mol"]["harmful"]["same_expt_count"] + merged_results["diff_mol"]["helpful"]["diff_expt_count"] ) * 100, 3) 
    merged_results["diff_mol"]["helpful"]["percent_same_expt"] = round(merged_results["diff_mol"]["helpful"]["same_expt_count"] / (merged_results["diff_mol"]["helpful"]["same_expt_count"] + merged_results["diff_mol"]["helpful"]["diff_expt_count"] ) * 100, 3) 
    merged_results["diff_mol"]["helpful"]["percent_diff_expt"] = round(merged_results["diff_mol"]["helpful"]["diff_expt_count"] / (merged_results["diff_mol"]["helpful"]["same_expt_count"] + merged_results["diff_mol"]["helpful"]["diff_expt_count"] ) * 100, 3) 

    merged_results["diff_mol"]["harmful"]["avg_dist"] = round(merged_results["diff_mol"]["harmful"]["dist"]  / merged_results["diff_mol"]["harmful"]["count"], 3) 
    merged_results["diff_mol"]["helpful"]["avg_dist"] = round(merged_results["diff_mol"]["helpful"]["dist"]  / merged_results["diff_mol"]["helpful"]["count"], 3) 
    
    # Same mol
    merged_results["same_mol"]["harmful"]["percent"] = round(merged_results["same_mol"]["harmful"]["count"] / (merged_results["same_mol"]["harmful"]["count"] + merged_results["same_mol"]["helpful"]["count"] ) * 100, 3) 
    merged_results["same_mol"]["helpful"]["percent"] = round(merged_results["same_mol"]["helpful"]["count"] / (merged_results["same_mol"]["harmful"]["count"] + merged_results["same_mol"]["helpful"]["count"] ) * 100, 3) 
    # Diff expt
    merged_results["diff_expt"] = {}
    merged_results["diff_expt"]["count"] = merged_results["same_mol"]["diff_expt"]["count"] + merged_results["diff_mol"]["diff_expt"]["count"]
    merged_results["diff_expt"]["harmful"] = merged_results["same_mol"]["diff_expt"]["harmful"] + merged_results["diff_mol"]["diff_expt"]["harmful"]
    merged_results["diff_expt"]["helpful"] = merged_results["same_mol"]["diff_expt"]["helpful"] + merged_results["diff_mol"]["diff_expt"]["helpful"]

    merged_results["diff_expt"]["percent_harmful"] = round(merged_results["diff_expt"]["harmful"] / merged_results["diff_expt"]["count"] * 100,3)
    merged_results["diff_expt"]["percent_helpful"] = round(merged_results["diff_expt"]["helpful"] / merged_results["diff_expt"]["count"] * 100,3)

    # Same expt
    merged_results["same_expt"] = {}
    merged_results["same_expt"]["count"] = merged_results["same_mol"]["same_expt"]["count"] + merged_results["diff_mol"]["same_expt"]["count"]
    merged_results["same_expt"]["harmful"] = merged_results["same_mol"]["same_expt"]["harmful"] + merged_results["diff_mol"]["same_expt"]["harmful"]
    merged_results["same_expt"]["helpful"] = merged_results["same_mol"]["same_expt"]["helpful"] + merged_results["diff_mol"]["same_expt"]["helpful"]

    merged_results["same_expt"]["percent_harmful"] = round(merged_results["same_expt"]["harmful"] / merged_results["same_expt"]["count"] * 100,3)
    merged_results["same_expt"]["percent_helpful"] = round(merged_results["same_expt"]["helpful"] / merged_results["same_expt"]["count"] * 100,3)

    merged_results["diff_mol"]["harmful"]["diff_expt_reason"] = get_average_reasons(merged_results["diff_mol"]["harmful"], include_energy = include_energy)
    merged_results["diff_mol"]["helpful"]["diff_expt_reason"] = get_average_reasons(merged_results["diff_mol"]["helpful"], include_energy = include_energy)

    merged_results["same_mol"]["harmful"]["diff_expt_reason"] = get_average_reasons(merged_results["same_mol"]["harmful"], include_energy = include_energy)
    merged_results["same_mol"]["helpful"]["diff_expt_reason"] = get_average_reasons(merged_results["same_mol"]["helpful"], include_energy = include_energy)

    return merged_results

Get information on the random results

In [None]:
folders = ["../FP_prediction/mist/models_cached/morgan4_4096", "../FP_prediction/baseline_models/models_cached/morgan4_4096"]
all_info_random_path = os.path.join(cache_folder, "all_info_random.pkl")

for folder in folders:
    
    
    folder = Path(folder)

    for dataset in ["canopus", "massspecgym"]:
        
        if dataset == "canopus": include_energy = False
        else: include_energy = True 
            
            if dataset == "canopus" and "mist" not in folder: inchikey_key = "inchikey_original"
            else: inchikey_key = "inchikey"

            current_folder = os.path.join(folder, dataset)
            if dataset not in all_info_random: all_info_random[dataset] = {}

            for model in tqdm(os.listdir(current_folder)):

                # Analyze for random only 
                if "random" not in model: continue

                model_folder = os.path.join(current_folder, model)
                influence_scores_path = os.path.join(model_folder, "EK-FAC_scores_w_info.pkl")
                if not os.path.exists(influence_scores_path): continue
                influence_scores = load_pickle(influence_scores_path)
                merged_results = {}
                
                for test_id, results in tqdm(influence_scores.items()):
                    
                    test_info = list(results.values())[0]["test_info"]
                    test_mol = test_info["inchikey"]

                    train_same_mol = {train_id: v for train_id, v in results.items() if test_mol[:14] == v["train_info"][inchikey_key][:14]}
                    train_diff_mol = {train_id: v for train_id, v in results.items() if test_mol[:14] != v["train_info"][inchikey_key][:14]}

                    same_mol_analysis = analyze_influence_records(test_info, train_same_mol, include_energy = include_energy)
                    diff_mol_analysis = analyze_influence_records(test_info, train_diff_mol, include_energy = include_energy)

                    merged_results[test_id] = {"same_mol": same_mol_analysis,
                                            "diff_mol": diff_mol_analysis}
                    
                merged_results = merge_results(merged_results, include_energy = include_energy)        
                merged_results = get_average(merged_results, include_energy = include_energy)
                all_info_random[dataset][model] = merged_results

    pickle_data(all_info_random, all_info_random_path)

all_info_random = load_pickle(all_info_random_path)

Get percentage of positive vs negative influence scores

In [None]:
folders = ["../FP_prediction/baseline_models/models_cached/morgan4_4096", ]
percent_scores_path = os.path.join(cache_folder, "percent_scores.pkl")

if not os.path.exists(percent_scores_path):

    all_percent_scores = {}

    for folder in folders:
        
        for dataset in tqdm(["massspecgym", "canopus"]):
            
            if dataset not in all_percent_scores: all_percent_scores[dataset] = {} 

            current_folder = os.path.join(folder, dataset)

            for model in tqdm(os.listdir(current_folder)):

                model_folder = os.path.join(current_folder, model)
            
                train_ids_path = os.path.join(model_folder, "train_ids.pkl")
                test_ids_path = os.path.join(model_folder, "test_ids.pkl")
                influence_scores_path = os.path.join(model_folder, "EK-FAC_scores.pkl")

                train_ids = load_pickle(train_ids_path)
                test_ids = load_pickle(test_ids_path)
                influence_scores = load_pickle(influence_scores_path)

                percent_pos = round((influence_scores["all_modules"] > 0).float().mean(-1).mean().item() * 100, 3)
                percent_neg = round((influence_scores["all_modules"] < 0).float().mean(-1).mean().item() * 100, 3)

                all_percent_scores[dataset][model] = {"pos": percent_pos, "neg": percent_neg}
    
    pickle_data(all_percent_scores, percent_scores_path)

else:
    all_percent_scores = load_pickle(percent_scores_path)

pprint(all_percent_scores)

Plot percentage of negative training examples

In [None]:
for dataset in ["massspecgym", "canopus"]:

    plt.figure(figsize=(8, 6))

    models = ["binned_4096", "MS_4096", "formula_4096"]
    splits = ["scaffold", "inchikey", "random"]

    bar_width = 0.23
    random, scaffold, inchikey = [],[],[]

    for model in models:
        random.extend([s["neg"] for k, s in all_percent_scores[dataset].items() if model in k and "random" in k])
        scaffold.extend([s["neg"] for k, s in all_percent_scores[dataset].items() if model in k and "scaffold" in k])   
        inchikey.extend([s["neg"] for k, s in all_percent_scores[dataset].items() if model in k and "inchikey" in k])    

    random_bar = np.arange(len(models)) 
    scaffold_bar = [x + bar_width + 0.01 for x in random_bar] 
    inchikey_bar = [x + bar_width + 0.01 for x in scaffold_bar] 

    plt.bar(random_bar, random, width = bar_width, label ='random', color = "#006d77") 
    plt.bar(scaffold_bar, scaffold, width = bar_width, label ='scaffold', color = "#83c5be") 
    plt.bar(inchikey_bar, inchikey, width = bar_width, label ='inchikey', color = "#e94f37") 

    for i in range(len(random)):
        plt.text(random_bar[i], random[i], str(round(random[i], 1)) + "%", ha='center', va='bottom')
        plt.text(scaffold_bar[i] + 0.01, scaffold[i], str(round(scaffold[i], 1)) + "%", ha='center', va='bottom')
        plt.text(inchikey_bar[i] + 0.01, inchikey[i], str(round(inchikey[i], 1)) + "%", ha='center', va='bottom')

    if dataset == "canopus":
        plt.title(f"Canopus", fontsize = 14)
    if dataset == "massspecgym":
        plt.title(f"MassSpecGym", fontsize = 14)
    
    plt.xlabel('Models', fontsize = 12)
    plt.ylabel('Train samples with negative influence score (%)', fontsize = 12)
    plt.xticks([r + bar_width for r in range(len(models))], ['Binned MLP', 'MS Transformer', 'Formula Transformer'])
    plt.legend()
    plt.show()
    

Look at distribution of harmful vs helpful for identical molecules

In [None]:
for dataset in ["massspecgym", "canopus"]:

    fig, ax = plt.subplots(layout='constrained', figsize=(8, 6))
    bar_width = 0.23
    helpful, harmful = [],[]

    models = ["binned_4096", "MS_4096", "formula_4096", "mist_4096"]

    for model in models:

        helpful.extend([s["same_mol"]["helpful"]["percent"] for k, s in all_info_random[dataset].items() if model in k])
        harmful.extend([s["same_mol"]["harmful"]["percent"] for k, s in all_info_random[dataset].items() if model in k])

        assert helpful[-1] + harmful[-1] == 100
        
    helpful_bar = np.arange(len(helpful))
    harmful_bar = [x + bar_width + 0.01 for x in helpful_bar] 
    
    # plot bars
    plt.bar(helpful_bar, helpful, width = bar_width, label = "Helpful", color = "#D3D3D3")
    plt.bar(harmful_bar, harmful, width = bar_width, label = "Harmful", color = "#F6BD60") 

    if dataset == "canopus":
        plt.title(f"Canopus", fontsize = 14)
    if dataset == "massspecgym":
        plt.title(f"MassSpecGym", fontsize = 14)
    
    plt.xlabel('Models', fontsize = 12)
    plt.ylabel('Effect of identical molecules on test samples', fontsize = 12)
    plt.ylim(bottom = 0, top = 100)
    plt.xticks([r + bar_width /2 for r in range(len(models))], ['Binned MLP', 'MS Transformer', 'Formula Transformer', "MIST"])

    for i in range(len(helpful)):
        plt.text(helpful_bar[i], helpful[i], str(round(helpful[i], 1)) + "%", ha='center', va='bottom')
        plt.text(harmful_bar[i], harmful[i], str(round(harmful[i], 1)) + "%", ha='center', va='bottom')

    # pos = []
    # for r in range(len(harmful_bar)):
    #     pos.extend([r + bar_width / 10, r + bar_width])

    # plt.xticks(pos, ["Harmful", "Helpful"] * len(harmful_bar), fontsize = 8)

    # # plt.xlabel(" "  * 15 + "Binned MLP" + " " * 45 + "MS Transformer" + " " * 30 + "Formula Transformer")
    # # plt.xticks([r + bar_width/2 for r in range(len(helpful_bar))], ['Binned MLP', 'MS Transformer', 'Formula Transformer'])
    # ax2 = ax.secondary_xaxis(-0.05)
    # ax2.set_xticks([])
    # ax2.set_xticks([r + bar_width/2 for r in range(len(helpful_bar))], ['Binned MLP', 'MS Transformer', 'Formula Transformer'])

    plt.legend()

In [None]:
for dataset in ["massspecgym", "canopus"]:

    fig, ax = plt.subplots(layout='constrained', figsize=(8, 6))
    bar_width = 0.23
    helpful, harmful = [],[]

    models = ["binned_4096", "MS_4096", "formula_4096", "mist_4096"]

    for model in models:

        helpful.extend([s["diff_mol"]["helpful"]["avg_dist"] for k, s in all_info_random[dataset].items() if model in k])
        harmful.extend([s["diff_mol"]["harmful"]["avg_dist"] for k, s in all_info_random[dataset].items() if model in k])

    helpful_bar = np.arange(len(helpful))
    harmful_bar = [x + bar_width + 0.01 for x in helpful_bar] 
    
    # plot bars
    plt.bar(helpful_bar, helpful, width = bar_width, label = "Helpful", color = "#D3D3D3")
    plt.bar(harmful_bar, harmful, width = bar_width, label = "Harmful", color = "#F6BD60") 

    if dataset == "canopus":
        plt.title(f"Canopus", fontsize = 14)
    if dataset == "massspecgym":
        plt.title(f"MassSpecGym", fontsize = 14)
    
    plt.xlabel('Models', fontsize = 12)
    plt.ylabel('Tanimoto similarity of training molecules', fontsize = 12)
    plt.ylim(bottom = 0, top = 0.3)

    # plt.ylim(bottom = 0, top = 100)
    plt.xticks([r + bar_width /2 for r in range(len(models))], ['Binned MLP', 'MS Transformer', 'Formula Transformer', "MIST"])

    for i in range(len(helpful)):
        plt.text(helpful_bar[i], helpful[i], str(round(helpful[i], 3)), ha='center', va='bottom')
        plt.text(harmful_bar[i], harmful[i], str(round(harmful[i], 3)), ha='center', va='bottom')

    # pos = []
    # for r in range(len(harmful_bar)):
    #     pos.extend([r + bar_width / 10, r + bar_width])

    # plt.xticks(pos, ["Harmful", "Helpful"] * len(harmful_bar), fontsize = 8)

    # # plt.xlabel(" "  * 15 + "Binned MLP" + " " * 45 + "MS Transformer" + " " * 30 + "Formula Transformer")
    # # plt.xticks([r + bar_width/2 for r in range(len(helpful_bar))], ['Binned MLP', 'MS Transformer', 'Formula Transformer'])
    # ax2 = ax.secondary_xaxis(-0.05)
    # ax2.set_xticks([])
    # ax2.set_xticks([r + bar_width/2 for r in range(len(helpful_bar))], ['Binned MLP', 'MS Transformer', 'Formula Transformer'])

    plt.legend()

In [None]:
models = ["binned_4096", "MS_4096", "formula_4096", "mist_4096"]
datasets = ["canopus", "massspecgym"]

for dataset in datasets:

    for model in models:

        harmful_identical = [s["same_mol"]["same_expt"]["harmful"] for k, s in all_info_random[dataset].items() if model in k][0]
        helpful_identical = [s["same_mol"]["same_expt"]["helpful"] for k, s in all_info_random[dataset].items() if model in k][0]
        total = harmful_identical + helpful_identical

        print(dataset, model, round(harmful_identical / (total) * 100, 1), harmful_identical, total)


Helpful and harmful, distribution of experiments

- do we have the same molecule that can contribute both negatively and positively to the test set

In [None]:
for dataset, results in all_info_random.items():

    fig, ax = plt.subplots(layout='constrained', figsize=(8, 6))
    bar_width = 0.23
    helpful, harmful = [],[]

    models = ["binned_4096", "MS_4096", "formula_4096", "mist_4096"]

    for model in models:

        helpful.extend([s["diff_mol"]["helpful"]["percent"] for k, s in all_info_random[dataset].items() if model in k])
        harmful.extend([s["diff_mol"]["harmful"]["percent"] for k, s in all_info_random[dataset].items() if model in k])

    helpful_bar = np.arange(len(helpful))
    harmful_bar = [x + bar_width + 0.01 for x in helpful_bar] 
    
    # plot bars
    plt.bar(helpful_bar, helpful, width = bar_width, label = "Helpful", color = "#D3D3D3")
    plt.bar(harmful_bar, harmful, width = bar_width, label = "Harmful", color = "#F6BD60") 

    if dataset == "canopus":
        plt.title(f"Canopus", fontsize = 14)
    if dataset == "massspecgym":
        plt.title(f"MassSpecGym", fontsize = 14)
    
    plt.xlabel('Models', fontsize = 12)
    plt.ylabel('Tanimoto similarity of training molecules', fontsize = 12)
    plt.ylim(bottom = 0, top = 65)

    # plt.ylim(bottom = 0, top = 100)
    plt.xticks([r + bar_width /2 for r in range(len(models))], ['Binned MLP', 'MS Transformer', 'Formula Transformer', "MIST"])

    for i in range(len(helpful)):
        plt.text(helpful_bar[i], helpful[i], str(round(helpful[i], 1)) + "%", ha='center', va='bottom', fontsize = 8)
        plt.text(harmful_bar[i], harmful[i], str(round(harmful[i], 1)) + "%", ha='center', va='bottom', fontsize = 8)
    
    plt.legend()

In [None]:
dataset = "massspecgym"

for reason in ["percent_diff_CE", "percent_diff_adduct", "percent_diff_instrument",  "percent_diff_adduct_CE", "percent_diff_adduct_instrument", "percent_diff_instrument_CE", "percent_diff_adduct_instrument_CE"]:

    reason_mapping = {'percent_diff_CE': "Collision Energy",
                        'percent_diff_adduct': "Adduct",
                        'percent_diff_adduct_CE': "Collision Energy + Adduct",
                        'percent_diff_adduct_instrument': "Instrument + Adduct",
                        'percent_diff_adduct_instrument_CE': "Instrument + Adduct + CE",
                        'percent_diff_instrument': "Instrument",
                        'percent_diff_instrument_CE':  "Instrument + CE"}

    reason_key = reason_mapping[reason]
    fig, ax = plt.subplots(layout='constrained', figsize=(8, 6))
    bar_width = 0.23
    harmful = []
    helpful = []

    for model in ["MSG_binned_4096_random", "MSG_MS_4096_random", "MSG_formula_4096_random", "MSG_mist_4096_random"]:
        
        helpful.append(all_info_random[dataset][model]["diff_mol"]["helpful"]["diff_expt_reason"]["diff_expt_reason"][reason])
        harmful.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"][reason])

    helpful_bar = np.arange(len(helpful))
    harmful_bar = [x + bar_width + 0.01 for x in helpful_bar] 

    # plot bars
    plt.bar(helpful_bar, helpful, width = bar_width, label = "Helpful", color = "#D3D3D3")
    plt.bar(harmful_bar, harmful, width = bar_width, label = "Harmful", color = "#F6BD60")

    if dataset == "canopus":
        plt.title(f"Canopus", fontsize = 14)
    if dataset == "massspecgym":
        plt.title(f"{reason_key} (MassSpecGym)", fontsize = 14)

    plt.xlabel('Models', fontsize = 12)
    plt.ylabel('Percentage train samples (%)', fontsize = 12)
    plt.xticks([r + bar_width /2 for r in range(len(models))], ['Binned MLP', 'MS Transformer', 'Formula Transformer', "MIST"])

    for i in range(len(helpful)):
        plt.text(helpful_bar[i], helpful[i], str(round(helpful[i], 1)) + "%", ha='center', va='bottom', fontsize = 8)
        plt.text(harmful_bar[i], harmful[i], str(round(harmful[i], 1)) + "%", ha='center', va='bottom', fontsize = 8)

    plt.legend()

In [None]:
dataset = "massspecgym"

fig, ax = plt.subplots(layout='constrained', figsize=(8, 6))
bar_width = 0.08
diff_CE, diff_adduct, diff_adduct_CE, diff_adduct_instrument, diff_adduct_instrument_CE, diff_instrument, diff_instrument_CE = [], [], [], [], [], [], []
diff_CE, diff_adduct, diff_adduct_CE, diff_adduct_instrument, diff_adduct_instrument_CE, diff_instrument, diff_instrument_CE = [], [], [], [], [], [], []

for model in ["MSG_binned_4096_random", "MSG_MS_4096_random", "MSG_formula_4096_random", "MSG_mist_4096_random"]:

    diff_CE.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"]["percent_diff_CE"])
    diff_adduct.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"]["percent_diff_adduct"])
    diff_instrument.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"]["percent_diff_instrument"])
    diff_adduct_CE.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"]["percent_diff_adduct_CE"])
    diff_adduct_instrument.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"]["percent_diff_adduct_instrument"])
    diff_instrument_CE.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"]["percent_diff_instrument_CE"])
    diff_adduct_instrument_CE.append(all_info_random[dataset][model]["diff_mol"]["harmful"]["diff_expt_reason"]["diff_expt_reason"]["percent_diff_adduct_instrument_CE"])

diff_CE_bar = np.arange(len(diff_CE))
diff_adduct_bar = [x + bar_width + 0.01 for x in diff_CE_bar] 
diff_instrument_bar = [x + bar_width + 0.01 for x in diff_adduct_bar] 
diff_adduct_CE_bar = [x + bar_width + 0.01 for x in diff_instrument_bar] 
diff_adduct_instrument_bar = [x + bar_width + 0.01 for x in diff_adduct_CE_bar] 
diff_instrument_CE_bar = [x + bar_width + 0.01 for x in diff_adduct_instrument_bar] 
diff_adduct_instrument_CE_bar = [x + bar_width + 0.01 for x in diff_instrument_CE_bar] 

# plot bars
plt.bar(diff_CE_bar, diff_CE, width = bar_width, label = "% diff CE") #, color = "#D3D3D3")
plt.bar(diff_adduct_bar, diff_adduct, width = bar_width, label = "% diff adduct") #, color = "#F6BD60")
plt.bar(diff_instrument_bar, diff_instrument, width = bar_width, label = "% diff instrument") #, color = "#F6BD60")
plt.bar(diff_adduct_CE_bar, diff_adduct_CE, width = bar_width, label = "% diff adduct + CE") #, color = "#F6BD60")
plt.bar(diff_adduct_instrument_bar, diff_adduct_instrument, width = bar_width, label = "% diff adduct + instrument") #, color = "#F6BD60")
plt.bar(diff_instrument_CE_bar, diff_instrument_CE, width = bar_width, label = "% diff instrument + CE") #, color = "#F6BD60")
plt.bar(diff_adduct_instrument_CE_bar, diff_adduct_instrument_CE, width = bar_width, label = "% diff adduct + instrument + CE") #, color = "#F6BD60")

if dataset == "canopus":
    plt.title(f"Canopus", fontsize = 14)
if dataset == "massspecgym":
    plt.title(f"MassSpecGym (Harmful train samples)", fontsize = 14)

plt.xlabel('Models', fontsize = 12)
plt.ylabel('Percentage train samples (%)', fontsize = 12)
plt.ylim(bottom = 0, top = 100)
plt.xticks([r + bar_width * 4 for r in range(len(models))], ['Binned MLP', 'MS Transformer', 'Formula Transformer', "MIST"])

# for i in range(len(helpful)):
#     plt.text(helpful_bar[i], helpful[i], str(round(helpful[i], 1)) + "%", ha='center', va='bottom', fontsize = 8)
#     plt.text(harmful_bar[i], harmful[i], str(round(harmful[i], 1)) + "%", ha='center', va='bottom', fontsize = 8)

plt.legend()

In [None]:
all_info_random[dataset]["MSG_formula_4096_random"]["diff_mol"]["helpful"]["diff_expt_reason"]["diff_expt_reason"]

In [None]:
import os 
from tqdm import tqdm
from pathlib import Path
from utils import load_pickle, update_dict, same_expt, get_stats, pickle_data

def main(expt_folder, data):

    train_ids_path = expt_folder / "train_ids.pkl"
    test_ids_path = expt_folder / "test_ids.pkl"
    output_path = folder / dataset / expt / "scores_stats.pkl"

    if os.path.exists(output_path): return 
            
    scores = load_pickle(scores_path)["all_modules"]
    train_ids, test_ids = load_pickle(train_ids_path), load_pickle(test_ids_path)
    
    # Create a master list to conslidate the statistics
    master_stats = {"same_mol" : {"helpful": 0, "harmful": 0},
                    "harmful": {"mol_sim": 0, "MS_sim": 0, "count": 0},
                    "helpful": {"mol_sim": 0, "MS_sim": 0, "count": 0},
                    "same_mol_harmful_conditions" : {e: 0 for e in DIFF_EXPT_CONDITIONS},
                    "same_mol_helpful_conditions" : {e: 0 for e in DIFF_EXPT_CONDITIONS},
                    "same_mol_same_conditions": {"helpful": 0, "harmful": 0}}

    # Iterate through now 
    for test_idx, test_id in tqdm(enumerate(test_ids)):

        test_info = data[Path(test_id).stem]
        test_mol = test_info["inchikey_original"]

        train_records = {Path(train_id).stem: update_dict(scores[test_idx, train_idx].item(), data[Path(train_id).stem], "score") for train_idx, train_id in enumerate(train_ids)}

        # Get harmful and helpful molecules 
        harmful = {k:v for k,v in train_records.items() if v["score"] < 0}
        helpful = {k:v for k,v in train_records.items() if v["score"] > 0}
        
        top_k_harmful = {r[0]: r[1] for r in sorted(harmful.items(), key = lambda x: x[1]["score"])[:TOP_K]}
        top_k_helpful = {r[0]: r[1] for r in sorted(helpful.items(), key = lambda x: x[1]["score"], reverse = True)[:TOP_K]}

        # Look at how many harmful are of same / different molecules
        harmful_same = {k:v for k, v in harmful.items() if v["inchikey_original"][:14] == test_mol[:14]}
        helpful_same = {k:v for k, v in helpful.items() if v["inchikey_original"][:14] == test_mol[:14]}

        harmful_diff = {k:v for k, v in harmful.items() if v["inchikey_original"][:14] != test_mol[:14]}
        helpful_diff = {k:v for k, v in helpful.items() if v["inchikey_original"][:14] != test_mol[:14]}
        
        # Get same molecule but with either same or different experimental conditions 
        same_mol_same_conditions = {k: v for k, v in train_records.items() \
                                    if v["inchikey_original"][:14] == test_mol[:14] 
                                    and same_expt(test_info, v)}

        same_mol_diff_conditions = {k: v for k, v in train_records.items() \
                                    if v["inchikey_original"][:14] == test_mol[:14] 
                                    and not same_expt(test_info, v)}
        
        # Look at how many of the same mol with same / diff conditions are helpful / harmful
        harmful_same_mol_same_conditions = {k:v for k,v in same_mol_same_conditions.items() if v["score"] < 0}
        helpful_same_mol_same_conditions = {k:v for k,v in same_mol_same_conditions.items() if v["score"] > 0}

        harmful_same_mol_diff_conditions = {k:v for k,v in same_mol_diff_conditions.items() if v["score"] < 0}
        helpful_same_mol_diff_conditions = {k:v for k,v in same_mol_diff_conditions.items() if v["score"] > 0}

        # Get the statistics now
        harmful_stats = get_stats(test_info, top_k_harmful, compute_mol_dist= True, compute_MS_dist= True)
        helpful_stats = get_stats(test_info, top_k_helpful, compute_mol_dist= True, compute_MS_dist= True)
        
        harmful_same_mol_diff_conditions_stats = get_stats(test_info, harmful_same_mol_diff_conditions) 
        helpful_same_mol_diff_conditions_stats = get_stats(test_info, helpful_same_mol_diff_conditions) 

        # sum everything up now 
        master_stats["same_mol"]["harmful"] += len(harmful_same)
        master_stats["same_mol"]["helpful"] += len(helpful_same)

        master_stats["harmful"]["mol_sim"] += harmful_stats["mol_sim"]
        master_stats["harmful"]["MS_sim"] += harmful_stats["MS_sim"]
        master_stats["harmful"]["count"] += harmful_stats["mol_sim_n_train"]
                        
        master_stats["helpful"]["mol_sim"] += helpful_stats["mol_sim"]
        master_stats["helpful"]["MS_sim"] += helpful_stats["MS_sim"]
        master_stats["helpful"]["count"] += helpful_stats["mol_sim_n_train"]
        
        master_stats["same_mol_same_conditions"]["harmful"] += len(harmful_same_mol_same_conditions)
        master_stats["same_mol_same_conditions"]["helpful"] += len(helpful_same_mol_same_conditions)

        # Add to the conditions now 
        for k,v in harmful_same_mol_diff_conditions_stats["diff_expt_count"].items():
            if k == "total": continue
            master_stats["same_mol_harmful_conditions"][k] += v 

        for k,v in helpful_same_mol_diff_conditions_stats["diff_expt_count"].items():
            if k == "total": continue
            master_stats["same_mol_helpful_conditions"][k] += v 
    
    pickle_data(master_stats, output_path)

if __name__ == "__main__":

    TOP_K = 500
    CACHE_FOLDER = "./cache"
    DATA_FOLDER = Path("/data/rbg/users/klingmin/projects/MS_processing/data")
    RESULTS_FOLDER = ["../FP_prediction/baseline_models/models_cached/w_meta", "../FP_prediction/baseline_models/models_cached/wo_meta"]
    RESULTS_FOLDER = [Path(f) for f in RESULTS_FOLDER]
    if not os.path.exists(CACHE_FOLDER): os.makedirs(CACHE_FOLDER)

    DIFF_EXPT_CONDITIONS = ["diff_adduct", "diff_instrument", "diff_CE",
                            "diff_adduct_instrument", "diff_adduct_CE",
                            "diff_instrument_CE", "diff_adduct_instrument_CE"]

    # Get statistics for each experiment 
    for folder in RESULTS_FOLDER:

        for dataset in os.listdir(folder): 

            data = load_pickle(DATA_FOLDER / dataset / f"{dataset}_w_mol_info.pkl")
            data = {str(rec["id_"]): rec for rec in data}

            for expt in os.listdir(folder / dataset):

                expt_folder = folder / dataset / expt
                scores_path = expt_folder / "EK-FAC_scores.pkl"

                if not os.path.exists(scores_path): continue 

                # Proceed to process the influence scores now 
                main(expt_folder, data)

