# Objective: Use AFLite to greedily solve for $\text{arg min}_{S \subset \mathcal{D}, ~|S| \geq n}\mathcal{R}(\Phi, ~S, ~\mathcal{M})$

### 1. Imports and Global Settings

In [1]:
from datasets import load_dataset, disable_caching
from transformers import GPT2TokenizerFast, DataCollatorWithPadding, set_seed
import torch
from torch.nn.functional import one_hot
import copy
import numpy as np
import sys
sys.path.append('..')
from utils_ import tokenize, train_classifier, predict, select_k
import pickle
import itertools
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
disable_caching()

### 2. Pre-Processing
- Get SNLI Dataset (Train fold) and shuffle it using the same seed as used for obtaining GPT-2 based Feature Representation (see notebook [Filtering_Part1.ipynb](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/Filtering_Part1.ipynb))
- Remove instances without gold standard labels, i.e., label = -1
- One-hot encoding for labels
- Partition data 10%/90%; use the 90% as `train`
- Tokenise train

In [2]:
snli_train = load_dataset('snli', split = 'train').shuffle(seed = 42)
snli_train = snli_train.filter(lambda x: x['label'] != -1).map( \
    lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
    batched = True)
train = snli_train.select(range(int(len(snli_train)/10), len(snli_train)))

Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

In [3]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-medium", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [4]:
# tokenize data
train = train.map(lambda x: tokenize(tokenizer, x['premise'] + '|' + x['hypothesis']))
len_bef_exclusion = len(train)

# exclude instances with > 128 tokens
train = train.filter(lambda x: x['exclude'] == False)
len_aft_exclusion = len(train)

# print message if instances were in fact excluded
if len_bef_exclusion - len_aft_exclusion > 0:
    
    print(f'{len_bef_exclusion - len_aft_exclusion} ' + \
          f'({(len_bef_exclusion/len_aft_exclusion - 1)*100:>2f}%) sequences excluded')

  0%|          | 0/494431 [00:00<?, ?ex/s]

  0%|          | 0/495 [00:00<?, ?ba/s]

In [5]:
# keep only needed columns, set data format to PyTorch
train.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

### 3. Set up inputs for AFLite

In [6]:
# load in the feature representation, Phi, with linear layer attached
model = torch.load('feature_rep.pth')

# move model to CPU
model.to('cpu')

# freeze all layers except the last
num_layers = sum(1 for _ in model.parameters())
for idx, param in enumerate(model.parameters()):
    
    if idx != num_layers - 1:
        
        # freeze
        param.requires_grad = False

In [7]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 128)

In [8]:
# hyper-parameters - constrained by training time available
m = 30
n = 195000
t = 50000
k = 100000
tau = 0.75
AFLite_seeds = [0, 1, 2, 3, 4]

In [9]:
# hyper-parameters for model training within AFLite implementation
batch_size = 128 # constrained by GPU memory
lr = 1e-5 # set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108

### 4.  AFLite Procedure

In [None]:
# set up containers to collect outputs
filtered_datasets = {}
removed_idx = {x: '' for x in AFLite_seeds}

# begin procedure
for seed in AFLite_seeds:
    
    # first step of AFLite; initialise S
    S = copy.deepcopy(train)
    
    # initialise iteration index
    it_idx = 0
    
    while len(S) > n:
        
        # update iteration index
        it_idx += 1
        
        # initialise multiset for Out-Of-Sample predictions
        E = {x: [] for x in range(len(S))}

        for j in range(m):
            
            # randomly partition S into (S\T_j, T_j) s.t. |S\T_j| = t
            tr_idx = set(np.random.default_rng(j).choice(np.arange(len(S)), t, replace = False))
            te_idx = set(range(len(S))) - tr_idx
            tr, te = S.select(tr_idx), S.select(te_idx)
            print(f'Seed {seed} - Iteration {it_idx} - Model {j + 1} - Begin')
                        
            # train classifier on S\T_j, i.e. tr
            classifier = copy.deepcopy(model)
            dataloader = torch.utils.data.DataLoader(tr, batch_size=batch_size, \
                                 shuffle=True, collate_fn=data_collator)
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr = lr)
            trained_classifier = train_classifier(classifier, dataloader, optimizer, device)
            
            # for all instances i in T_j, add predictions to E(i)
            te_dataloader = torch.utils.data.DataLoader(te, batch_size=batch_size, collate_fn=data_collator)
            preds = predict(trained_classifier, te_dataloader, device)
            print(f'Seed {seed} - Iteration {it_idx} - Model {j + 1} - Done')
            
            for pred_idx, data_idx in enumerate(te_idx): # there are as many predictions as test instances
                
                E[data_idx] += [preds[pred_idx]]
        
        # for all instances in S, compute predictability score
        # in the corner case that there are no predictions for an instance, we do not filter it out
        lengths = torch.tensor([len(x) if len(x) > 0 else 1 for x in E.values()])
        preds_padded = torch.tensor(list(itertools.zip_longest(*E.values(), fillvalue=-1))).transpose(0, 1)
        labels = torch.repeat_interleave(S['label'].argmax(1), max(lengths)).reshape(preds_padded.size())

        pred_matches = torch.eq(preds_padded, labels)
        pred_match_totals = torch.sum(pred_matches, axis = 1)
        pred_scores = pred_match_totals / lengths
        
        # select up to k instances with the highest predictability scores subject to score >= tau
        selected_idx = select_k(pred_scores, tau, k, seed)
        
        if selected_idx.shape[0] > 0:
        
            # cache instances selected for removal
            removed_idx[seed] += ',' + ','.join([str(idx) for idx in selected_idx])

            # filter out selected instances
            S = S.select(set(range(len(S))) - set(selected_idx))
        
        # early stopping
        elif selected_idx.shape[0] < k:
            
            break
    
    # cache file
    filtered_datasets[seed] = S
    
    # print number of instances in S, for creating random baseline
    print(f'Number of instances in S (seed {seed}): {len(S)}')
    
