In [1]:
%load_ext autoreload
%autoreload 2
import os

os.environ["WANDB_SILENT"] = "true"

In [2]:
import argparse
import os
import statistics
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import yaml
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset, TensorDataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

import wandb
from src import BertClassifier
from src import datasets as data_utils
from src import influence, train_utils, utils
from src.datasets import create_loo_dataset, create_test_sst2, create_train_sst2

device = utils.get_device()

config = utils.load_config(
    "model_params/bert_classifier.yaml", epochs=5, num_training_examples=1000
)

## Create Datasets

In [3]:
USE_BERT_EMBEDDINGS = True

# Create datasets
train_dataset = create_train_sst2(
    num_samples=config["num_training_examples"],
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
    device=device,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)

test_dataset = create_test_sst2(
    tokenizer_name=config["bert_model_name"],
    max_seq_len=config["max_sequence_length"],
    device=device,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=1)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 14255.58it/s]
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

## Train model

In [4]:
full_model, original_df, test_loss, test_acc = train_utils.train_bert_model(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    config=config,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)
test_loss, test_acc

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 27.04batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

(0.41442395658465236, 80.27522935779817)

## Adversarial Attack

In [5]:
def perturb_datapoint(dataset, data_guid, perturbation):
    """This modifies the dataset in place"""
    device = utils.get_device()
    guid, inputs, attn_mask, labels = [t[data_guid] for t in train_dataset.tensors]
    assert guid.squeeze() == data_guid

    inputs_before = inputs.detach().clone()
    inputs += perturbation.to(device)
    return inputs_before, inputs


def perform_attack(
    model,
    config,
    train_dataset,
    test_dataset,
    target_test_guid,
    target_train_guid=None,
    alpha=2e-2,
):
    infl = None
    if target_train_guid is None:
        print("---Computing Influence Function---")
        infl = influence.compute_influence(
            model,
            target_test_guid,
            param_influence=list(model.classifier.parameters()),
            train_dataset=train_dataset,
            test_dataset=test_dataset,
            use_bert_embeddings=True,
            lissa_r=2,
            lissa_depth=1,
            damping=5e-3,
            scale=100,
        )

        # Most negative influence is most helpful
        helpful_idxs = np.argsort(infl)[:10]
        target_train_guid = helpful_idxs[0]

    print("---Computing Input Influence Function---")
    input_infl = influence.compute_input_influence(
        model,
        target_test_guid,
        param_influence=list(model.classifier.parameters()),
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        use_bert_embeddings=True,
        lissa_r=2,
        lissa_depth=1,
        damping=5e-3,
        scale=100,
        training_indices=[target_train_guid],
    )

    print(f"---Perturbing training guid {target_train_guid}---")
    perturb = alpha * input_infl[target_train_guid]
    perturb_datapoint(train_dataset, target_train_guid, perturb)

    print("---Retraining on perturbed data---")
    # Retrain model on perturbed dataset
    model, df, full_test_loss, full_test_acc = train_utils.train_bert_model(
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        config=config,
        use_bert_embeddings=True,
    )
    df["perturbed_guid"] = target_train_guid
    return model, df, infl, input_infl

In [17]:
original_df[original_df.pred == original_df.label].sort_values('loss', ascending=False)

Unnamed: 0,test_guid,logits,pred,label,loss
497,497,"[0.0020433236, -0.011170415]",0,0,0.686562
870,870,"[-0.0993969, -0.13528356]",0,0,0.675365
271,271,"[0.0107166935, 0.050429046]",1,1,0.673488
695,695,"[-0.43976375, -0.4852701]",0,0,0.670653
230,230,"[-0.25980154, -0.21333694]",1,1,0.670185
...,...,...,...,...,...
613,613,"[-3.3422766, 3.141655]",1,1,0.001527
303,303,"[-3.5994592, 2.8959098]",1,1,0.001509
837,837,"[-3.6050532, 3.4566987]",1,1,0.000857
443,443,"[-3.6812828, 3.639783]",1,1,0.000661


In [6]:
# TEST_GUID = 716
TEST_GUID = 497

baseline_test_loss = original_df[original_df.test_guid == TEST_GUID].loss.squeeze()
baseline_test_loss

0.6865621209144592

In [19]:
# target_train_guid = 262

