Import libraries

In [1]:
import os
import json
import numpy as np
import pandas as pd
from tqdm import tqdm
from utils import load_json, load_pickle, pickle_data, write_json

import torch

Settings

In [10]:
baseline_folder = "../FP_prediction/baseline_models/best_models/"

thresholds = [0.40, 0.45, 0.50, 0.60, 0.65]

In [12]:
def sigmoid(x):
  """
  Computes the sigmoid function for each element in x.

  Args:
    x: A scalar or NumPy array.

  Returns:
    A scalar or NumPy array with the sigmoid of each element in x.
  """
  return 1 / (1 + np.exp(-x))

def to_binary(FP, threshold):

    FP = sigmoid(FP)
    FP = (FP > threshold).astype(int)

    return FP 

@torch.no_grad()
def batch_jaccard_index(FP_pred, FP):

    # Intersection = bitwise AND
    intersection = np.logical_and(FP, FP_pred).sum()

    # Union = bitwise OR
    union = np.logical_or(FP, FP_pred).sum()

    # Avoid division-by-zero by adding a small epsilon
    jaccard_scores = intersection / (union + 1e-9)

    return jaccard_scores

Consolidate results folders

In [13]:
# Manually add in (hack)
all_folders = []
all_test_performances = {"loss": {}, "jaccard": {}} 

for folder in [baseline_folder]:

    for dataset in os.listdir(folder):
        subfolder = os.path.join(folder, dataset)
        for checkpoint in os.listdir(subfolder):
            all_folders.append(os.path.join(subfolder, checkpoint))

In [14]:
for f in tqdm(all_folders):

    path =  os.path.join(f, "EK-FAC_scores.pkl")
    if not os.path.exists(path): 
        print(f)
        continue 

    # try:

    #     test_performance = load_json(path)

    # except Exception as e:

    #     test_performance = load_pickle(path)
    #     write_json(test_performance, path)
    #     print(e)

100%|██████████| 27/27 [00:00<00:00, 3095.09it/s]

../FP_prediction/baseline_models/best_models/massspecgym/MSG_binned_4096_scaffold_vanilla
../FP_prediction/baseline_models/best_models/massspecgym/MSG_binned_4096_inchikey_vanilla
../FP_prediction/baseline_models/best_models/nist2023/NIST2023_binned_meta_4096_scaffold_vanilla
../FP_prediction/baseline_models/best_models/nist2023/NIST2023_MS_meta_4096_random
../FP_prediction/baseline_models/best_models/nist2023/NIST2023_formula_meta_4096_random
../FP_prediction/baseline_models/best_models/nist2023/NIST2023_binned_meta_4096_inchikey_vanilla
../FP_prediction/baseline_models/best_models/nist2023/NIST2023_MS_meta_4096_scaffold_vanilla
../FP_prediction/baseline_models/best_models/nist2023/NIST2023_binned_meta_4096_random
../FP_prediction/baseline_models/best_models/nist2023/NIST2023_formula_4096_inchikey_vanilla
../FP_prediction/baseline_models/best_models/canopus/C_binned_meta_4096_random
../FP_prediction/baseline_models/best_models/canopus/C_binned_4096_inchikey_vanilla
../FP_prediction/ba




In [15]:
for f in tqdm(all_folders):

    f_list = f.split("/")
    model = f_list[-1]

    test_performance = load_json(os.path.join(f, "test_performance.json"))
    
    if "threshold" not in test_performance: 
        
        print(f"processing : {f} now")
        consolidated_jaccard_scores = {t: 0 for t in thresholds}
        test_results = load_pickle(os.path.join(f, "test_results.pkl"))
        
        # Iterate through 
        for k, v in tqdm(test_results.items()): 
            v_pred = v["pred"]
            v_true = v["GT"]

            for t in thresholds:

                v_binary = to_binary(np.array(v_pred), threshold = t)
                score = batch_jaccard_index(v_binary, v_true)
                consolidated_jaccard_scores[t] += score

        consolidated_jaccard_scores = {k : (v / len(test_results)) for k, v in consolidated_jaccard_scores.items()}
        best_t = max(consolidated_jaccard_scores, key=consolidated_jaccard_scores.get)
        test_performance["threshold"] = best_t 
        test_performance["jaccard"] = float(consolidated_jaccard_scores[best_t])
        write_json(test_performance, os.path.join(f, "test_performance.json"))

  0%|          | 0/27 [00:00<?, ?it/s]

processing : ../FP_prediction/baseline_models/best_models/nist2023/NIST2023_MS_meta_4096_random now


100%|██████████| 137898/137898 [08:25<00:00, 272.76it/s]
 44%|████▍     | 12/27 [10:20<12:55, 51.67s/it]

processing : ../FP_prediction/baseline_models/best_models/nist2023/NIST2023_formula_meta_4096_random now


100%|██████████| 137898/137898 [08:21<00:00, 274.81it/s]
 48%|████▊     | 13/27 [20:42<26:13, 112.39s/it]

processing : ../FP_prediction/baseline_models/best_models/nist2023/NIST2023_binned_meta_4096_random now


100%|██████████| 137898/137898 [08:34<00:00, 268.22it/s]
100%|██████████| 27/27 [31:46<00:00, 70.62s/it] 
