# Evaluate Performance of Baseline Models
__Model performance will be evaluated on:__
1. In-Distribution sample (SNLI test split), in zero-shot settings (see [Part 1](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/gpt2-small/notebooks_and_scripts/Evaluation_Baselines_Part1.ipynb))
2. The following Out-of-Distribution samples:
    - HANS dataset (validation split), in zero-shot settings  (see [Part 1](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/gpt2-small/notebooks_and_scripts/Evaluation_Baselines_Part1.ipynb))
    - NLI Diagnostics dataset, in zero-shot settings  (see [Part 1](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/gpt2-small/notebooks_and_scripts/Evaluation_Baselines_Part1.ipynb))
    - Stress Test datasets, in zero-shot settings  (see [Part 1](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/gpt2-small/notebooks_and_scripts/Evaluation_Baselines_Part1.ipynb))
    - ANLI datasets (test splits), after fine-tuning the model for each round (_this notebook_)

__Performance indicators:__ Classification accuracy and $R_K$


## 1. Imports and Global Settings

In [1]:
from datasets import load_dataset, disable_caching
from transformers import GPT2TokenizerFast, DataCollatorWithPadding, set_seed
import torch
from torch.nn.functional import one_hot
from utils_ import tokenize, train_classifier, evaluate_acc_rk
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
#disable_caching()

In [2]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [3]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 128)

## 2. Out-of-Distribution Evaluation - ANLI - Fine-Tuning
### 2.1. Data Read + Pre-Processing
- Get ANLI Datasets
- One-hot encode labels
- Tokenise data

In [4]:
# function to read in data and pre-process
def anli_data(split, tokenizer):

    # read in data
    data = load_dataset('anli', split = split)
    
    # one-hot encode labels
    data = data.map(lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
        batched = True)
    
    # tokenize data
    data = data.map(lambda x: tokenize(tokenizer, x['premise'] + '|' + x['hypothesis']))
    len_bef_exclusion = len(data)

    # exclude instances with > 128 tokens
    data = data.filter(lambda x: x['exclude'] == False)
    len_aft_exclusion = len(data)

    # print message if instances were in fact excluded
    if len_bef_exclusion - len_aft_exclusion > 0:

        print(f'Split: {split} - {len_bef_exclusion - len_aft_exclusion} ' + \
              f'({(len_bef_exclusion/len_aft_exclusion - 1)*100:>2f}%) sequences excluded')
    
    # keep only needed columns, set data format to PyTorch
    data.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])
    
    # store in data_dict
    return(data)

### 2.2. Fine-tune both Baseline models using corresponding training data for each ANLI round, then evaluate

In [5]:
# set up dictionary of the models/file_names
models = {'Unfiltered': 'baseline_unfiltered.pth', 'Random 190k Subset': 'baseline_random_190k.pth'}

In [6]:
for model_name, file_name in models.items(): # two baseline models
    
    for rd in ['r1', 'r2', 'r3']: # rounds
        
        # fine-tune
        tr = anli_data('train_' + rd, tokenizer)
        tr_model = torch.load(file_name)
        tr_dataloader = torch.utils.data.DataLoader(tr, batch_size=64, shuffle=True, collate_fn=data_collator)
        optimizer = torch.optim.Adam(tr_model.parameters(), lr = 1e-5)
        trained_classifier = train_classifier(tr_model, tr_dataloader, optimizer, device)
        del tr
        
        # evaluate
        te = anli_data('test_' + rd, tokenizer)
        te_dataloader = torch.utils.data.DataLoader(te, batch_size=64, collate_fn=data_collator)
        acc, rk = evaluate_acc_rk(trained_classifier, te_dataloader, device)
        del te
        del tr_model
        print(f'Round: {rd} - Model: {model_name} - Accuracy: {acc*100:>3f}%, RK: {rk:>3f}')       

Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-1c80317fa3b1799d.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-bdd640fb06671ad1.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-9a90de1e791ac553.arrow


Split: train_r1 - 446 (2.703030%) sequences excluded


  0%|          | 0/258 [00:00<?, ?it/s]

loss: 0.662190  [    0/16500]
loss: 0.684524  [ 1600/16500]
loss: 0.627868  [ 3200/16500]
loss: 0.631673  [ 4800/16500]
loss: 0.619842  [ 6400/16500]
loss: 0.622510  [ 8000/16500]
loss: 0.589818  [ 9600/16500]
loss: 0.639176  [11200/16500]
loss: 0.620326  [12800/16500]
loss: 0.632369  [14400/16500]
loss: 0.650170  [16000/16500]
Epoch average loss: 0.629946231842041


  0%|          | 0/258 [00:00<?, ?it/s]