hist = {
    "loss_df": [],
    "influence": [],
    "input_influence": [],
}
model = full_model
for i in range(15):
    model, loss_df, infl, input_infl = perform_attack(
        model=model,
        config=config,
        train_dataset=train_dataset,
        test_dataset=test_dataset,
        target_test_guid=TEST_GUID,
        alpha=5e-1,
        # target_train_guid=target_train_guid,
    )
    loss_df["iter"] = i

    hist["loss_df"].append(loss_df)
    hist["influence"].append(infl)
    hist["input_influence"].append(input_infl)

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 13.830604
Recursion at depth 200: norm is 310.815369
Recursion at depth 400: norm is 400.959137
Recursion at depth 600: norm is 432.403717
Recursion at depth 800: norm is 442.971222
Recursion at depth 999: norm is 445.291595
Recursion at depth 0: norm is 13.849314
Recursion at depth 200: norm is 310.803345
Recursion at depth 400: norm is 402.894745
Recursion at depth 600: norm is 432.494446
Recursion at depth 800: norm is 443.604462
Recursion at depth 999: norm is 445.130676


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 149.13it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 13.806417
Recursion at depth 200: norm is 313.570862
Recursion at depth 400: norm is 402.886688
Recursion at depth 600: norm is 432.520599
Recursion at depth 800: norm is 444.473999
Recursion at depth 999: norm is 444.417725
Recursion at depth 0: norm is 14.480818
Recursion at depth 200: norm is 311.267944
Recursion at depth 400: norm is 402.693848
Recursion at depth 600: norm is 434.224945
Recursion at depth 800: norm is 440.267578
Recursion at depth 999: norm is 448.659424


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 112.68it/s]


---Perturbing training guid 146---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.83batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 15.687018
Recursion at depth 200: norm is 349.274170
Recursion at depth 400: norm is 450.368378
Recursion at depth 600: norm is 486.081757
Recursion at depth 800: norm is 498.160797
Recursion at depth 999: norm is 500.448059
Recursion at depth 0: norm is 15.503262
Recursion at depth 200: norm is 349.021698
Recursion at depth 400: norm is 453.112122
Recursion at depth 600: norm is 485.510956
Recursion at depth 800: norm is 498.992126
Recursion at depth 999: norm is 500.939545


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 149.06it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 15.583508
Recursion at depth 200: norm is 352.369141
Recursion at depth 400: norm is 452.870331
Recursion at depth 600: norm is 486.043640
Recursion at depth 800: norm is 499.626282
Recursion at depth 999: norm is 499.591919
Recursion at depth 0: norm is 16.258253
Recursion at depth 200: norm is 350.180359
Recursion at depth 400: norm is 452.665771
Recursion at depth 600: norm is 488.268982
Recursion at depth 800: norm is 494.840302
Recursion at depth 999: norm is 504.537476


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 107.28it/s]


---Perturbing training guid 685---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.47batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 16.321352
Recursion at depth 200: norm is 362.926178
Recursion at depth 400: norm is 467.813629
Recursion at depth 600: norm is 505.231750
Recursion at depth 800: norm is 517.767822
Recursion at depth 999: norm is 520.059448
Recursion at depth 0: norm is 16.107344
Recursion at depth 200: norm is 362.598480
Recursion at depth 400: norm is 470.840057
Recursion at depth 600: norm is 504.545197
Recursion at depth 800: norm is 518.691223
Recursion at depth 999: norm is 520.661194


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 147.38it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 16.248802
Recursion at depth 200: norm is 366.209595
Recursion at depth 400: norm is 470.613739
Recursion at depth 600: norm is 505.030670
Recursion at depth 800: norm is 519.433899
Recursion at depth 999: norm is 519.080139
Recursion at depth 0: norm is 16.900284
Recursion at depth 200: norm is 363.967651
Recursion at depth 400: norm is 470.535309
Recursion at depth 600: norm is 507.528381
Recursion at depth 800: norm is 514.166199
Recursion at depth 999: norm is 524.315979


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 120.09it/s]


