# Objective: Fine-tune GPT-2 with the resulting filtered dataset

### 1. Imports and Global Settings

In [1]:
from datasets import load_dataset, disable_caching
from transformers import GPT2TokenizerFast, DataCollatorWithPadding, GPT2ForSequenceClassification, set_seed
import torch
from torch.nn.functional import one_hot
from utils_ import tokenize, train_classifier
import pickle
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
disable_caching()

In [2]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [3]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 128)

In [4]:
# read in indices removed by AFLite for each seed
with open('removed_idx.pkl', 'rb') as f:
    removed = pickle.load(f)

## 2. Pre-processing Routine for SNLI
- Get SNLI Dataset (Train fold) and shuffle it using the same seed as used for obtaining GPT-2 based Feature Representation (see notebook [Filtering_Part1.ipynb](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/gpt2-small/notebooks_and_scripts/Filtering_Part1.ipynb))
- Remove instances without gold standard labels, i.e., label = -1
- One-hot encoding for labels
- Partition data 10%/90%; use the 90% as unfiltered `train`
- Filter out what AFLite run removed (see notebook [Filtering_Part2.ipynb](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/gpt2-small/notebooks_and_scripts/Filtering_Part2.ipynb))
- Tokenise

In [5]:
def preprocess_snli(seed):
    
    # set up `train` dataset
    snli_train = load_dataset('snli', split = 'train').shuffle(seed = 42)
    snli_train = snli_train.filter(lambda x: x['label'] != -1).map( \
        lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
        batched = True)
    train = snli_train.select(range(int(len(snli_train)/10), len(snli_train)))
    
    # filter out what AFLite run removed
    removed_idx = [int(x) for x in removed[seed].split(',')[1:]] 
    train = train.select(set(range(len(train))) - set(removed_idx))
    
    # tokenize data
    train = train.map(lambda x: tokenize(tokenizer, x['premise'] + '|' + x['hypothesis']))
    len_bef_exclusion = len(train)

    # exclude instances with > 128 tokens
    train = train.filter(lambda x: x['exclude'] == False)
    len_aft_exclusion = len(train)

    # print message if instances were in fact excluded
    if len_bef_exclusion - len_aft_exclusion > 0:

        print(f'{len_bef_exclusion - len_aft_exclusion} ' + \
              f'({(len_bef_exclusion/len_aft_exclusion - 1)*100:>2f}%) sequences excluded')

    # keep only needed columns, set data format to PyTorch
    train.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

    return(train)

### 3. Run fine-tuning step and save resulting models

In [6]:
# hyper-parameters for model training
batch_size = 64 # constrained by GPU memory
lr = 1e-5 # set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108

# from Filtering_Part2.ipynb
AFLite_seeds = [0, 1, 2, 3, 4]
    
# Begin fine-tuning
for seed in AFLite_seeds:
    
    # print log message
    print(f'Seed {seed} begin')
    
    # read in filtered data
    data = preprocess_snli(seed)
    print('Filtered Data Read In')
    
    # instantiate new GPT-2 based model
    model = GPT2ForSequenceClassification.from_pretrained("gpt2", 
                                      num_labels=3,
                                      problem_type="multi_label_classification")
    model.config.pad_token_id = model.config.eos_token_id # specify pad_token used by tokenizer
    
    # set up data loader
    dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, \
                                             shuffle=True, collate_fn=data_collator)
    
    # set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)    
    
    # fine-tune model
    torch.save(train_classifier(model, dataloader, optimizer, device), \
              'AFLite_fine_tuned_model_seed' + str(seed) + '.pth')
    
    # print log message
    print(f'Seed {seed} complete')    

Seed 0 begin


Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/262396 [00:00<?, ?ex/s]

  0%|          | 0/263 [00:00<?, ?ba/s]

Filtered Data Read In


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/4100 [00:00<?, ?it/s]

loss: 1.757688  [    0/262396]
loss: 0.639000  [26240/262396]
loss: 0.604016  [52480/262396]
loss: 0.598526  [78720/262396]
loss: 0.604191  [104960/262396]
loss: 0.579932  [131200/262396]
loss: 0.542527  [157440/262396]
loss: 0.543326  [183680/262396]
loss: 0.473438  [209920/262396]
loss: 0.464315  [236160/262396]
Epoch average loss: 0.5419945120811462


  0%|          | 0/4100 [00:00<?, ?it/s]