loss: 0.626826  [    0/16500]
loss: 0.637085  [ 1600/16500]
loss: 0.614172  [ 3200/16500]
loss: 0.588634  [ 4800/16500]
loss: 0.615306  [ 6400/16500]
loss: 0.623370  [ 8000/16500]
loss: 0.623684  [ 9600/16500]
loss: 0.600798  [11200/16500]
loss: 0.625641  [12800/16500]
loss: 0.628761  [14400/16500]
loss: 0.568816  [16000/16500]
Epoch average loss: 0.6183813810348511


  0%|          | 0/258 [00:00<?, ?it/s]

loss: 0.604350  [    0/16500]
loss: 0.601835  [ 1600/16500]
loss: 0.603542  [ 3200/16500]
loss: 0.612788  [ 4800/16500]
loss: 0.580942  [ 6400/16500]
loss: 0.586551  [ 8000/16500]
loss: 0.592043  [ 9600/16500]
loss: 0.590704  [11200/16500]
loss: 0.625573  [12800/16500]
loss: 0.607084  [14400/16500]
loss: 0.575626  [16000/16500]
Epoch average loss: 0.6013104319572449
Done!


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-3eb13b9046685257.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-23b8c1e9392456de.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-7edfd81a366fc563.arrow


Split: test_r1 - 23 (2.354145%) sequences excluded


  0%|          | 0/16 [00:00<?, ?it/s]

Round: r1 - Model: Unfiltered - Accuracy: 32.139203%, RK: -0.026607


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-1a3d1fa7bc8960a9.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-bd9c66b3ad3c2d6d.arrow


  0%|          | 0/46 [00:00<?, ?ba/s]

Split: train_r2 - 944 (2.120586%) sequences excluded


  0%|          | 0/696 [00:00<?, ?it/s]

loss: 0.661627  [    0/44516]
loss: 0.639557  [ 4416/44516]
loss: 0.607458  [ 8832/44516]
loss: 0.566478  [13248/44516]
loss: 0.598534  [17664/44516]
loss: 0.622622  [22080/44516]
loss: 0.621659  [26496/44516]
loss: 0.568870  [30912/44516]
loss: 0.628582  [35328/44516]
loss: 0.621225  [39744/44516]
loss: 0.562283  [44160/44516]
Epoch average loss: 0.6068109273910522


  0%|          | 0/696 [00:00<?, ?it/s]

loss: 0.617732  [    0/44516]
loss: 0.565455  [ 4416/44516]
loss: 0.622436  [ 8832/44516]
loss: 0.557771  [13248/44516]
loss: 0.593824  [17664/44516]
loss: 0.683575  [22080/44516]
loss: 0.525310  [26496/44516]
loss: 0.525068  [30912/44516]
loss: 0.564386  [35328/44516]
loss: 0.536801  [39744/44516]
loss: 0.576559  [44160/44516]
Epoch average loss: 0.5746095180511475


  0%|          | 0/696 [00:00<?, ?it/s]

loss: 0.595745  [    0/44516]
loss: 0.561534  [ 4416/44516]
loss: 0.536243  [ 8832/44516]
loss: 0.586561  [13248/44516]
loss: 0.552988  [17664/44516]
loss: 0.468478  [22080/44516]
loss: 0.522253  [26496/44516]
loss: 0.528718  [30912/44516]
loss: 0.570660  [35328/44516]
loss: 0.481175  [39744/44516]
loss: 0.555486  [44160/44516]
Epoch average loss: 0.5410431623458862
Done!


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-8b9d2434e465e150.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-972a846916419f82.arrow


  0%|          | 0/1 [00:00<?, ?ba/s]

Split: test_r2 - 18 (1.832994%) sequences excluded


  0%|          | 0/16 [00:00<?, ?it/s]

Round: r2 - Model: Unfiltered - Accuracy: 32.281059%, RK: -0.017092


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-0822e8f36c031199.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-17fc695a07a0ca6e.arrow


  0%|          | 0/101 [00:00<?, ?ba/s]

Split: train_r3 - 7468 (8.030885%) sequences excluded


  0%|          | 0/1453 [00:00<?, ?it/s]