---Perturbing training guid 920---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.14batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 16.776709
Recursion at depth 200: norm is 372.787781
Recursion at depth 400: norm is 480.538300
Recursion at depth 600: norm is 519.034363
Recursion at depth 800: norm is 531.914124
Recursion at depth 999: norm is 534.267395
Recursion at depth 0: norm is 16.537794
Recursion at depth 200: norm is 372.477020
Recursion at depth 400: norm is 483.608032
Recursion at depth 600: norm is 518.455872
Recursion at depth 800: norm is 532.824402
Recursion at depth 999: norm is 535.007141


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 143.77it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 16.698921
Recursion at depth 200: norm is 376.125092
Recursion at depth 400: norm is 483.395966
Recursion at depth 600: norm is 518.817444
Recursion at depth 800: norm is 533.589111
Recursion at depth 999: norm is 533.149719
Recursion at depth 0: norm is 17.365234
Recursion at depth 200: norm is 373.968262
Recursion at depth 400: norm is 483.402222
Recursion at depth 600: norm is 521.470398
Recursion at depth 800: norm is 528.118896
Recursion at depth 999: norm is 538.658203


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 120.04it/s]


---Perturbing training guid 131---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.91batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 16.947731
Recursion at depth 200: norm is 376.468536
Recursion at depth 400: norm is 485.480774
Recursion at depth 600: norm is 524.282043
Recursion at depth 800: norm is 537.295105
Recursion at depth 999: norm is 539.736755
Recursion at depth 0: norm is 16.693722
Recursion at depth 200: norm is 376.155090
Recursion at depth 400: norm is 488.640350
Recursion at depth 600: norm is 524.007324
Recursion at depth 800: norm is 538.372803
Recursion at depth 999: norm is 540.655212


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 147.14it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 16.869684
Recursion at depth 200: norm is 379.934998
Recursion at depth 400: norm is 488.224854
Recursion at depth 600: norm is 524.116089
Recursion at depth 800: norm is 539.290771
Recursion at depth 999: norm is 538.798218
Recursion at depth 0: norm is 17.536266
Recursion at depth 200: norm is 377.769562
Recursion at depth 400: norm is 488.342285
Recursion at depth 600: norm is 526.871399
Recursion at depth 800: norm is 533.519165
Recursion at depth 999: norm is 544.154114


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 108.37it/s]


---Perturbing training guid 867---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670745000010357, max=1.0…

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 28.98batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 30.09batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 30.19batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 30.02batch/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 30.05batch/s]


---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 17.054840
Recursion at depth 200: norm is 378.969788
Recursion at depth 400: norm is 489.129395
Recursion at depth 600: norm is 527.901062
Recursion at depth 800: norm is 540.914490
Recursion at depth 999: norm is 543.728699
Recursion at depth 0: norm is 16.794306
Recursion at depth 200: norm is 379.287292
Recursion at depth 400: norm is 492.098999
Recursion at depth 600: norm is 527.523865
Recursion at depth 800: norm is 542.232910
Recursion at depth 999: norm is 544.948669


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 146.95it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 16.974543
Recursion at depth 200: norm is 382.689362
Recursion at depth 400: norm is 491.732819
Recursion at depth 600: norm is 527.951111
Recursion at depth 800: norm is 543.403503
Recursion at depth 999: norm is 542.795471
Recursion at depth 0: norm is 17.645317
Recursion at depth 200: norm is 380.455078
Recursion at depth 400: norm is 492.186829
Recursion at depth 600: norm is 530.765259
Recursion at depth 800: norm is 537.337524
Recursion at depth 999: norm is 548.040344


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 115.12it/s]


---Perturbing training guid 911---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.34batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 17.347719
Recursion at depth 200: norm is 385.858215
Recursion at depth 400: norm is 498.396759
Recursion at depth 600: norm is 538.009460
Recursion at depth 800: norm is 551.000122
Recursion at depth 999: norm is 553.751526
Recursion at depth 0: norm is 17.060331
Recursion at depth 200: norm is 385.773651
Recursion at depth 400: norm is 501.779877
Recursion at depth 600: norm is 537.744141
Recursion at depth 800: norm is 552.808655
Recursion at depth 999: norm is 554.917358


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 144.71it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 17.254351
Recursion at depth 200: norm is 389.463531
Recursion at depth 400: norm is 501.295929
Recursion at depth 600: norm is 537.377014
Recursion at depth 800: norm is 553.889771
Recursion at depth 999: norm is 553.747314
Recursion at depth 0: norm is 17.940121
Recursion at depth 200: norm is 387.318146
Recursion at depth 400: norm is 501.673035
Recursion at depth 600: norm is 540.898682
Recursion at depth 800: norm is 547.328125
Recursion at depth 999: norm is 558.003662


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 107.24it/s]


