# Dual Objective:
- A) Use AFLite to greedily solve for $\text{arg min}_{S \subset \mathcal{D}, ~|S| \geq n}\mathcal{R}(\Phi, ~S, ~\mathcal{M})$
- B) Fine-tune GPT-2 with the resulting filtered dataset

### 1. Imports and Global Settings

In [1]:
from datasets import load_dataset, disable_caching
from tqdm.notebook import tqdm
from transformers import GPT2TokenizerFast, DataCollatorWithPadding, set_seed
import torch
from torch.nn.functional import one_hot
import copy
import numpy as np
from utils_ import tokenize, train_classifier, predict, select_k
import pickle
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
disable_caching()

### 2. Pre-Processing
- Get SNLI Dataset (Train fold) and shuffle it using the same seed as used for obtaining GPT-2 based Feature Representation (see notebook [Filtering_Part1.ipynb](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/Filtering_Part1.ipynb))
- Remove instances without gold standard labels, i.e., label = -1
- One-hot encoding for labels
- Partition data 10%/90%; use the 90% as `train`
- Tokenise train

In [2]:
snli_train = load_dataset('snli', split = 'train').shuffle(seed = 42)
snli_train = snli_train.filter(lambda x: x['label'] != -1).map( \
    lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
    batched = True)
train = snli_train.select(range(int(len(snli_train)/10), len(snli_train)))

Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

In [3]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [4]:
# tokenize data
train = train.map(lambda x: tokenize(tokenizer, x['premise'] + '|' + x['hypothesis']))
len_bef_exclusion = len(train)

# exclude instances with > 128 tokens
train = train.filter(lambda x: x['exclude'] == False)
len_aft_exclusion = len(train)

# print message if instances were in fact excluded
if len_bef_exclusion - len_aft_exclusion > 0:
    
    print(f'{len_bef_exclusion - len_aft_exclusion} ' + \
          f'({(en_bef_exclusion/len_aft_exclusion - 1):>2f}%) sequences excluded')

  0%|          | 0/494431 [00:00<?, ?ex/s]

  0%|          | 0/495 [00:00<?, ?ba/s]

In [5]:
# keep only needed columns, set data format to PyTorch
train.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

### 3. Set up inputs for AFLite

In [6]:
# load in the feature representation, Phi, with linear layer attached
model = torch.load('feature_rep.pth')

# move model to CPU
model.to('cpu')

# freeze all layers except the last
num_layers = sum(1 for _ in model.parameters())
for idx, param in enumerate(model.parameters()):
    
    if idx != num_layers - 1:
        
        # freeze
        param.requires_grad = False

In [7]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 128)

In [8]:
# hyper-parameters - constrained by training time available
m = 30
n = 195000
t = 50000
k = 100000
tau = 0.75
AFLite_seeds = [0, 1, 2, 3, 4]

In [9]:
# hyper-parameters for model training within AFLite implementation
batch_size = 128 # constrained by GPU memory
lr = 1e-5 # set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108

### 4.  AFLite Procedure

In [None]:
# set up containers to collect outputs
filtered_datasets = {}
removed_idx = {x: '' for x in AFLite_seeds}