loss: 0.427399  [    0/262396]
loss: 0.435798  [26240/262396]
loss: 0.367921  [52480/262396]
loss: 0.435413  [78720/262396]
loss: 0.441651  [104960/262396]
loss: 0.429242  [131200/262396]
loss: 0.329641  [157440/262396]
loss: 0.403802  [183680/262396]
loss: 0.369968  [209920/262396]
loss: 0.447986  [236160/262396]
Epoch average loss: 0.39370766282081604


  0%|          | 0/4100 [00:00<?, ?it/s]

loss: 0.384310  [    0/262396]
loss: 0.349832  [26240/262396]
loss: 0.259876  [52480/262396]
loss: 0.349334  [78720/262396]
loss: 0.342495  [104960/262396]
loss: 0.305910  [131200/262396]
loss: 0.260810  [157440/262396]
loss: 0.334609  [183680/262396]
loss: 0.311066  [209920/262396]
loss: 0.316424  [236160/262396]
Epoch average loss: 0.341413289308548
Done!
Seed 0 complete
Seed 1 begin


Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/262079 [00:00<?, ?ex/s]

  0%|          | 0/263 [00:00<?, ?ba/s]

Filtered Data Read In


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/4095 [00:00<?, ?it/s]

loss: 0.939102  [    0/262079]
loss: 0.650807  [26176/262079]
loss: 0.637025  [52352/262079]
loss: 0.572087  [78528/262079]
loss: 0.463744  [104704/262079]
loss: 0.473584  [130880/262079]
loss: 0.484078  [157056/262079]
loss: 0.526533  [183232/262079]
loss: 0.444780  [209408/262079]
loss: 0.527884  [235584/262079]
loss: 0.406123  [261760/262079]
Epoch average loss: 0.5199455618858337


  0%|          | 0/4095 [00:00<?, ?it/s]

loss: 0.499139  [    0/262079]
loss: 0.459934  [26176/262079]
loss: 0.376811  [52352/262079]
loss: 0.307248  [78528/262079]
loss: 0.328775  [104704/262079]
loss: 0.423708  [130880/262079]
loss: 0.375343  [157056/262079]
loss: 0.366637  [183232/262079]
loss: 0.400426  [209408/262079]
loss: 0.312537  [235584/262079]
loss: 0.321643  [261760/262079]
Epoch average loss: 0.37954846024513245


  0%|          | 0/4095 [00:00<?, ?it/s]

loss: 0.400075  [    0/262079]
loss: 0.369990  [26176/262079]
loss: 0.393274  [52352/262079]
loss: 0.304584  [78528/262079]
loss: 0.413210  [104704/262079]
loss: 0.334329  [130880/262079]
loss: 0.273709  [157056/262079]
loss: 0.366422  [183232/262079]
loss: 0.397555  [209408/262079]
loss: 0.340288  [235584/262079]
loss: 0.259924  [261760/262079]
Epoch average loss: 0.3353414237499237
Done!
Seed 1 complete
Seed 2 begin


Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/262330 [00:00<?, ?ex/s]

  0%|          | 0/263 [00:00<?, ?ba/s]

Filtered Data Read In


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/4099 [00:00<?, ?it/s]

loss: 1.861917  [    0/262330]
loss: 0.634136  [26176/262330]
loss: 0.630388  [52352/262330]
loss: 0.548637  [78528/262330]
loss: 0.526299  [104704/262330]
loss: 0.506436  [130880/262330]
loss: 0.492384  [157056/262330]
loss: 0.491082  [183232/262330]
loss: 0.482625  [209408/262330]
loss: 0.463158  [235584/262330]
loss: 0.449278  [261760/262330]
Epoch average loss: 0.5308033227920532


  0%|          | 0/4099 [00:00<?, ?it/s]

loss: 0.323597  [    0/262330]
loss: 0.441712  [26176/262330]
loss: 0.379708  [52352/262330]
loss: 0.404740  [78528/262330]
loss: 0.389663  [104704/262330]
loss: 0.398999  [130880/262330]
loss: 0.351477  [157056/262330]
loss: 0.333535  [183232/262330]
loss: 0.309637  [209408/262330]
loss: 0.405708  [235584/262330]
loss: 0.371995  [261760/262330]
Epoch average loss: 0.38767144083976746


  0%|          | 0/4099 [00:00<?, ?it/s]