---Perturbing training guid 183---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 27.94batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 17.774250
Recursion at depth 200: norm is 396.140320
Recursion at depth 400: norm is 512.147583
Recursion at depth 600: norm is 553.026550
Recursion at depth 800: norm is 566.036560
Recursion at depth 999: norm is 569.827576
Recursion at depth 0: norm is 17.489773
Recursion at depth 200: norm is 396.125519
Recursion at depth 400: norm is 515.583435
Recursion at depth 600: norm is 552.404175
Recursion at depth 800: norm is 568.297913
Recursion at depth 999: norm is 570.307556


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:07<00:00, 141.80it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 17.686131
Recursion at depth 200: norm is 399.852814
Recursion at depth 400: norm is 515.068787
Recursion at depth 600: norm is 552.106506
Recursion at depth 800: norm is 568.849060
Recursion at depth 999: norm is 568.734314
Recursion at depth 0: norm is 18.390631
Recursion at depth 200: norm is 397.482758
Recursion at depth 400: norm is 515.575012
Recursion at depth 600: norm is 555.962463
Recursion at depth 800: norm is 562.567688
Recursion at depth 999: norm is 572.964355


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 109.25it/s]


---Perturbing training guid 333---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.34batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 18.117043
Recursion at depth 200: norm is 403.823883
Recursion at depth 400: norm is 522.251770
Recursion at depth 600: norm is 563.839233
Recursion at depth 800: norm is 576.983154
Recursion at depth 999: norm is 580.884949
Recursion at depth 0: norm is 17.822470
Recursion at depth 200: norm is 403.666534
Recursion at depth 400: norm is 525.571899
Recursion at depth 600: norm is 563.074768
Recursion at depth 800: norm is 579.278442
Recursion at depth 999: norm is 581.391663


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 147.70it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 18.046347
Recursion at depth 200: norm is 407.519012
Recursion at depth 400: norm is 525.048096
Recursion at depth 600: norm is 562.781067
Recursion at depth 800: norm is 579.856323
Recursion at depth 999: norm is 579.771362
Recursion at depth 0: norm is 18.746189
Recursion at depth 200: norm is 405.259186
Recursion at depth 400: norm is 525.513062
Recursion at depth 600: norm is 566.707092
Recursion at depth 800: norm is 573.350464
Recursion at depth 999: norm is 584.062622


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:09<00:00, 109.08it/s]


---Perturbing training guid 843---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 29.94batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

---Computing Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 18.141100
Recursion at depth 200: norm is 404.502472
Recursion at depth 400: norm is 522.870300
Recursion at depth 600: norm is 564.650940
Recursion at depth 800: norm is 577.708557
Recursion at depth 999: norm is 581.674500
Recursion at depth 0: norm is 17.836382
Recursion at depth 200: norm is 404.141693
Recursion at depth 400: norm is 526.183350
Recursion at depth 600: norm is 563.671204
Recursion at depth 800: norm is 580.138123
Recursion at depth 999: norm is 582.219238


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 148.30it/s]


---Computing Input Influence Function---
LiSSA reps: 2 and num_iterations: 1000
Recursion at depth 0: norm is 18.051645
Recursion at depth 200: norm is 408.585968
Recursion at depth 400: norm is 525.826538
Recursion at depth 600: norm is 563.662964
Recursion at depth 800: norm is 580.642761
Recursion at depth 999: norm is 580.527527
Recursion at depth 0: norm is 18.759806
Recursion at depth 200: norm is 405.813110
Recursion at depth 400: norm is 526.589111
Recursion at depth 600: norm is 567.518860
Recursion at depth 800: norm is 574.171387
Recursion at depth 999: norm is 585.055420


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:08<00:00, 120.26it/s]


---Perturbing training guid 119---
---Retraining on perturbed data---


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 28.94batch/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████

In [21]:
df = pd.concat(hist['loss_df'])
df[df.test_guid == TEST_GUID]

