In [17]:
%load_ext autoreload
%autoreload 2
import argparse
import os
import statistics
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import yaml
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

import wandb
from src import BertClassifier
from src import datasets as data_utils
from src import influence, train_utils, utils
from src.datasets import create_loo_dataset, create_test_sst2, create_train_sst2

device = utils.get_device()

config = utils.load_config(
    "model_params/bert_classifier.yaml", epochs=5, num_training_examples=1000
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Create Datasets

In [21]:
USE_BERT_EMBEDDINGS = True

# Create datasets
train_dataset = create_train_sst2(
    num_samples=config["num_training_examples"],
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
    device=device,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)

test_dataset = create_test_sst2(
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
    device=device,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 14000.94it/s]
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

## Train model

In [22]:
full_model, fdf, full_test_loss, full_test_acc = train_utils.train_bert_model(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    config=config,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016669544999604113, max=1.0…

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.66batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.56batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 30.06batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.70batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.95batch/s]


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▃▅▆█
train/accuracy,▁▆███
train/batch_loss,▆▇▆█▅▅▄▅▄▄▄▃▄▄▃▂▃▃▄▆▄▂▂▄▃▃▁▄▁▂▄▁▅▁▅▃▃▂▃▅
train/loss,█▄▂▁▁

0,1
epoch,5.0
test/accuracy,80.27523
test/loss,0.41442
train/accuracy,85.21825
train/batch_loss,0.13145
train/loss,0.36353


In [23]:
baseline_test_loss = fdf[fdf.test_guid == TEST_GUID].loss.squeeze()
baseline_test_loss

0.1746881753206253

## Compute Influence Function

In [24]:
TEST_GUID = 716

infl = influence.compute_influence(
    full_model,
    TEST_GUID,
    param_influence=list(full_model.classifier.parameters()),
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
    lissa_r=2,
    lissa_depth=1,
    damping=5e-3,
    scale=100,
)

LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 4.655922
Recursion at depth 200: norm is 127.890549
Recursion at depth 400: norm is 165.555710
Recursion at depth 600: norm is 177.680008
Recursion at depth 800: norm is 182.073212
Recursion at depth 999: norm is 183.608948
Recursion at depth 0: norm is 4.609522
Recursion at depth 200: norm is 128.841476
Recursion at depth 400: norm is 166.395264
Recursion at depth 600: norm is 178.216537
Recursion at depth 800: norm is 181.195511
Recursion at depth 999: norm is 182.774078


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 144.65it/s]


In [25]:
# Most negative influence is most helpful
helpful_idxs = np.argsort(infl)[:10]
helpful_idxs

array([262, 447, 536, 526, 722, 190, 326, 293, 393, 143])

In [26]:
np.take(infl, helpful_idxs)

array([-4.30263615, -2.73550391, -2.63446903, -2.57483053, -2.51361799,
       -2.45602036, -2.43724656, -2.34256577, -2.29062819, -2.20918775])

## Compute Input Influence Function

In [32]:
input_infl = influence.compute_input_influence(
    full_model,
    TEST_GUID,
    param_influence=list(full_model.classifier.parameters()),
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
    lissa_r=2,
    lissa_depth=1,
    damping=5e-3,
    scale=100,
    training_indices=helpful_idxs
)



LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 4.539837
Recursion at depth 200: norm is 128.686966
Recursion at depth 400: norm is 165.868591
Recursion at depth 600: norm is 177.595367
Recursion at depth 800: norm is 181.906235
Recursion at depth 999: norm is 183.765076
Recursion at depth 0: norm is 4.443922
Recursion at depth 200: norm is 129.020157
Recursion at depth 400: norm is 165.915131
Recursion at depth 600: norm is 177.849350
Recursion at depth 800: norm is 182.245407
Recursion at depth 999: norm is 183.830505


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:25<00:00, 11.68it/s]


In [36]:
input_infl.keys()

dict_keys([tensor([143], device='cuda:0', dtype=torch.int32), tensor([190], device='cuda:0', dtype=torch.int32), tensor([262], device='cuda:0', dtype=torch.int32), tensor([293], device='cuda:0', dtype=torch.int32), tensor([326], device='cuda:0', dtype=torch.int32), tensor([393], device='cuda:0', dtype=torch.int32), tensor([447], device='cuda:0', dtype=torch.int32), tensor([526], device='cuda:0', dtype=torch.int32), tensor([536], device='cuda:0', dtype=torch.int32), tensor([722], device='cuda:0', dtype=torch.int32)])

In [44]:
def perturb_datapoint(dataset, data_guid, perturbation):
    pass

In [48]:
alpha = 0.01

# peturb input by this amount
input_infl[262].shape

torch.Size([64, 768])

In [39]:
input_infl[torch.tensor([262])]

KeyError: tensor([262])

In [None]:
import pickle

with open('input_infl.pkl', 'wb') as fh:
    pickle.dump(input_infl, fh)

In [31]:
262 in helpful_idxs

True

In [None]:
-

## Word2Vec

In [None]:
from gensim.models import KeyedVectors, Word2Vec

w2v = KeyedVectors.load("word2vec/glove-twitter-100.kv")

In [None]:
w2v["spielberg"]