# write out list of removed indices for further analysis
with open('removed_idx.pkl', 'wb') as f:
    pickle.dump(removed_idx, f)

Seed 0 - Iteration 1 - Model 1 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.351462  [    0/50000]
loss: 0.339478  [ 4992/50000]
loss: 0.323733  [ 9984/50000]
loss: 0.371513  [14976/50000]
loss: 0.327074  [19968/50000]
loss: 0.414310  [24960/50000]
loss: 0.360435  [29952/50000]
loss: 0.316903  [34944/50000]
loss: 0.329192  [39936/50000]
loss: 0.394657  [44928/50000]
loss: 0.372505  [31200/50000]
Epoch average loss: 0.36408787965774536


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.384505  [    0/50000]
loss: 0.419292  [ 4992/50000]
loss: 0.439487  [ 9984/50000]
loss: 0.422065  [14976/50000]
loss: 0.387640  [19968/50000]
loss: 0.368946  [24960/50000]
loss: 0.349701  [29952/50000]
loss: 0.391921  [34944/50000]
loss: 0.347152  [39936/50000]
loss: 0.438856  [44928/50000]
loss: 0.348565  [31200/50000]
Epoch average loss: 0.3647885024547577


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.370970  [    0/50000]
loss: 0.397175  [ 4992/50000]
loss: 0.425714  [ 9984/50000]
loss: 0.365622  [14976/50000]
loss: 0.376260  [19968/50000]
loss: 0.401830  [24960/50000]
loss: 0.397583  [29952/50000]
loss: 0.368095  [34944/50000]
loss: 0.393686  [39936/50000]
loss: 0.346776  [44928/50000]
loss: 0.398886  [31200/50000]
Epoch average loss: 0.3642916977405548
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 1 - Done
Seed 0 - Iteration 1 - Model 2 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.452623  [    0/50000]
loss: 0.356459  [ 4992/50000]
loss: 0.327671  [ 9984/50000]
loss: 0.355331  [14976/50000]
loss: 0.380826  [19968/50000]
loss: 0.385904  [24960/50000]
loss: 0.325877  [29952/50000]
loss: 0.453212  [34944/50000]
loss: 0.400007  [39936/50000]
loss: 0.353646  [44928/50000]
loss: 0.425990  [31200/50000]
Epoch average loss: 0.3637206256389618


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.401627  [    0/50000]
loss: 0.381684  [ 4992/50000]
loss: 0.378271  [ 9984/50000]
loss: 0.373003  [14976/50000]
loss: 0.325675  [19968/50000]
loss: 0.369392  [24960/50000]
loss: 0.430628  [29952/50000]
loss: 0.337135  [34944/50000]
loss: 0.329574  [39936/50000]
loss: 0.341632  [44928/50000]
loss: 0.334829  [31200/50000]
Epoch average loss: 0.3633255660533905


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.339411  [    0/50000]
loss: 0.377477  [ 4992/50000]
loss: 0.333938  [ 9984/50000]
loss: 0.350429  [14976/50000]
loss: 0.367987  [19968/50000]
loss: 0.329853  [24960/50000]
loss: 0.338207  [29952/50000]
loss: 0.333629  [34944/50000]
loss: 0.447789  [39936/50000]
loss: 0.362146  [44928/50000]
loss: 0.446653  [31200/50000]
Epoch average loss: 0.3627510368824005
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 2 - Done
Seed 0 - Iteration 1 - Model 3 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.366838  [    0/50000]
loss: 0.346938  [ 4992/50000]
loss: 0.372451  [ 9984/50000]
loss: 0.320485  [14976/50000]
loss: 0.356794  [19968/50000]
loss: 0.369373  [24960/50000]
loss: 0.363575  [29952/50000]
loss: 0.308874  [34944/50000]
loss: 0.393721  [39936/50000]
loss: 0.382872  [44928/50000]
loss: 0.380604  [31200/50000]
Epoch average loss: 0.36189162731170654


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.301175  [    0/50000]
loss: 0.349950  [ 4992/50000]
loss: 0.352163  [ 9984/50000]
loss: 0.348009  [14976/50000]
loss: 0.409791  [19968/50000]
loss: 0.365269  [24960/50000]
loss: 0.377372  [29952/50000]
loss: 0.361114  [34944/50000]
loss: 0.318849  [39936/50000]
loss: 0.366322  [44928/50000]
loss: 0.283538  [31200/50000]
Epoch average loss: 0.3623788058757782


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.402654  [    0/50000]
loss: 0.303437  [ 4992/50000]
loss: 0.339138  [ 9984/50000]
loss: 0.330377  [14976/50000]
loss: 0.326630  [19968/50000]
loss: 0.367296  [24960/50000]
loss: 0.391380  [29952/50000]
loss: 0.333377  [34944/50000]
loss: 0.347316  [39936/50000]
loss: 0.327635  [44928/50000]
loss: 0.384607  [31200/50000]
Epoch average loss: 0.36186596751213074
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 3 - Done
Seed 0 - Iteration 1 - Model 4 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.450345  [    0/50000]
loss: 0.317346  [ 4992/50000]
loss: 0.383968  [ 9984/50000]
loss: 0.351726  [14976/50000]
loss: 0.438074  [19968/50000]
loss: 0.399947  [24960/50000]
loss: 0.372646  [29952/50000]
loss: 0.355177  [34944/50000]
loss: 0.349292  [39936/50000]
loss: 0.403655  [44928/50000]
loss: 0.445391  [31200/50000]
Epoch average loss: 0.36558693647384644


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.363079  [    0/50000]
loss: 0.376719  [ 4992/50000]
loss: 0.353817  [ 9984/50000]
loss: 0.318308  [14976/50000]
loss: 0.328502  [19968/50000]
loss: 0.455255  [24960/50000]
loss: 0.387360  [29952/50000]
loss: 0.333444  [34944/50000]
loss: 0.344278  [39936/50000]
loss: 0.293226  [44928/50000]
loss: 0.403219  [31200/50000]
Epoch average loss: 0.3643246293067932


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.395858  [    0/50000]
loss: 0.365659  [ 4992/50000]
loss: 0.390630  [ 9984/50000]
loss: 0.324348  [14976/50000]
loss: 0.278239  [19968/50000]
loss: 0.363707  [24960/50000]
loss: 0.341512  [29952/50000]
loss: 0.475533  [34944/50000]
loss: 0.366447  [39936/50000]
loss: 0.297043  [44928/50000]
loss: 0.362669  [31200/50000]
Epoch average loss: 0.3626748323440552
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 4 - Done
Seed 0 - Iteration 1 - Model 5 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.354864  [    0/50000]
loss: 0.423739  [ 4992/50000]
loss: 0.302279  [ 9984/50000]
loss: 0.379920  [14976/50000]
loss: 0.327253  [19968/50000]
loss: 0.316194  [24960/50000]
loss: 0.353090  [29952/50000]
loss: 0.371466  [34944/50000]
loss: 0.357100  [39936/50000]
loss: 0.318972  [44928/50000]
loss: 0.248706  [31200/50000]
Epoch average loss: 0.36417701840400696


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.337895  [    0/50000]
loss: 0.337823  [ 4992/50000]
loss: 0.417160  [ 9984/50000]
loss: 0.353334  [14976/50000]
loss: 0.247678  [19968/50000]
loss: 0.386853  [24960/50000]
loss: 0.353052  [29952/50000]
loss: 0.301107  [34944/50000]
loss: 0.368206  [39936/50000]
loss: 0.361346  [44928/50000]
loss: 0.377103  [31200/50000]
Epoch average loss: 0.363930344581604


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.437157  [    0/50000]
loss: 0.371344  [ 4992/50000]
loss: 0.338309  [ 9984/50000]
loss: 0.340573  [14976/50000]
loss: 0.400044  [19968/50000]
loss: 0.385437  [24960/50000]
loss: 0.421255  [29952/50000]
loss: 0.399701  [34944/50000]
loss: 0.279014  [39936/50000]
loss: 0.359794  [44928/50000]
loss: 0.394982  [31200/50000]
Epoch average loss: 0.3622462749481201
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

In [None]:
# write out list of removed indices for further analysis
with open('removed_idx.pkl', 'wb') as f:
    pickle.dump(removed_idx, f)