In [14]:
%load_ext autoreload
%autoreload 2

import argparse
import os
from pathlib import Path

import numpy as np
import pandas as pd
import src.BertClassifier as BertClassifier
import src.utils as utils
import torch
import yaml
from src.datasets import create_loo_dataset, create_test_sst2, create_train_sst2
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb

device = utils.get_device()
config_path = "loo/classifier.yaml"
epochs = 3
num_training_examples = 10000

with open(config_path, "r") as stream:
    config = yaml.safe_load(stream)
config.update({"epochs": epochs, "num_training_examples": num_training_examples})

# Create datasets
train_dataset = create_train_sst2(
    device,
    num_samples=config["num_training_examples"],
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
)

test_dataset = create_test_sst2(
    device,
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=1)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


100%|████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 14475.99it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 872/872 [00:00<00:00, 10838.25it/s]


## Full Model

In [2]:
# loo_dataset = create_loo_dataset(train_dataset, 0)
# train_dataloader = DataLoader(
#     loo_dataset, batch_size=config["batch_size"], shuffle=True
# )
train_dataloader = DataLoader(
    train_dataset, batch_size=config["batch_size"], shuffle=True
)

# Create classifcation model
full_model = BertClassifier.create_bert_classifier(
    config["bert_model_name"],
    classifier_type=config["classifier_type"],
    classifier_hidden_size=config["classifier_hidden_size"],
    classifier_drop_out=config["classifier_drop_out"],
    freeze_bert=True,
    random_state=42,
)
full_model.classifier.load_state_dict(
    torch.load("loo_10k/run_0/init_classifier_params.pt")
)
# torch.save(model.classifier.state_dict(), 'classifier_params.pt')

optimizer = Adam(full_model.classifier.parameters(), lr=config["learning_rate"])
loss_fn = torch.nn.CrossEntropyLoss()

run = wandb.init(project="LOO-test", tags=["full"], config=config)

timings = utils.train(
    config=config,
    model=full_model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    train_dataloader=train_dataloader,
    val_dataloader=None,
)