loss: 0.649782  [    0/92991]
loss: 0.629987  [ 9280/92991]
loss: 0.622221  [18560/92991]
loss: 0.555973  [27840/92991]
loss: 0.536686  [37120/92991]
loss: 0.569254  [46400/92991]
loss: 0.604832  [55680/92991]
loss: 0.654975  [64960/92991]
loss: 0.556073  [74240/92991]
loss: 0.564328  [83520/92991]
loss: 0.563100  [92800/92991]
Epoch average loss: 0.5865782499313354


  0%|          | 0/1453 [00:00<?, ?it/s]

loss: 0.568452  [    0/92991]
loss: 0.580404  [ 9280/92991]
loss: 0.562465  [18560/92991]
loss: 0.577402  [27840/92991]
loss: 0.538368  [37120/92991]
loss: 0.545005  [46400/92991]
loss: 0.649020  [55680/92991]
loss: 0.619443  [64960/92991]
loss: 0.530617  [74240/92991]
loss: 0.541208  [83520/92991]
loss: 0.527690  [92800/92991]
Epoch average loss: 0.5556570291519165


  0%|          | 0/1453 [00:00<?, ?it/s]

loss: 0.502693  [    0/92991]
loss: 0.563620  [ 9280/92991]
loss: 0.554028  [18560/92991]
loss: 0.487002  [27840/92991]
loss: 0.524226  [37120/92991]
loss: 0.511838  [46400/92991]
loss: 0.462878  [55680/92991]
loss: 0.546365  [64960/92991]
loss: 0.567102  [74240/92991]
loss: 0.504202  [83520/92991]
loss: 0.564119  [92800/92991]
Epoch average loss: 0.5241260528564453
Done!


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-3b8faa1837f8a88b.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b/cache-9a1de644815ef6d1.arrow


  0%|          | 0/2 [00:00<?, ?ba/s]

Split: test_r3 - 90 (8.108108%) sequences excluded


  0%|          | 0/18 [00:00<?, ?it/s]

Round: r3 - Model: Unfiltered - Accuracy: 34.954956%, RK: 0.022752


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


  0%|          | 0/17 [00:00<?, ?ba/s]

  0%|          | 0/16946 [00:00<?, ?ex/s]

  0%|          | 0/17 [00:00<?, ?ba/s]

Split: train_r1 - 446 (2.703030%) sequences excluded


  0%|          | 0/258 [00:00<?, ?it/s]

loss: 0.642953  [    0/16500]
loss: 0.640017  [ 1600/16500]
loss: 0.620945  [ 3200/16500]
loss: 0.632540  [ 4800/16500]
loss: 0.658361  [ 6400/16500]
loss: 0.613115  [ 8000/16500]
loss: 0.620597  [ 9600/16500]
loss: 0.641272  [11200/16500]
loss: 0.631798  [12800/16500]
loss: 0.609118  [14400/16500]
loss: 0.614423  [16000/16500]
Epoch average loss: 0.6336648464202881


  0%|          | 0/258 [00:00<?, ?it/s]

loss: 0.632271  [    0/16500]
loss: 0.604586  [ 1600/16500]
loss: 0.633499  [ 3200/16500]
loss: 0.629539  [ 4800/16500]
loss: 0.628619  [ 6400/16500]
loss: 0.623759  [ 8000/16500]
loss: 0.624003  [ 9600/16500]
loss: 0.604767  [11200/16500]
loss: 0.617369  [12800/16500]
loss: 0.627418  [14400/16500]
loss: 0.596394  [16000/16500]
Epoch average loss: 0.623264729976654


  0%|          | 0/258 [00:00<?, ?it/s]