# begin procedure
for seed in AFLite_seeds:
    
    # first step of AFLite; initialise S
    S = copy.deepcopy(train)
    
    # initialise iteration index
    it_idx = 0
    
    while len(S) > n:
        
        # update iteration index
        it_idx += 1
        
        # initialise multiset for Out-Of-Sample predictions
        E = {x: [] for x in range(len(S))}

        for j in range(m):
            
            # randomly partition S into (S\T_j, T_j) s.t. |S\T_j| = t
            tr_idx = set(np.random.default_rng(j).choice(np.arange(len(S)), t, replace = False))
            te_idx = set(range(len(S))) - tr_idx
            tr, te = S.select(tr_idx), S.select(te_idx)
            print(f'Seed {seed} - Iteration {it_idx} - Model {j + 1} - Begin')
                        
            # train classifier on S\T_j, i.e. tr
            classifier = copy.deepcopy(model)
            dataloader = torch.utils.data.DataLoader(tr, batch_size=batch_size, \
                                 shuffle=True, collate_fn=data_collator)
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr = lr)
            trained_classifier = train_classifier(classifier, dataloader, optimizer, device)
            
            # for all instances i in T_j, add predictions to E(i)
            te_dataloader = torch.utils.data.DataLoader(te, batch_size=batch_size, collate_fn=data_collator)
            preds = predict(trained_classifier, te_dataloader, device)
            print(f'Seed {seed} - Iteration {it_idx} - Model {j + 1} - Done')
            
            for pred_idx, data_idx in enumerate(te_idx): # there are as many predictions as test instances
                
                E[data_idx] += [preds[pred_idx]]
        
        # for all instances in S, compute predictability score
        # in the corner case that there are no predictions for an instance, we do not filter it out
        pred_scores = [0 if len(x) == 0 else sum([1 for y_hat in x if y_hat == S[idx]['label']])/len(x) \
                       for idx, x in enumerate(E.values())]
        
        # select up to k instances with the highest predictability scores subject to score >= tau
        selected_idx = select_k(pred_scores, tau, k, seed)
        
        if selected_idx.shape[0] > 0:
        
            # cache instances selected for removal
            removed_idx[seed] += ',' + ','.join([str(idx) for idx in selected_idx])

            # filter out selected instances
            S = S.select(set(range(len(S))) - set(selected_idx))
        
        # early stopping
        elif selected_idx.shape[0] < k:
            
            break
    
    # cache file
    filtered_datasets[seed] = S
    
    # print number of instances in S, for creating random baseline
    print(f'Number of instances in S (seed {seed}): {len(S)}')
    
# write out list of removed indices for further analysis
with open('removed_idx.pkl', 'wb') as f:
    pickle.dump(removed_idx, f)