Unnamed: 0,test_guid,logits,pred,label,loss,perturbed_guid,iter
497,497,"[-0.118081115, 0.1089515]",1,0,0.813093,146,0
497,497,"[-0.1613373, 0.15220653]",1,0,0.862158,685,1
497,497,"[-0.19281432, 0.18368623]",1,0,0.899013,920,2
497,497,"[-0.20458043, 0.19545211]",1,0,0.913035,131,3
497,497,"[-0.21231705, 0.20318845]",1,0,0.922327,867,4
497,497,"[-0.23281749, 0.2236888]",1,0,0.947227,911,5
497,497,"[-0.26563478, 0.25650555]",1,0,0.987916,183,6
497,497,"[-0.29135677, 0.2822266]",1,0,1.020512,333,7
497,497,"[-0.29223296, 0.2831043]",1,0,1.021634,843,8
497,497,"[-0.29912966, 0.29000124]",1,0,1.030484,119,9


In [None]:
df0 = loss_hist[0]
df1 = loss_hist[1]
df2 = loss_hist[2]
df3 = loss_hist[3]
df4 = loss_hist[4]

In [None]:
df0[df0.test_guid == TEST_GUID]

In [None]:
df1[df1.test_guid == TEST_GUID]

In [None]:
df2[df2.test_guid == TEST_GUID]

In [None]:
df3[df3.test_guid == TEST_GUID]

In [None]:
df4[df4.test_guid == TEST_GUID]

## Scratch

In [None]:
model2, loss_df2 = perform_attack(model, config, train_dataset, test_dataset, TEST_GUID)
loss_df2[loss_df2.test_guid == TEST_GUID]

In [None]:
model3, loss_df3 = perform_attack(
    model2, config, train_dataset, test_dataset, TEST_GUID
)
loss_df3[loss_df3.test_guid == TEST_GUID]

In [None]:
-

## Compute Influence Function

In [None]:
TEST_GUID = 716

infl = influence.compute_influence(
    full_model,
    TEST_GUID,
    param_influence=list(full_model.classifier.parameters()),
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
    lissa_r=2,
    lissa_depth=1,
    damping=5e-3,
    scale=100,
)

In [None]:
# Most negative influence is most helpful
helpful_idxs = np.argsort(infl)[:10]
helpful_idxs

In [None]:
np.take(infl, helpful_idxs)

## Compute Input Influence Function

In [None]:
input_infl = influence.compute_input_influence(
    full_model,
    TEST_GUID,
    param_influence=list(full_model.classifier.parameters()),
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
    lissa_r=2,
    lissa_depth=1,
    damping=5e-3,
    scale=100,
    training_indices=helpful_idxs,
)

In [None]:
best_idx = helpful_idxs[0]
best_idx

## Perturb the Best Idx

In [None]:
def get_guid(dataset, data_guid):
    pass


def perturb_datapoint(dataset, data_guid, perturbation):
    """This modifies the dataset in place"""
    device = utils.get_device()
    guid, inputs, attn_mask, labels = [t[data_guid] for t in train_dataset.tensors]
    assert guid.squeeze() == data_guid

    inputs_before = inputs.detach().clone()
    inputs += perturbation.to(device)
    return inputs_before, inputs

In [None]:
alpha = 1e-2

perturb = alpha * input_infl[best_idx]
before, after = perturb_datapoint(train_dataset, best_idx, perturb)

## Retrain model with new dataset

In [None]:
model, df, full_test_loss, full_test_acc = train_utils.train_bert_model(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    config=config,
    use_bert_embeddings=USE_BERT_EMBEDDINGS,
)

In [None]:
test_loss = df[df.test_guid == TEST_GUID].loss.squeeze()
test_loss

In [None]:
df[df.test_guid == TEST_GUID]

In [None]:
--

In [None]:
before

In [None]:
after

In [None]:
guid, inputs, attn_mask, labels = [t[data_guid] for t in train_dataset.tensors]

inputs[0][0]

In [None]:
guid, inputs, attn_mask, labels = [t[262] for t in train_dataset.tensors]

In [None]:
inputs[0][0]

In [None]:
# inputs[0]

In [None]:
inputs.shape

In [None]:
inputs += perturb.to(device)

In [None]:
inputs

In [None]:
input_infl[262]

In [None]:
import pickle

with open("input_infl.pkl", "wb") as fh:
    pickle.dump(input_infl, fh)

In [None]:
262 in helpful_idxs

In [None]:
-

## Word2Vec

In [None]:
from gensim.models import KeyedVectors, Word2Vec

w2v = KeyedVectors.load("word2vec/glove-twitter-100.kv")

In [None]:
w2v["spielberg"]