test_loss, test_acc = utils.evaluate(full_model, test_dataloader)
wandb.summary["test/loss"] = test_loss
wandb.summary["test/accuracy"] = test_acc
wandb.finish()

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[34m[1mwandb[0m: Currently logged in as: [33mpatcao[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 27.73batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:21<00:00, 28.54batch/s]


VBox(children=(Label(value='0.003 MB of 0.013 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.213530…

0,1
epoch,▁█
train/accuracy,▁█
train/batch_loss,█▇▇▅▅▃▄▇▆▇▅▃▃▄▄▂▄▂▆▁█▂▄▁▂▅▂█▁▆█▂▄▂▅▆▃▁▃▅
train/loss,█▁

0,1
epoch,2.0
test/accuracy,83.37156
test/loss,0.37996
train/accuracy,83.18
train/batch_loss,0.11154
train/loss,0.37465


In [3]:
# Compute loss for each test point
fdf = utils.evaluate_loss_df(full_model, test_dataloader)
fdf

Unnamed: 0,test_guid,label,pred,loss
0,0,1,1,0.004901
1,1,0,0,0.063696
2,2,1,1,0.011256
3,3,1,1,0.058076
4,4,0,0,0.170423
...,...,...,...,...
867,867,0,1,0.938666
868,868,1,1,0.581234
869,869,0,1,1.234386
870,870,0,0,0.404960


In [4]:
fdf[fdf.test_guid == 869]

Unnamed: 0,test_guid,label,pred,loss
869,869,0,1,1.234386


In [5]:
fdf[fdf.test_guid == 869]

Unnamed: 0,test_guid,label,pred,loss
869,869,0,1,1.234386


## LOO Models

In [6]:
from torch.utils.data import Dataset, TensorDataset


def create_loo_dataset(sst2_dataset, loo_guid):
    guids, inputs, masks, labels = sst2_dataset.tensors
    loo_mask = ~(guids == loo_guid)
    return TensorDataset(
        guids[loo_mask], inputs[loo_mask], masks[loo_mask], labels[loo_mask]
    )

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()
loo_dfs = []

for loo_guid in range(0, 5):
    loo_dataset = create_loo_dataset(train_dataset, loo_guid)
    loo_train_dataloader = DataLoader(
        loo_dataset, batch_size=config["batch_size"], shuffle=True
    )

    # Create classifcation model
    loo_model = BertClassifier.create_bert_classifier(
        config["bert_model_name"],
        classifier_type=config["classifier_type"],
        classifier_hidden_size=config["classifier_hidden_size"],
        classifier_drop_out=config["classifier_drop_out"],
        freeze_bert=True,
        random_state=42,
    )

    optimizer = Adam(loo_model.classifier.parameters(), lr=config["learning_rate"])
    run = wandb.init(project="LOO-test", tags=["full"], config=config)

    timings = utils.train(
        config=config,
        model=loo_model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        train_dataloader=loo_train_dataloader,
        val_dataloader=None,
    )

    test_loss, test_acc = utils.evaluate(loo_model, test_dataloader)
    wandb.summary["test/loss"] = test_loss
    wandb.summary["test/accuracy"] = test_acc
    wandb.finish()

    # Compute loss for each test point
    df = utils.evaluate_loss_df(loo_model, test_dataloader)
    df["loo_guid"] = loo_guid
    loo_dfs.append(df)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 28.18batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:21<00:00, 28.76batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:21<00:00, 28.63batch/s]


0,1
epoch,▁▅█
train/accuracy,▁▇█
train/batch_loss,▅▅▄▄▄▃▄▃▁▅▄▃▃▃▆▂▄▂▅▂▃▄▃▆▂▅▄▂▃▃▁▂▂▄▂▂▃▂█▂
train/loss,█▂▁

0,1
epoch,3.0
test/accuracy,83.1422
test/loss,0.38197
train/accuracy,83.83867
train/batch_loss,0.04381
train/loss,0.36695


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 28.26batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 28.29batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 27.92batch/s]


VBox(children=(Label(value='0.003 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.194228…

0,1
epoch,▁▅█
train/accuracy,▁▇█
train/batch_loss,▅▅▄▄▄▃▄▃▁▅▄▃▃▃▆▂▄▂▅▂▃▄▃▆▂▅▄▂▃▃▁▂▂▄▂▂▃▂█▂
train/loss,█▂▁

0,1
epoch,3.0
test/accuracy,83.25688
test/loss,0.38207
train/accuracy,83.85867
train/batch_loss,0.04374
train/loss,0.36699


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 27.96batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 27.82batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 27.84batch/s]


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▅█
train/accuracy,▁▇█
train/batch_loss,▅▅▄▄▄▃▄▃▁▅▄▃▃▃▆▂▄▂▅▂▃▄▃▆▂▅▄▂▃▃▁▂▂▄▂▂▃▂█▂
train/loss,█▂▁

0,1
epoch,3.0
test/accuracy,83.25688
test/loss,0.38208
train/accuracy,83.86867
train/batch_loss,0.04371
train/loss,0.36699


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 27.77batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:22<00:00, 27.92batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:21<00:00, 28.53batch/s]


VBox(children=(Label(value='0.003 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.735468…

0,1
epoch,▁▅█
train/accuracy,▁▇█
train/batch_loss,▅▅▄▄▄▃▄▃▁▅▄▃▃▃▆▂▄▂▅▂▃▄▃▆▂▅▄▂▃▃▁▂▂▄▂▂▃▂█▂
train/loss,█▂▁

0,1
epoch,3.0
test/accuracy,83.25688
test/loss,0.38185
train/accuracy,83.87867
train/batch_loss,0.04357
train/loss,0.36673


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


 49%|██████████████████████████████████████████████▋                                                 | 304/625 [00:10<00:11, 28.64batch/s]

In [None]:
ldf = pd.concat(loo_dfs, axis=0)
# ldf = pd.concat([pd.read_csv('exps/loo_df.csv'), ldf], axis=0)
# ldf["loss_diff"] = ldf["loss"] - fdf[fdf.test_guid == test_guid].loss.squeeze()
# ldf

## Compute Influence

In [9]:
# import src.influence as inf_utils
# import torch.autograd as autograd

# test_guid = 869
# param_influence = list(full_model.classifier.parameters())
# influences = np.zeros(len(train_dataset))

# for guid, input_ids, input_mask, label_ids in test_dataloader:
#     if guid != test_guid:
#         continue

#     full_model.eval()
#     utils.set_seed(42)
#     train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=1)
#     train_dataloader_lissa = DataLoader(
#         train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True
#     )

#     input_ids = input_ids.to(device)
#     input_mask = input_mask.to(device)
#     label_ids = label_ids.to(device)

#     # L_TEST gradient
#     full_model.zero_grad()
#     output = full_model(input_ids, input_mask)
#     test_loss = loss_fn(output, label_ids)
#     test_grads = autograd.grad(test_loss, param_influence)

#     # IVHP
#     full_model.train()

#     t = int(len(train_dataloader) * 0.25)
#     # r = int(len(train_dataloader) / t)
#     r = 1
#     print(f"Using r: {r} and t: {t}")

#     inverse_hvp = inf_utils.get_inverse_hvp_lissa(
#         test_grads,
#         full_model,
#         device,
#         param_influence,
#         train_dataloader_lissa,
#         damping=3e-3,
#         scale=1e4,
#         num_samples=r,
#         recursion_depth=t,
#     )

#     for train_guid, train_input_id, train_input_mask, train_label in tqdm(
#         train_dataloader
#     ):
#         full_model.train()
#         full_model.zero_grad()
#         train_output = full_model(train_input_id, train_input_mask)
#         train_loss = loss_fn(train_output, train_label)
#         train_grads = autograd.grad(train_loss, param_influence)
#         influences[train_guid] = torch.dot(
#             inverse_hvp, inf_utils.gather_flat_grad(train_grads)
#         ).item()

#     break

In [10]:
import src.influence as inf_utils

test_guid = 869
param_infl = list(full_model.classifier.parameters())
infl = inf_utils.compute_influence(
    full_model=full_model,
    test_guid=test_guid,
    param_influence=param_infl,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    lissa_r=1,
    lissa_depth=0.25,
)

LiSSA reps: 1 and num_iterations: 2500
Recursion at depth 0: norm is 23.792892
Recursion at depth 200: norm is 1463.629517
Recursion at depth 400: norm is 1998.025513
Recursion at depth 600: norm is 2208.230469
Recursion at depth 800: norm is 2318.055908
Recursion at depth 1000: norm is 2370.135498
Recursion at depth 1200: norm is 2390.668945
Recursion at depth 1400: norm is 2399.083008
Recursion at depth 1600: norm is 2402.173828
Recursion at depth 1800: norm is 2407.013428
Recursion at depth 2000: norm is 2397.940430
Recursion at depth 2200: norm is 2412.884521
Recursion at depth 2400: norm is 2411.270264
Recursion at depth 2499: norm is 2415.950439


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:05<00:00, 151.76it/s]


## Analysis

In [None]:
test_guid = 869

#ldf = pd.read_csv("loo_10k/all_loo_losses.csv")
ldf = ldf[ldf.test_guid == test_guid]
ldf["loss_diff"] = ldf["loss"] - fdf[fdf.test_guid == test_guid].loss.squeeze()
ldf["if_diff"] = (-100.0 / len(train_dataset)) * infl[:5]
ldf

In [12]:
test_guid = 869

ldf = pd.read_csv("loo_10k/all_loo_losses.csv")
ldf = ldf[ldf.test_guid == test_guid]
ldf["loss_diff"] = ldf["loss"] - fdf[fdf.test_guid == test_guid].loss.squeeze()
ldf["if_diff"] = (-100.0 / len(train_dataset)) * infl[:20]
ldf

Unnamed: 0,loo_guid,test_guid,label,loss,loss_diff,if_diff
869,0,869,0,1.888133,0.653746,-0.001626
1741,1,869,0,1.888622,0.654235,-0.001123
2613,2,869,0,1.888152,0.653765,0.000207
3485,3,869,0,1.863295,0.628909,-0.017191
4357,4,869,0,1.866936,0.632549,-0.001228
5229,5,869,0,1.867203,0.632817,-0.004064
6101,6,869,0,1.865716,0.63133,0.000959
6973,7,869,0,1.865781,0.631395,0.005427
7845,8,869,0,1.864583,0.630196,-0.002895
8717,9,869,0,1.868037,0.63365,0.018015


In [21]:
print(ldf[["loss_diff", "if_diff"]].corr())
ldf[["loss_diff", "if_diff"]]

           loss_diff  if_diff
loss_diff        1.0      1.0
if_diff          1.0      1.0


Unnamed: 0,loss_diff,if_diff
869,-0.033466,-0.001626
869,-0.026826,-0.001123


In [None]:
cdf[["loss_diff", "if_diff"]].plot.scatter("loss_diff", "if_diff")

In [None]:
----

In [None]:
influences[:10]

## LOO Model

In [None]:
loo_guid = 0

# Create train dataset
loo_dataset = create_loo_dataset(train_dataset, loo_guid)
train_dataloader = DataLoader(
    loo_dataset, batch_size=config["batch_size"], shuffle=True
)

# Create classifcation model
model = BertClassifier.create_bert_classifier(
    config["bert_model_name"],
    classifier_type=config["classifier_type"],
    classifier_hidden_size=config["classifier_hidden_size"],
    classifier_drop_out=config["classifier_drop_out"],
    freeze_bert=True,
    random_state=42,
)

optimizer = Adam(model.classifier.parameters(), lr=config["learning_rate"])
loss_fn = torch.nn.CrossEntropyLoss()


run = wandb.init(project="LOO-test", tags=["loo"], config=config)

timings = utils.train(
    config=config,
    model=model,
    optimizer=optimizer,
    loss_fn=loss_fn,
    train_dataloader=train_dataloader,
    val_dataloader=None,
)

test_loss, test_acc = utils.evaluate(model, test_dataloader)
wandb.summary["test/loss"] = test_loss
wandb.summary["test/accuracy"] = test_acc

In [None]:
# Compute loss for each test point
df = utils.evaluate_loss_df(model, test_dataloader)
df

In [None]:
--

In [None]:
dfs = []

for loo_guid in tqdm(range(2000)):
    df = df.drop([c for c in df.columns if "Unnamed" in c], axis=1)
    df = df.rename(columns={"guid": "test_guid"})
    df["loo_guid"] = loo_guid
    df = df[["loo_guid", "test_guid", "loss", "label"]]
    dfs.append(df)

In [None]:
d = pd.concat(dfs)
df = pd.read_csv("loo/all_loo_losses.csv")

In [None]:
d.to_csv("loo/all_loo_losses.csv", index=False)

In [None]:
test_guid = 4

for test_guid in range(800):
    min_loss = df[df.test_guid == test_guid].loss.min()
    max_loss = df[df.test_guid == test_guid].loss.max()
    if min_loss != max_loss:
        print(min_loss, max_loss)

In [None]:
df