Seed 0 - Iteration 1 - Model 1 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.490815  [    0/50000]
loss: 0.462348  [ 4992/50000]
loss: 0.483250  [ 9984/50000]
loss: 0.501971  [14976/50000]
loss: 0.445378  [19968/50000]
loss: 0.507128  [24960/50000]
loss: 0.435846  [29952/50000]
loss: 0.411889  [34944/50000]
loss: 0.481307  [39936/50000]
loss: 0.452662  [44928/50000]
loss: 0.496749  [31200/50000]
Epoch average loss: 0.4680524170398712


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.488783  [    0/50000]
loss: 0.467965  [ 4992/50000]
loss: 0.482960  [ 9984/50000]
loss: 0.464695  [14976/50000]
loss: 0.479309  [19968/50000]
loss: 0.460052  [24960/50000]
loss: 0.459742  [29952/50000]
loss: 0.501513  [34944/50000]
loss: 0.476872  [39936/50000]
loss: 0.534709  [44928/50000]
loss: 0.460403  [31200/50000]
Epoch average loss: 0.46620702743530273


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.439318  [    0/50000]
loss: 0.496312  [ 4992/50000]
loss: 0.486560  [ 9984/50000]
loss: 0.445026  [14976/50000]
loss: 0.457667  [19968/50000]
loss: 0.512696  [24960/50000]
loss: 0.515036  [29952/50000]
loss: 0.453929  [34944/50000]
loss: 0.412722  [39936/50000]
loss: 0.464638  [44928/50000]
loss: 0.542555  [31200/50000]
Epoch average loss: 0.4677579998970032
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 1 - Done
Seed 0 - Iteration 1 - Model 2 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.496329  [    0/50000]
loss: 0.461750  [ 4992/50000]
loss: 0.414453  [ 9984/50000]
loss: 0.499598  [14976/50000]
loss: 0.489535  [19968/50000]
loss: 0.493609  [24960/50000]
loss: 0.491454  [29952/50000]
loss: 0.517127  [34944/50000]
loss: 0.455259  [39936/50000]
loss: 0.461471  [44928/50000]
loss: 0.489900  [31200/50000]
Epoch average loss: 0.4648999571800232


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.480835  [    0/50000]
loss: 0.514182  [ 4992/50000]
loss: 0.455320  [ 9984/50000]
loss: 0.439934  [14976/50000]
loss: 0.424020  [19968/50000]
loss: 0.511567  [24960/50000]
loss: 0.536579  [29952/50000]
loss: 0.472423  [34944/50000]
loss: 0.482909  [39936/50000]
loss: 0.452261  [44928/50000]
loss: 0.445191  [31200/50000]
Epoch average loss: 0.46628013253211975


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.433838  [    0/50000]
loss: 0.455367  [ 4992/50000]
loss: 0.442342  [ 9984/50000]
loss: 0.443768  [14976/50000]
loss: 0.433495  [19968/50000]
loss: 0.446716  [24960/50000]
loss: 0.519018  [29952/50000]
loss: 0.480844  [34944/50000]
loss: 0.473903  [39936/50000]
loss: 0.441200  [44928/50000]
loss: 0.526351  [31200/50000]
Epoch average loss: 0.46541374921798706
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 2 - Done
Seed 0 - Iteration 1 - Model 3 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.468877  [    0/50000]
loss: 0.444588  [ 4992/50000]
loss: 0.463039  [ 9984/50000]
loss: 0.452809  [14976/50000]
loss: 0.475705  [19968/50000]
loss: 0.507545  [24960/50000]
loss: 0.442597  [29952/50000]
loss: 0.422230  [34944/50000]
loss: 0.504980  [39936/50000]
loss: 0.464689  [44928/50000]
loss: 0.475628  [31200/50000]
Epoch average loss: 0.4662174582481384


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.434559  [    0/50000]
loss: 0.472283  [ 4992/50000]
loss: 0.515166  [ 9984/50000]
loss: 0.440969  [14976/50000]
loss: 0.435638  [19968/50000]
loss: 0.420971  [24960/50000]
loss: 0.485145  [29952/50000]
loss: 0.495289  [34944/50000]
loss: 0.464453  [39936/50000]
loss: 0.464244  [44928/50000]
loss: 0.437880  [31200/50000]
Epoch average loss: 0.4661409258842468


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.471737  [    0/50000]
loss: 0.412788  [ 4992/50000]
loss: 0.467256  [ 9984/50000]
loss: 0.425003  [14976/50000]
loss: 0.399822  [19968/50000]
loss: 0.389849  [24960/50000]
loss: 0.456639  [29952/50000]
loss: 0.413032  [34944/50000]
loss: 0.473176  [39936/50000]
loss: 0.409610  [44928/50000]
loss: 0.495577  [31200/50000]
Epoch average loss: 0.46622443199157715
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 3 - Done
Seed 0 - Iteration 1 - Model 4 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.482670  [    0/50000]
loss: 0.467089  [ 4992/50000]
loss: 0.488443  [ 9984/50000]
loss: 0.473134  [14976/50000]
loss: 0.512105  [19968/50000]
loss: 0.512495  [24960/50000]
loss: 0.485099  [29952/50000]
loss: 0.429189  [34944/50000]
loss: 0.452853  [39936/50000]
loss: 0.462465  [44928/50000]
loss: 0.515317  [31200/50000]
Epoch average loss: 0.466852605342865


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.485342  [    0/50000]
loss: 0.458100  [ 4992/50000]
loss: 0.455437  [ 9984/50000]
loss: 0.439894  [14976/50000]
loss: 0.443444  [19968/50000]
loss: 0.519521  [24960/50000]
loss: 0.455382  [29952/50000]
loss: 0.451171  [34944/50000]
loss: 0.459446  [39936/50000]
loss: 0.459204  [44928/50000]
loss: 0.475304  [31200/50000]
Epoch average loss: 0.4650052487850189


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.533792  [    0/50000]
loss: 0.471606  [ 4992/50000]
loss: 0.526378  [ 9984/50000]
loss: 0.412952  [14976/50000]
loss: 0.413553  [19968/50000]
loss: 0.475280  [24960/50000]
loss: 0.390599  [29952/50000]
loss: 0.534704  [34944/50000]
loss: 0.460793  [39936/50000]
loss: 0.422744  [44928/50000]
loss: 0.466432  [31200/50000]
Epoch average loss: 0.4644717574119568
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

### 5. Fine-tuning GPT-2 with AFLite filtered datasets

In [None]:
# Free up some RAM
del train 
del model
del classifier

# Begin fine-tuning
for seed in AFLite_seeds:
    
    # instantiate new GPT-2 based model
    model = GPT2ForSequenceClassification.from_pretrained("gpt2", 
                                      num_labels=3,
                                      problem_type="multi_label_classification")
    model.config.pad_token_id = model.config.eos_token_id # specify pad_token used by tokenizer
    
    # set up data loader
    dataloader = torch.utils.data.DataLoader(filtered_datasets[seed], batch_size=batch_size, \
                                             shuffle=True, collate_fn=data_collator)
    
    # set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)    
    
    # fine-tune model
    torch.save(train_classifier(model, dataloader, optimizer, device), \
              'AFLite_fine_tuned_model_seed' + str(seed) + '.pth')