`` This script designs various strategies to sample training data using the influence function scores ``

Import libraries

In [3]:
import os
import pickle
import numpy as np
from tqdm import tqdm 
import torch

from utils import load_pickle, pickle_data

Settings

In [2]:
# Define the percentage of samples that we will be using for training
ratio_list = [0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.75, 0.80, 0.85, 0.90]

Consolidate all the results

In [4]:
main_results_folder = "../../FP_prediction"
all_folders = [] 

for model in os.listdir(main_results_folder):

    main_folder = os.path.join(main_results_folder, model)
    for dataset in os.listdir(os.path.join(main_folder, "best_models")):

        if "sieved" in dataset: continue 
        if "sampled_random" in dataset: continue 
        dataset_folder = os.path.join(main_folder, "best_models", dataset)
        for checkpoint in os.listdir(dataset_folder):
            if "EK-FAC_scores.pkl" in os.listdir(os.path.join(dataset_folder, checkpoint)):
                assert "EK-FAC_self_scores.pkl" in os.listdir(os.path.join(dataset_folder, checkpoint))
                all_folders.append(os.path.join(dataset_folder, checkpoint))

print(f"There are {len(all_folders)} checkpoints for analysis")

There are 28 checkpoints for analysis


`` Strategy 1: Pick the top k most important training datapoints``
    
    Technically, we need to define a metric of importance

Recall the equation is:

$$f(\theta^\star(\epsilon)) - f(\theta^\star) \approx -\nabla_\theta f(\theta^\star)^\top H^{-1} \nabla_\theta \mathcal{L}(z_m, \theta^\star)$$

To upweigh the examples, we set $$\epsilon = \frac{1}{n}$$

Since we want $$(f(\theta^\star(\epsilon)) - f(\theta^\star))\epsilon < 0$$

$$\Rightarrow -\nabla_\theta f(\theta^\star)^\top H^{-1} \nabla_\theta \mathcal{L}(z_m, \theta^\star) \, \epsilon < 0$$

Therefore, we need to pick the training samples that can lower the loss the most

In [5]:
for IDX in tqdm(range(len(all_folders))):

    CHECKPOINT = all_folders[IDX]
    sample_ids_folder = os.path.join(CHECKPOINT, "sample_ids")
    if not os.path.exists(sample_ids_folder): os.makedirs(sample_ids_folder)

    IF = load_pickle(os.path.join(CHECKPOINT, "EK-FAC_scores.pkl"))["all_modules"].T
    train_indices = load_pickle(os.path.join(CHECKPOINT, "train_ids.pkl"))

    helpful_score = torch.topk(-IF.sum(-1), k = IF.shape[0])
    helpful_score_idx_sorted = helpful_score.indices.numpy().tolist()

    for ratio in ratio_list:

        ratio_int = int(ratio * 100)
        n_train = int(round(ratio * IF.shape[0], 0))

        selected_train_ids = [train_indices[i] for i in helpful_score_idx_sorted[:n_train]]
        output_path = os.path.join(sample_ids_folder, f"top_k_helpful_{ratio_int}.pkl")
        pickle_data(selected_train_ids, output_path)

100%|██████████| 28/28 [01:22<00:00,  2.93s/it]


`` Strategy 2: Remove the top k most negative samples ``

Recall once again, the equation is:

$$f(\theta^\star(\epsilon)) - f(\theta^\star) \approx -\nabla_\theta f(\theta^\star)^\top H^{-1} \nabla_\theta \mathcal{L}(z_m, \theta^\star)$$

To remove the examples, we set $$\epsilon = -\frac{1}{n}$$

Since we want $$(f(\theta^\star(\epsilon)) - f(\theta^\star))\epsilon < 0$$

$$\Rightarrow \nabla_\theta f(\theta^\star)^\top H^{-1} \nabla_\theta \mathcal{L}(z_m, \theta^\star) \, \epsilon > 0$$

Therefore, we need to pick the training samples that can lower the loss the most, when removed

In [None]:
for IDX in tqdm(range(len(all_folders))):

    CHECKPOINT = all_folders[IDX]
    print(f"Processing {CHECKPOINT} now")
    sample_ids_folder = os.path.join(CHECKPOINT, "sample_ids")
    if not os.path.exists(sample_ids_folder): os.makedirs(sample_ids_folder)

    IF = load_pickle(os.path.join(CHECKPOINT, "EK-FAC_scores.pkl"))["all_modules"].T
    train_indices = load_pickle(os.path.join(CHECKPOINT, "train_ids.pkl"))

    harmful_score = torch.topk(IF.sum(-1), k = IF.shape[0])
    harmful_score_idx_sorted = harmful_score.indices.numpy().tolist()
    removal_ratio_list = [1.0 - r for r in ratio_list]

    for ratio in tqdm(removal_ratio_list):

        ratio_int = int((1.0 - ratio) * 100)
        n_train_to_remove = int(round(ratio * IF.shape[0], 0))
        selected_train_ids = list(set(train_indices) - set(harmful_score_idx_sorted[:n_train_to_remove]))
        
        output_path = os.path.join(sample_ids_folder, f"remove_top_k_harmful_{ratio_int}.pkl")
        pickle_data(selected_train_ids, output_path)

  0%|          | 0/28 [00:00<?, ?it/s]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_binned_meta_4096_random now