loss: 0.607398  [    0/16500]
loss: 0.606147  [ 1600/16500]
loss: 0.602473  [ 3200/16500]
loss: 0.625929  [ 4800/16500]
loss: 0.592740  [ 6400/16500]
loss: 0.628374  [ 8000/16500]
loss: 0.610486  [ 9600/16500]
loss: 0.604055  [11200/16500]
loss: 0.601393  [12800/16500]
loss: 0.618940  [14400/16500]
loss: 0.612677  [16000/16500]
Epoch average loss: 0.6106846928596497
Done!


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1000 [00:00<?, ?ex/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Split: test_r1 - 23 (2.354145%) sequences excluded


  0%|          | 0/16 [00:00<?, ?it/s]

Round: r1 - Model: Random 190k Subset - Accuracy: 31.320369%, RK: -0.037414


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


  0%|          | 0/46 [00:00<?, ?ba/s]

  0%|          | 0/45460 [00:00<?, ?ex/s]

  0%|          | 0/46 [00:00<?, ?ba/s]

Split: train_r2 - 944 (2.120586%) sequences excluded


  0%|          | 0/696 [00:00<?, ?it/s]

loss: 0.617604  [    0/44516]
loss: 0.571803  [ 4416/44516]
loss: 0.627500  [ 8832/44516]
loss: 0.598281  [13248/44516]
loss: 0.607201  [17664/44516]
loss: 0.646916  [22080/44516]
loss: 0.626570  [26496/44516]
loss: 0.651115  [30912/44516]
loss: 0.639319  [35328/44516]
loss: 0.585977  [39744/44516]
loss: 0.578139  [44160/44516]
Epoch average loss: 0.6121343970298767


  0%|          | 0/696 [00:00<?, ?it/s]

loss: 0.577696  [    0/44516]
loss: 0.598935  [ 4416/44516]
loss: 0.606488  [ 8832/44516]
loss: 0.594637  [13248/44516]
loss: 0.599124  [17664/44516]
loss: 0.579397  [22080/44516]
loss: 0.573437  [26496/44516]
loss: 0.628968  [30912/44516]
loss: 0.574040  [35328/44516]
loss: 0.581816  [39744/44516]
loss: 0.532921  [44160/44516]
Epoch average loss: 0.5869357585906982


  0%|          | 0/696 [00:00<?, ?it/s]

loss: 0.621207  [    0/44516]
loss: 0.516411  [ 4416/44516]
loss: 0.537725  [ 8832/44516]
loss: 0.578803  [13248/44516]
loss: 0.537030  [17664/44516]
loss: 0.614550  [22080/44516]
loss: 0.554960  [26496/44516]
loss: 0.561645  [30912/44516]
loss: 0.592638  [35328/44516]
loss: 0.500187  [39744/44516]
loss: 0.537129  [44160/44516]
Epoch average loss: 0.5558214783668518
Done!


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1000 [00:00<?, ?ex/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

Split: test_r2 - 18 (1.832994%) sequences excluded


  0%|          | 0/16 [00:00<?, ?it/s]

Round: r2 - Model: Random 190k Subset - Accuracy: 33.299389%, RK: -0.001673


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


  0%|          | 0/101 [00:00<?, ?ba/s]

  0%|          | 0/100459 [00:00<?, ?ex/s]

  0%|          | 0/101 [00:00<?, ?ba/s]

Split: train_r3 - 7468 (8.030885%) sequences excluded


  0%|          | 0/1453 [00:00<?, ?it/s]

loss: 0.698562  [    0/92991]
loss: 0.635315  [ 9280/92991]
loss: 0.606245  [18560/92991]
loss: 0.567960  [27840/92991]
loss: 0.643162  [37120/92991]
loss: 0.600711  [46400/92991]
loss: 0.604694  [55680/92991]
loss: 0.631606  [64960/92991]
loss: 0.539954  [74240/92991]
loss: 0.652034  [83520/92991]
loss: 0.658233  [92800/92991]
Epoch average loss: 0.6017566323280334


  0%|          | 0/1453 [00:00<?, ?it/s]

loss: 0.619817  [    0/92991]
loss: 0.595634  [ 9280/92991]
loss: 0.562059  [18560/92991]
loss: 0.578261  [27840/92991]
loss: 0.568227  [37120/92991]
loss: 0.533602  [46400/92991]
loss: 0.578181  [55680/92991]
loss: 0.527934  [64960/92991]
loss: 0.542149  [74240/92991]
loss: 0.519473  [83520/92991]
loss: 0.545630  [92800/92991]
Epoch average loss: 0.5744589567184448


  0%|          | 0/1453 [00:00<?, ?it/s]

loss: 0.552824  [    0/92991]
loss: 0.582559  [ 9280/92991]
loss: 0.597908  [18560/92991]
loss: 0.549712  [27840/92991]
loss: 0.513818  [37120/92991]
loss: 0.490071  [46400/92991]
loss: 0.597688  [55680/92991]
loss: 0.553266  [64960/92991]
loss: 0.548793  [74240/92991]
loss: 0.598288  [83520/92991]
loss: 0.527880  [92800/92991]
Epoch average loss: 0.5501933097839355
Done!


Reusing dataset anli (/home/shana92/.cache/huggingface/datasets/anli/plain_text/0.1.0/aabce88453b06dff21c201855ea83283bab0390bff746deadb30b65695755c0b)


  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1200 [00:00<?, ?ex/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

Split: test_r3 - 90 (8.108108%) sequences excluded


  0%|          | 0/18 [00:00<?, ?it/s]

Round: r3 - Model: Random 190k Subset - Accuracy: 34.504506%, RK: 0.015251