loss: 0.411243  [    0/262330]
loss: 0.455609  [26176/262330]
loss: 0.262094  [52352/262330]
loss: 0.353340  [78528/262330]
loss: 0.261723  [104704/262330]
loss: 0.256272  [130880/262330]
loss: 0.254492  [157056/262330]
loss: 0.303975  [183232/262330]
loss: 0.328739  [209408/262330]
loss: 0.284473  [235584/262330]
loss: 0.280484  [261760/262330]
Epoch average loss: 0.3402520418167114
Done!
Seed 2 complete
Seed 3 begin


Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/256809 [00:00<?, ?ex/s]

  0%|          | 0/257 [00:00<?, ?ba/s]

Filtered Data Read In


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/4013 [00:00<?, ?it/s]

loss: 1.796898  [    0/256809]
loss: 0.626781  [25664/256809]
loss: 0.649242  [51328/256809]
loss: 0.561202  [76992/256809]
loss: 0.579713  [102656/256809]
loss: 0.528545  [128320/256809]
loss: 0.484664  [153984/256809]
loss: 0.541037  [179648/256809]
loss: 0.485742  [205312/256809]
loss: 0.453123  [230976/256809]
loss: 0.394164  [256640/256809]
Epoch average loss: 0.54237300157547


  0%|          | 0/4013 [00:00<?, ?it/s]

loss: 0.413500  [    0/256809]
loss: 0.423586  [25664/256809]
loss: 0.425314  [51328/256809]
loss: 0.379643  [76992/256809]
loss: 0.422790  [102656/256809]
loss: 0.446650  [128320/256809]
loss: 0.335778  [153984/256809]
loss: 0.360096  [179648/256809]
loss: 0.405407  [205312/256809]
loss: 0.263682  [230976/256809]
loss: 0.362599  [256640/256809]
Epoch average loss: 0.38935941457748413


  0%|          | 0/4013 [00:00<?, ?it/s]

loss: 0.334590  [    0/256809]
loss: 0.354069  [25664/256809]
loss: 0.356162  [51328/256809]
loss: 0.228001  [76992/256809]
loss: 0.287348  [102656/256809]
loss: 0.419076  [128320/256809]
loss: 0.260109  [153984/256809]
loss: 0.310465  [179648/256809]
loss: 0.277157  [205312/256809]
loss: 0.388664  [230976/256809]
loss: 0.246732  [256640/256809]
Epoch average loss: 0.33925527334213257
Done!
Seed 3 complete
Seed 4 begin


Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

  0%|          | 0/265398 [00:00<?, ?ex/s]

  0%|          | 0/266 [00:00<?, ?ba/s]

Filtered Data Read In


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/4147 [00:00<?, ?it/s]

loss: 1.528919  [    0/265398]
loss: 0.629002  [26496/265398]
loss: 0.644122  [52992/265398]
loss: 0.521323  [79488/265398]
loss: 0.615192  [105984/265398]
loss: 0.603087  [132480/265398]
loss: 0.637906  [158976/265398]
loss: 0.478359  [185472/265398]
loss: 0.548402  [211968/265398]
loss: 0.481871  [238464/265398]
loss: 0.448869  [264960/265398]
Epoch average loss: 0.5331658720970154


  0%|          | 0/4147 [00:00<?, ?it/s]

loss: 0.430480  [    0/265398]
loss: 0.381909  [26496/265398]
loss: 0.373069  [52992/265398]
loss: 0.370104  [79488/265398]
loss: 0.423201  [105984/265398]
loss: 0.445478  [132480/265398]
loss: 0.397075  [158976/265398]
loss: 0.329071  [185472/265398]
loss: 0.320523  [211968/265398]
loss: 0.317310  [238464/265398]
loss: 0.354632  [264960/265398]
Epoch average loss: 0.38367968797683716


  0%|          | 0/4147 [00:00<?, ?it/s]

loss: 0.354580  [    0/265398]
loss: 0.364377  [26496/265398]
loss: 0.302272  [52992/265398]
loss: 0.432563  [79488/265398]
loss: 0.349704  [105984/265398]
loss: 0.283945  [132480/265398]
loss: 0.290156  [158976/265398]
loss: 0.301987  [185472/265398]
loss: 0.365680  [211968/265398]
loss: 0.317792  [238464/265398]
loss: 0.377326  [264960/265398]
Epoch average loss: 0.3356095552444458
Done!
Seed 4 complete