100%|██████████| 10/10 [00:01<00:00,  7.65it/s]
  4%|▎         | 1/28 [00:02<01:08,  2.54s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_formula_4096_scaffold_vanilla now


100%|██████████| 10/10 [00:01<00:00,  7.99it/s]
  7%|▋         | 2/28 [00:05<01:05,  2.50s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_formula_meta_4096_random now


100%|██████████| 10/10 [00:01<00:00,  9.13it/s]
 11%|█         | 3/28 [00:07<00:58,  2.32s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_MS_meta_4096_random now


100%|██████████| 10/10 [00:01<00:00,  8.54it/s]
 14%|█▍        | 4/28 [00:09<00:56,  2.37s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_MS_4096_inchikey_vanilla now


100%|██████████| 10/10 [00:01<00:00,  8.98it/s]
 18%|█▊        | 5/28 [00:11<00:53,  2.34s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_binned_4096_scaffold_vanilla now


100%|██████████| 10/10 [00:01<00:00,  7.84it/s]
 21%|██▏       | 6/28 [00:14<00:51,  2.35s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_formula_4096_inchikey_vanilla now


100%|██████████| 10/10 [00:01<00:00,  8.30it/s]
 25%|██▌       | 7/28 [00:16<00:50,  2.40s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_MS_4096_scaffold_vanilla now


100%|██████████| 10/10 [00:01<00:00,  8.42it/s]
 29%|██▊       | 8/28 [00:19<00:48,  2.42s/it]

Processing ../../FP_prediction/baseline_models/best_models/massspecgym/MSG_binned_4096_inchikey_vanilla now


100%|██████████| 10/10 [00:01<00:00,  9.06it/s]
 32%|███▏      | 9/28 [00:21<00:45,  2.42s/it]

Processing ../../FP_prediction/baseline_models/best_models/nist2023/NIST2023_MS_meta_4096_inchikey_vanilla now


`` Strategy 3: We pick the top k most diverse ones ``

    Note that it's almost impossible to get the true optimal solution for this task given the size of our training data 
    (the runtime was prohibitively long)
    We would prematurely terminate the k-means clustering and get the centroid as the diverse samples

In [None]:
for IDX in tqdm(range(len(all_folders))):

    CHECKPOINT = all_folders[IDX]
    print(f"Processing {CHECKPOINT} now")
    sample_ids_folder = os.path.join(CHECKPOINT, "sample_ids")
    if not os.path.exists(sample_ids_folder): os.makedirs(sample_ids_folder)

    IF = load_pickle(os.path.join(CHECKPOINT, "EK-FAC_scores.pkl"))["all_modules"].T.numpy()
    train_indices = load_pickle(os.path.join(CHECKPOINT, "train_ids.pkl"))

    for ratio in ratio_list:

        ratio_int = int(ratio * 100)
        n_train = int(round(ratio * IF.shape[0], 0))        
        mbk = MiniBatchKMeans(n_clusters = n_train, batch_size = 1024, max_iter=5) # Let us use a small number to speed up the computation 
        mbk.fit(IF)

        # Get centroids
        centroids = mbk.cluster_centers_
        selected_indices, _ = pairwise_distances_argmin_min(centroids, IF)
        selected_train_ids = [train_indices[i] for i in selected_indices]

        output_path = os.path.join(sample_ids_folder, f"k_means_centroid_{ratio_int}.pkl")
        pickle_data(selected_train_ids, output_path)

`` Strategy 4: Removal of samples using self-influence scores ``

    The motivation is that many of these methods rely on the test score to subsample the important training records.
    We look at how well we can use self-influence score to determine the sampling strategy

    So now, we aim to remove potential misannotations in the data

Recall once again, the equation for self influence is:

$$f(\theta^\star(\epsilon)) - f(\theta^\star) \approx -\nabla_\theta \mathcal{L}(z_m, \theta^\star)^\top H^{-1} \nabla_\theta \mathcal{L}(z_m, \theta^\star)$$

To remove the examples, we set $$\epsilon = -\frac{1}{n}$$

Consider:

$$f(\theta^\star(\epsilon)) - f(\theta^\star) > 0$$

Therefore, it is suggested that the higher the self-IF score, the greater the lost increases when it is removed from the test --> this migh suggest that this is a misannotation where very little other training samples can be used to "support" its prediction? 

In [None]:
for IDX in tqdm(range(len(all_folders))):

    CHECKPOINT = all_folders[IDX]
    print(f"Processing {CHECKPOINT} now")
    sample_ids_folder = os.path.join(CHECKPOINT, "sample_ids")
    if not os.path.exists(sample_ids_folder): os.makedirs(sample_ids_folder)

    self_IF = load_pickle(os.path.join(CHECKPOINT, "EK-FAC_self_scores.pkl"))["all_modules"]
    train_indices = load_pickle(os.path.join(CHECKPOINT, "train_ids.pkl"))

    harmful_score = torch.topk(self_IF, k = self_IF.shape[0])
    harmful_score_idx_sorted = harmful_score.indices.numpy().tolist()
    removal_ratio_list = [1.0 - r for r in ratio_list]

    for ratio in tqdm(removal_ratio_list):

        ratio_int = int((1.0 - ratio) * 100)
        n_train_to_remove = int(round(ratio * self_IF.shape[0], 0))
        selected_train_ids = list(set(train_indices) - set(harmful_score_idx_sorted[:n_train_to_remove]))
        
        output_path = os.path.join(sample_ids_folder, f"remove_top_k_self_{ratio_int}.pkl")
        pickle_data(selected_train_ids, output_path)