# Dual Objective:
- A) Use AFLite to greedily solve for $\text{arg min}_{S \subset \mathcal{D}, ~|S| \geq n}\mathcal{R}(\Phi, ~S, ~\mathcal{M})$
- B) Fine-tune GPT-2 with the resulting filtered dataset

### 1. Imports and Global Settings

In [1]:
from datasets import load_dataset, disable_caching
from tqdm.notebook import tqdm
from transformers import GPT2TokenizerFast, DataCollatorWithPadding, set_seed
import torch
from torch.nn.functional import one_hot
import copy
import numpy as np
from utils_ import tokenize, train_classifier, predict
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
disable_caching()

### 2. Pre-Processing
- Get SNLI Dataset (Train fold) and shuffle it using the same seed as used for obtaining GPT-2 based Feature Representation (see notebook [Filtering_Part1.ipynb](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/Filtering_Part1.ipynb))
- Remove instances without gold standard labels, i.e., label = -1
- One-hot encoding for labels
- Partition data 10%/90%; use the 90% as `train`
- Tokenise train

In [2]:
snli_train = load_dataset('snli', split = 'train').shuffle(seed = 42)
snli_train = snli_train.filter(lambda x: x['label'] != -1).map( \
    lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
    batched = True)
train = snli_train.select(range(int(len(snli_train)/10), len(snli_train)))

Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

In [3]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [4]:
# tokenize data
train = train.map(lambda x: tokenize(tokenizer, x['premise'] + '|' + x['hypothesis']))
len_bef_exclusion = len(train)

# exclude instances with > 128 tokens
train = train.filter(lambda x: x['exclude'] == False)
len_aft_exclusion = len(train)

# print message if instances were in fact excluded
if len_bef_exclusion - len_aft_exclusion > 0:
    
    print(f'{len_bef_exclusion - len_aft_exclusion} ' + \
          f'({(en_bef_exclusion/len_aft_exclusion - 1):>2f}%) sequences excluded')

  0%|          | 0/494431 [00:00<?, ?ex/s]

  0%|          | 0/495 [00:00<?, ?ba/s]

In [5]:
# keep only needed columns, set data format to PyTorch
train.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

### 3. Set up inputs for AFLite

In [6]:
# load in the feature representation, Phi, with linear layer attached
model = torch.load('feature_rep.pth')

# move model to CPU
model.to('cpu')

# freeze all layers except the last
num_layers = sum(1 for _ in model.parameters())
for idx, param in enumerate(model.parameters()):
    
    if idx != num_layers - 1:
        
        # freeze
        param.requires_grad = False

In [7]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 128)

In [8]:
# hyper-parameters - constrained by training time available
m = 30
n = 195000
t = 50000
k = 100000
tau = 0.75
AFLite_seeds = [0, 1, 2, 3, 4]

In [9]:
# hyper-parameters for model training within AFLite implementation
batch_size = 128 # constrained by GPU memory
lr = 1e-5 # set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108

### 4. Utility Function for AFLite

In [10]:
def select_k(pred_scores, tau, k):
    
    """
        Select up to k instances with the highest predictability
        scores (see report) subject to score >= tau
    """
    
    k_idx = [] # initialise
    pred_scores = np.array(pred_scores) # format as numpy array for subsequent steps
    sorted_descended_idx = np.argsort(-pred_scores) # sorting because we want to select instances with high pred_score
    sorted_pred_scores = pred_scores[sorted_descended_idx] # cache
    
    for idx in sorted_descended_idx:
        
        if pred_scores[idx] < tau:
            
            return(k_idx)
        
        elif idx == sorted_descended_idx.shape[0]:
            
            k_idx += idx
        
        elif pred_scores[idx] == pred_scores[idx + 1]:
            
            candidates = sorted_descended_idx[sorted_pred_scores == pred_scores[idx]]
            max_candidates_to_select = k - len(k_idx)
            
            if max_candidates_to_select >= candidates.shape[0]:
                
                k_idx += candidates.tolist()
            
            else:
                
                # randomly select a subset of `max_candidates_to_select` candidates
                k_idx += np.random.default_rng(42).choice(candidates, max_candidates_to_select, replace = False)
        
        else:
            
            k_idx += idx
        
        if len(k_idx) == k:
            
            return(k_idx)

### 5.  AFLite Procedure

In [11]:
filtered_datasets = {}

for seed in AFLite_seeds:
    
    # first step of AFLite; initialise S
    S = copy.deepcopy(train)
    
    # initialise iteration index
    it_idx = 0
    
    while len(S) > n:
        
        # update iteration index
        it_idx += 1
        
        # initialise multiset for Out-Of-Sample predictions
        E = {x: [] for x in range(len(S))}

        for j in range(m):
            
            # randomly partition S into (S\T_j, T_j) s.t. |S\T_j| = t
            tr_idx = set(np.random.default_rng(j).choice(np.arange(len(S)), t, replace = False))
            te_idx = set(range(len(S))) - tr_idx
            tr, te = S.select(tr_idx), S.select(te_idx)
            print(f'Seed {seed} - Iteration {it_idx} - Model {j} - Begin')
                        
            # train classifier on S\T_j, i.e. tr
            classifier = copy.deepcopy(model)
            dataloader = torch.utils.data.DataLoader(tr, batch_size=batch_size, \
                                 shuffle=True, collate_fn=data_collator)
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr = lr)
            trained_classifier = train_classifier(classifier, dataloader, optimizer, device)
            
            # for all instances i in T_j, add predictions to E(i)
            te_dataloader = torch.utils.data.DataLoader(te, batch_size=batch_size, collate_fn=data_collator)
            preds = predict(trained_classifier, te_dataloader, device)
            print(f'Seed {seed} - Iteration {it_idx} - Model {m} - Done')
            
            for pred_idx, data_idx in enumerate(te_idx): # there are as many predictions as test instances
                
                E[data_idx] += [preds[pred_idx]]
        
        # for all instances in S, compute predictability score
        pred_scores = [sum([1 for y_hat in x if y_hat == s[idx]['label']])/len(x) \
                       for idx, x in enumerate(E.values())]
        
        
        # Select up to k instances with the highest predictability scores subject to score >= tau
        selected_idx = select_k(pred_scores, tau, k)
        
        # filter
        S = S.select(set(range(len(S))) - selected_idx)
        
        # early stopping
        if len(selected_idx) < k:
            
            break
    
    # cache file
    filtered_datasets[seed] = S
    
    # print number of instances in S, for creating random baseline
    print(f'Number of instances in S (seed {seed}): {len(S)}')

Seed 0 - Iteration 1 - Model 0 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.490815  [    0/50000]
loss: 0.462348  [ 4992/50000]
loss: 0.483250  [ 9984/50000]
loss: 0.501971  [14976/50000]
loss: 0.445378  [19968/50000]
loss: 0.507128  [24960/50000]
loss: 0.435846  [29952/50000]
loss: 0.411889  [34944/50000]
loss: 0.481307  [39936/50000]
loss: 0.452662  [44928/50000]
loss: 0.496749  [31200/50000]
Epoch average loss: 0.4680524170398712


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.488783  [    0/50000]
loss: 0.467965  [ 4992/50000]
loss: 0.482960  [ 9984/50000]
loss: 0.464695  [14976/50000]
loss: 0.479309  [19968/50000]
loss: 0.460052  [24960/50000]
loss: 0.459742  [29952/50000]
loss: 0.501513  [34944/50000]
loss: 0.476872  [39936/50000]
loss: 0.534709  [44928/50000]
loss: 0.460403  [31200/50000]
Epoch average loss: 0.46620702743530273


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.439318  [    0/50000]
loss: 0.496312  [ 4992/50000]
loss: 0.486560  [ 9984/50000]
loss: 0.445026  [14976/50000]
loss: 0.457667  [19968/50000]
loss: 0.512696  [24960/50000]
loss: 0.515036  [29952/50000]
loss: 0.453929  [34944/50000]
loss: 0.412722  [39936/50000]
loss: 0.464638  [44928/50000]
loss: 0.542555  [31200/50000]
Epoch average loss: 0.4677579998970032
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 1 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.496329  [    0/50000]
loss: 0.461750  [ 4992/50000]
loss: 0.414453  [ 9984/50000]
loss: 0.499598  [14976/50000]
loss: 0.489535  [19968/50000]
loss: 0.493609  [24960/50000]
loss: 0.491454  [29952/50000]
loss: 0.517127  [34944/50000]
loss: 0.455259  [39936/50000]
loss: 0.461471  [44928/50000]
loss: 0.489900  [31200/50000]
Epoch average loss: 0.4648999571800232


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.480835  [    0/50000]
loss: 0.514182  [ 4992/50000]
loss: 0.455320  [ 9984/50000]
loss: 0.439934  [14976/50000]
loss: 0.424020  [19968/50000]
loss: 0.511567  [24960/50000]
loss: 0.536579  [29952/50000]
loss: 0.472423  [34944/50000]
loss: 0.482909  [39936/50000]
loss: 0.452261  [44928/50000]
loss: 0.445191  [31200/50000]
Epoch average loss: 0.46628013253211975


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.433838  [    0/50000]
loss: 0.455367  [ 4992/50000]
loss: 0.442342  [ 9984/50000]
loss: 0.443768  [14976/50000]
loss: 0.433495  [19968/50000]
loss: 0.446716  [24960/50000]
loss: 0.519018  [29952/50000]
loss: 0.480844  [34944/50000]
loss: 0.473903  [39936/50000]
loss: 0.441200  [44928/50000]
loss: 0.526351  [31200/50000]
Epoch average loss: 0.46541374921798706
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 2 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.468877  [    0/50000]
loss: 0.444588  [ 4992/50000]
loss: 0.463039  [ 9984/50000]
loss: 0.452809  [14976/50000]
loss: 0.475705  [19968/50000]
loss: 0.507545  [24960/50000]
loss: 0.442597  [29952/50000]
loss: 0.422230  [34944/50000]
loss: 0.504980  [39936/50000]
loss: 0.464689  [44928/50000]
loss: 0.475628  [31200/50000]
Epoch average loss: 0.4662174582481384


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.434559  [    0/50000]
loss: 0.472283  [ 4992/50000]
loss: 0.515166  [ 9984/50000]
loss: 0.440969  [14976/50000]
loss: 0.435638  [19968/50000]
loss: 0.420971  [24960/50000]
loss: 0.485145  [29952/50000]
loss: 0.495289  [34944/50000]
loss: 0.464453  [39936/50000]
loss: 0.464244  [44928/50000]
loss: 0.437880  [31200/50000]
Epoch average loss: 0.4661409258842468


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.471737  [    0/50000]
loss: 0.412788  [ 4992/50000]
loss: 0.467256  [ 9984/50000]
loss: 0.425003  [14976/50000]
loss: 0.399822  [19968/50000]
loss: 0.389849  [24960/50000]
loss: 0.456639  [29952/50000]
loss: 0.413032  [34944/50000]
loss: 0.473176  [39936/50000]
loss: 0.409610  [44928/50000]
loss: 0.495577  [31200/50000]
Epoch average loss: 0.46622443199157715
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 3 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.482670  [    0/50000]
loss: 0.467089  [ 4992/50000]
loss: 0.488443  [ 9984/50000]
loss: 0.473134  [14976/50000]
loss: 0.512105  [19968/50000]
loss: 0.512495  [24960/50000]
loss: 0.485099  [29952/50000]
loss: 0.429189  [34944/50000]
loss: 0.452853  [39936/50000]
loss: 0.462465  [44928/50000]
loss: 0.515317  [31200/50000]
Epoch average loss: 0.466852605342865


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.485342  [    0/50000]
loss: 0.458100  [ 4992/50000]
loss: 0.455437  [ 9984/50000]
loss: 0.439894  [14976/50000]
loss: 0.443444  [19968/50000]
loss: 0.519521  [24960/50000]
loss: 0.455382  [29952/50000]
loss: 0.451171  [34944/50000]
loss: 0.459446  [39936/50000]
loss: 0.459204  [44928/50000]
loss: 0.475304  [31200/50000]
Epoch average loss: 0.4650052487850189


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.533792  [    0/50000]
loss: 0.471606  [ 4992/50000]
loss: 0.526378  [ 9984/50000]
loss: 0.412952  [14976/50000]
loss: 0.413553  [19968/50000]
loss: 0.475280  [24960/50000]
loss: 0.390599  [29952/50000]
loss: 0.534704  [34944/50000]
loss: 0.460793  [39936/50000]
loss: 0.422744  [44928/50000]
loss: 0.466432  [31200/50000]
Epoch average loss: 0.4644717574119568
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 4 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.433044  [    0/50000]
loss: 0.526016  [ 4992/50000]
loss: 0.445489  [ 9984/50000]
loss: 0.485510  [14976/50000]
loss: 0.439847  [19968/50000]
loss: 0.432596  [24960/50000]
loss: 0.465773  [29952/50000]
loss: 0.468632  [34944/50000]
loss: 0.491474  [39936/50000]
loss: 0.441192  [44928/50000]
loss: 0.419341  [31200/50000]
Epoch average loss: 0.4678686261177063


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.437656  [    0/50000]
loss: 0.469157  [ 4992/50000]
loss: 0.535088  [ 9984/50000]
loss: 0.466430  [14976/50000]
loss: 0.393062  [19968/50000]
loss: 0.467271  [24960/50000]
loss: 0.490180  [29952/50000]
loss: 0.471673  [34944/50000]
loss: 0.447780  [39936/50000]
loss: 0.473813  [44928/50000]
loss: 0.492109  [31200/50000]
Epoch average loss: 0.46632587909698486


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.524030  [    0/50000]
loss: 0.466867  [ 4992/50000]
loss: 0.433357  [ 9984/50000]
loss: 0.438189  [14976/50000]
loss: 0.443250  [19968/50000]
loss: 0.460037  [24960/50000]
loss: 0.468671  [29952/50000]
loss: 0.523864  [34944/50000]
loss: 0.450836  [39936/50000]
loss: 0.456779  [44928/50000]
loss: 0.424690  [31200/50000]
Epoch average loss: 0.465709924697876
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 5 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.455836  [    0/50000]
loss: 0.462964  [ 4992/50000]
loss: 0.456246  [ 9984/50000]
loss: 0.477656  [14976/50000]
loss: 0.453614  [19968/50000]
loss: 0.469198  [24960/50000]
loss: 0.409245  [29952/50000]
loss: 0.425582  [34944/50000]
loss: 0.456350  [39936/50000]
loss: 0.535352  [44928/50000]
loss: 0.456168  [31200/50000]
Epoch average loss: 0.4682939052581787


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.500058  [    0/50000]
loss: 0.450341  [ 4992/50000]
loss: 0.422983  [ 9984/50000]
loss: 0.448929  [14976/50000]
loss: 0.475844  [19968/50000]
loss: 0.493882  [24960/50000]
loss: 0.442765  [29952/50000]
loss: 0.521435  [34944/50000]
loss: 0.529140  [39936/50000]
loss: 0.467571  [44928/50000]
loss: 0.430057  [31200/50000]
Epoch average loss: 0.46802929043769836


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.451216  [    0/50000]
loss: 0.473506  [ 4992/50000]
loss: 0.436486  [ 9984/50000]
loss: 0.477905  [14976/50000]
loss: 0.505407  [19968/50000]
loss: 0.503532  [24960/50000]
loss: 0.531428  [29952/50000]
loss: 0.515665  [34944/50000]
loss: 0.444911  [39936/50000]
loss: 0.485046  [44928/50000]
loss: 0.491922  [31200/50000]
Epoch average loss: 0.4658316373825073
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 6 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.437287  [    0/50000]
loss: 0.459754  [ 4992/50000]
loss: 0.489548  [ 9984/50000]
loss: 0.409811  [14976/50000]
loss: 0.417509  [19968/50000]
loss: 0.431104  [24960/50000]
loss: 0.451993  [29952/50000]
loss: 0.485138  [34944/50000]
loss: 0.526658  [39936/50000]
loss: 0.459254  [44928/50000]
loss: 0.476036  [31200/50000]
Epoch average loss: 0.4688470661640167


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.435326  [    0/50000]
loss: 0.462147  [ 4992/50000]
loss: 0.465431  [ 9984/50000]
loss: 0.502050  [14976/50000]
loss: 0.433416  [19968/50000]
loss: 0.463284  [24960/50000]
loss: 0.439393  [29952/50000]
loss: 0.446625  [34944/50000]
loss: 0.480686  [39936/50000]
loss: 0.491207  [44928/50000]
loss: 0.547830  [31200/50000]
Epoch average loss: 0.4671629071235657


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.460736  [    0/50000]
loss: 0.437581  [ 4992/50000]
loss: 0.527526  [ 9984/50000]
loss: 0.491923  [14976/50000]
loss: 0.456977  [19968/50000]
loss: 0.510696  [24960/50000]
loss: 0.455466  [29952/50000]
loss: 0.519326  [34944/50000]
loss: 0.480725  [39936/50000]
loss: 0.474543  [44928/50000]
loss: 0.515977  [31200/50000]
Epoch average loss: 0.4668899178504944
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 7 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.514175  [    0/50000]
loss: 0.422671  [ 4992/50000]
loss: 0.387330  [ 9984/50000]
loss: 0.438503  [14976/50000]
loss: 0.442139  [19968/50000]
loss: 0.456264  [24960/50000]
loss: 0.476131  [29952/50000]
loss: 0.430742  [34944/50000]
loss: 0.510662  [39936/50000]
loss: 0.443369  [44928/50000]
loss: 0.496969  [31200/50000]
Epoch average loss: 0.46389877796173096


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.447469  [    0/50000]
loss: 0.438273  [ 4992/50000]
loss: 0.497786  [ 9984/50000]
loss: 0.555004  [14976/50000]
loss: 0.507763  [19968/50000]
loss: 0.462944  [24960/50000]
loss: 0.472244  [29952/50000]
loss: 0.490931  [34944/50000]
loss: 0.511012  [39936/50000]
loss: 0.452350  [44928/50000]
loss: 0.506961  [31200/50000]
Epoch average loss: 0.4642736613750458


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.470216  [    0/50000]
loss: 0.417639  [ 4992/50000]
loss: 0.448084  [ 9984/50000]
loss: 0.534630  [14976/50000]
loss: 0.437972  [19968/50000]
loss: 0.424240  [24960/50000]
loss: 0.419690  [29952/50000]
loss: 0.450670  [34944/50000]
loss: 0.528334  [39936/50000]
loss: 0.431874  [44928/50000]
loss: 0.436372  [31200/50000]
Epoch average loss: 0.46379101276397705
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 8 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.452210  [    0/50000]
loss: 0.492686  [ 4992/50000]
loss: 0.416774  [ 9984/50000]
loss: 0.452274  [14976/50000]
loss: 0.481391  [19968/50000]
loss: 0.489002  [24960/50000]
loss: 0.409274  [29952/50000]
loss: 0.444181  [34944/50000]
loss: 0.471738  [39936/50000]
loss: 0.470591  [44928/50000]
loss: 0.430503  [31200/50000]
Epoch average loss: 0.4669973850250244


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.426227  [    0/50000]
loss: 0.442118  [ 4992/50000]
loss: 0.446396  [ 9984/50000]
loss: 0.479805  [14976/50000]
loss: 0.533572  [19968/50000]
loss: 0.446090  [24960/50000]
loss: 0.433642  [29952/50000]
loss: 0.509674  [34944/50000]
loss: 0.490265  [39936/50000]
loss: 0.417314  [44928/50000]
loss: 0.489345  [31200/50000]
Epoch average loss: 0.4667733907699585


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.482425  [    0/50000]
loss: 0.492693  [ 4992/50000]
loss: 0.439013  [ 9984/50000]
loss: 0.472188  [14976/50000]
loss: 0.521811  [19968/50000]
loss: 0.422870  [24960/50000]
loss: 0.438534  [29952/50000]
loss: 0.426945  [34944/50000]
loss: 0.523431  [39936/50000]
loss: 0.456516  [44928/50000]
loss: 0.566117  [31200/50000]
Epoch average loss: 0.46429717540740967
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 9 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.451098  [    0/50000]
loss: 0.477350  [ 4992/50000]
loss: 0.475644  [ 9984/50000]
loss: 0.458085  [14976/50000]
loss: 0.433531  [19968/50000]
loss: 0.406473  [24960/50000]
loss: 0.535883  [29952/50000]
loss: 0.500435  [34944/50000]
loss: 0.498843  [39936/50000]
loss: 0.484204  [44928/50000]
loss: 0.453903  [31200/50000]
Epoch average loss: 0.4672124981880188


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.443359  [    0/50000]
loss: 0.496375  [ 4992/50000]
loss: 0.437462  [ 9984/50000]
loss: 0.480564  [14976/50000]
loss: 0.479855  [19968/50000]
loss: 0.427105  [24960/50000]
loss: 0.454992  [29952/50000]
loss: 0.459519  [34944/50000]
loss: 0.472355  [39936/50000]
loss: 0.475721  [44928/50000]
loss: 0.425852  [31200/50000]
Epoch average loss: 0.46523669362068176


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.484567  [    0/50000]
loss: 0.450713  [ 4992/50000]
loss: 0.448480  [ 9984/50000]
loss: 0.451641  [14976/50000]
loss: 0.475978  [19968/50000]
loss: 0.457994  [24960/50000]
loss: 0.452651  [29952/50000]
loss: 0.462930  [34944/50000]
loss: 0.492389  [39936/50000]
loss: 0.469986  [44928/50000]
loss: 0.451083  [31200/50000]
Epoch average loss: 0.4656510055065155
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 10 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.468861  [    0/50000]
loss: 0.461141  [ 4992/50000]
loss: 0.499373  [ 9984/50000]
loss: 0.428628  [14976/50000]
loss: 0.442285  [19968/50000]
loss: 0.387423  [24960/50000]
loss: 0.476965  [29952/50000]
loss: 0.483892  [34944/50000]
loss: 0.477934  [39936/50000]
loss: 0.477053  [44928/50000]
loss: 0.497267  [31200/50000]
Epoch average loss: 0.4649522006511688


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.435785  [    0/50000]
loss: 0.533157  [ 4992/50000]
loss: 0.455788  [ 9984/50000]
loss: 0.501789  [14976/50000]
loss: 0.482162  [19968/50000]
loss: 0.441350  [24960/50000]
loss: 0.486231  [29952/50000]
loss: 0.452827  [34944/50000]
loss: 0.461257  [39936/50000]
loss: 0.488426  [44928/50000]
loss: 0.437270  [31200/50000]
Epoch average loss: 0.464244544506073


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.460283  [    0/50000]
loss: 0.479496  [ 4992/50000]
loss: 0.569986  [ 9984/50000]
loss: 0.459675  [14976/50000]
loss: 0.443899  [19968/50000]
loss: 0.472687  [24960/50000]
loss: 0.538514  [29952/50000]
loss: 0.457028  [34944/50000]
loss: 0.506931  [39936/50000]
loss: 0.458667  [44928/50000]
loss: 0.469683  [31200/50000]
Epoch average loss: 0.4621706008911133
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 11 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.488408  [    0/50000]
loss: 0.504679  [ 4992/50000]
loss: 0.470553  [ 9984/50000]
loss: 0.500734  [14976/50000]
loss: 0.472549  [19968/50000]
loss: 0.473824  [24960/50000]
loss: 0.448258  [29952/50000]
loss: 0.424527  [34944/50000]
loss: 0.439294  [39936/50000]
loss: 0.479055  [44928/50000]
loss: 0.468200  [31200/50000]
Epoch average loss: 0.46702322363853455


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.481663  [    0/50000]
loss: 0.487122  [ 4992/50000]
loss: 0.479060  [ 9984/50000]
loss: 0.478496  [14976/50000]
loss: 0.452755  [19968/50000]
loss: 0.513808  [24960/50000]
loss: 0.490463  [29952/50000]
loss: 0.451463  [34944/50000]
loss: 0.445567  [39936/50000]
loss: 0.460318  [44928/50000]
loss: 0.473600  [31200/50000]
Epoch average loss: 0.4657803773880005


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.445372  [    0/50000]
loss: 0.483058  [ 4992/50000]
loss: 0.409557  [ 9984/50000]
loss: 0.495329  [14976/50000]
loss: 0.510445  [19968/50000]
loss: 0.445253  [24960/50000]
loss: 0.443186  [29952/50000]
loss: 0.501668  [34944/50000]
loss: 0.478168  [39936/50000]
loss: 0.441746  [44928/50000]
loss: 0.449356  [31200/50000]
Epoch average loss: 0.4645027220249176
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 12 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.509644  [    0/50000]
loss: 0.461033  [ 4992/50000]
loss: 0.442738  [ 9984/50000]
loss: 0.521098  [14976/50000]
loss: 0.417417  [19968/50000]
loss: 0.429046  [24960/50000]
loss: 0.476489  [29952/50000]
loss: 0.500365  [34944/50000]
loss: 0.454865  [39936/50000]
loss: 0.511599  [44928/50000]
loss: 0.540245  [31200/50000]
Epoch average loss: 0.46408531069755554


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.489265  [    0/50000]
loss: 0.461402  [ 4992/50000]
loss: 0.443232  [ 9984/50000]
loss: 0.504986  [14976/50000]
loss: 0.494491  [19968/50000]
loss: 0.504043  [24960/50000]
loss: 0.521035  [29952/50000]
loss: 0.409282  [34944/50000]
loss: 0.478492  [39936/50000]
loss: 0.549472  [44928/50000]
loss: 0.468177  [31200/50000]
Epoch average loss: 0.46384698152542114


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.492013  [    0/50000]
loss: 0.479922  [ 4992/50000]
loss: 0.459161  [ 9984/50000]
loss: 0.441321  [14976/50000]
loss: 0.479816  [19968/50000]
loss: 0.449824  [24960/50000]
loss: 0.465926  [29952/50000]
loss: 0.447740  [34944/50000]
loss: 0.392865  [39936/50000]
loss: 0.458488  [44928/50000]
loss: 0.503041  [31200/50000]
Epoch average loss: 0.4644061326980591
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 13 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.497591  [    0/50000]
loss: 0.530760  [ 4992/50000]
loss: 0.497899  [ 9984/50000]
loss: 0.478314  [14976/50000]
loss: 0.518068  [19968/50000]
loss: 0.486955  [24960/50000]
loss: 0.456379  [29952/50000]
loss: 0.457733  [34944/50000]
loss: 0.457381  [39936/50000]
loss: 0.414165  [44928/50000]
loss: 0.490314  [31200/50000]
Epoch average loss: 0.46755897998809814


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.474509  [    0/50000]
loss: 0.487641  [ 4992/50000]
loss: 0.515323  [ 9984/50000]
loss: 0.533875  [14976/50000]
loss: 0.448765  [19968/50000]
loss: 0.406020  [24960/50000]
loss: 0.432788  [29952/50000]
loss: 0.465067  [34944/50000]
loss: 0.463552  [39936/50000]
loss: 0.500254  [44928/50000]
loss: 0.453372  [31200/50000]
Epoch average loss: 0.4643569886684418


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.488214  [    0/50000]
loss: 0.468060  [ 4992/50000]
loss: 0.480280  [ 9984/50000]
loss: 0.427781  [14976/50000]
loss: 0.518617  [19968/50000]
loss: 0.498264  [24960/50000]
loss: 0.427276  [29952/50000]
loss: 0.422488  [34944/50000]
loss: 0.462825  [39936/50000]
loss: 0.448409  [44928/50000]
loss: 0.419108  [31200/50000]
Epoch average loss: 0.46405351161956787
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 14 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.512553  [    0/50000]
loss: 0.474466  [ 4992/50000]
loss: 0.468083  [ 9984/50000]
loss: 0.508476  [14976/50000]
loss: 0.437254  [19968/50000]
loss: 0.494658  [24960/50000]
loss: 0.447799  [29952/50000]
loss: 0.484724  [34944/50000]
loss: 0.504379  [39936/50000]
loss: 0.460049  [44928/50000]
loss: 0.482386  [31200/50000]
Epoch average loss: 0.4697825014591217


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.440578  [    0/50000]
loss: 0.564142  [ 4992/50000]
loss: 0.430481  [ 9984/50000]
loss: 0.516748  [14976/50000]
loss: 0.468006  [19968/50000]
loss: 0.515787  [24960/50000]
loss: 0.459798  [29952/50000]
loss: 0.478496  [34944/50000]
loss: 0.486582  [39936/50000]
loss: 0.448254  [44928/50000]
loss: 0.520714  [31200/50000]
Epoch average loss: 0.46750015020370483


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.441896  [    0/50000]
loss: 0.463895  [ 4992/50000]
loss: 0.473388  [ 9984/50000]
loss: 0.512665  [14976/50000]
loss: 0.489067  [19968/50000]
loss: 0.458733  [24960/50000]
loss: 0.447567  [29952/50000]
loss: 0.480978  [34944/50000]
loss: 0.477234  [39936/50000]
loss: 0.465917  [44928/50000]
loss: 0.465807  [31200/50000]
Epoch average loss: 0.4676004946231842
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 15 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.507034  [    0/50000]
loss: 0.478793  [ 4992/50000]
loss: 0.558938  [ 9984/50000]
loss: 0.516160  [14976/50000]
loss: 0.501782  [19968/50000]
loss: 0.458433  [24960/50000]
loss: 0.489200  [29952/50000]
loss: 0.445365  [34944/50000]
loss: 0.494896  [39936/50000]
loss: 0.434355  [44928/50000]
loss: 0.491686  [31200/50000]
Epoch average loss: 0.46463316679000854


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.497821  [    0/50000]
loss: 0.496871  [ 4992/50000]
loss: 0.429645  [ 9984/50000]
loss: 0.439909  [14976/50000]
loss: 0.461965  [19968/50000]
loss: 0.454651  [24960/50000]
loss: 0.490157  [29952/50000]
loss: 0.498831  [34944/50000]
loss: 0.483950  [39936/50000]
loss: 0.461069  [44928/50000]
loss: 0.525301  [31200/50000]
Epoch average loss: 0.46480026841163635


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.460685  [    0/50000]
loss: 0.387969  [ 4992/50000]
loss: 0.416599  [ 9984/50000]
loss: 0.434959  [14976/50000]
loss: 0.474065  [19968/50000]
loss: 0.408900  [24960/50000]
loss: 0.451993  [29952/50000]
loss: 0.408330  [34944/50000]
loss: 0.435167  [39936/50000]
loss: 0.467250  [44928/50000]
loss: 0.480955  [31200/50000]
Epoch average loss: 0.4625505805015564
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 16 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.498403  [    0/50000]
loss: 0.416919  [ 4992/50000]
loss: 0.527071  [ 9984/50000]
loss: 0.522447  [14976/50000]
loss: 0.404433  [19968/50000]
loss: 0.461720  [24960/50000]
loss: 0.527305  [29952/50000]
loss: 0.486443  [34944/50000]
loss: 0.524850  [39936/50000]
loss: 0.523264  [44928/50000]
loss: 0.432307  [31200/50000]
Epoch average loss: 0.46635153889656067


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.473207  [    0/50000]
loss: 0.488149  [ 4992/50000]
loss: 0.471921  [ 9984/50000]
loss: 0.423854  [14976/50000]
loss: 0.458517  [19968/50000]
loss: 0.463102  [24960/50000]
loss: 0.423498  [29952/50000]
loss: 0.420713  [34944/50000]
loss: 0.468155  [39936/50000]
loss: 0.476712  [44928/50000]
loss: 0.459700  [31200/50000]
Epoch average loss: 0.4639015197753906


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.473141  [    0/50000]
loss: 0.424607  [ 4992/50000]
loss: 0.408113  [ 9984/50000]
loss: 0.457685  [14976/50000]
loss: 0.536663  [19968/50000]
loss: 0.455483  [24960/50000]
loss: 0.467286  [29952/50000]
loss: 0.500957  [34944/50000]
loss: 0.447979  [39936/50000]
loss: 0.474516  [44928/50000]
loss: 0.473007  [31200/50000]
Epoch average loss: 0.46451324224472046
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 17 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.505127  [    0/50000]
loss: 0.500839  [ 4992/50000]
loss: 0.438300  [ 9984/50000]
loss: 0.495394  [14976/50000]
loss: 0.428733  [19968/50000]
loss: 0.496772  [24960/50000]
loss: 0.446269  [29952/50000]
loss: 0.496062  [34944/50000]
loss: 0.475473  [39936/50000]
loss: 0.456013  [44928/50000]
loss: 0.430101  [31200/50000]
Epoch average loss: 0.4672999083995819


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.455757  [    0/50000]
loss: 0.464213  [ 4992/50000]
loss: 0.433969  [ 9984/50000]
loss: 0.526357  [14976/50000]
loss: 0.475596  [19968/50000]
loss: 0.444367  [24960/50000]
loss: 0.443328  [29952/50000]
loss: 0.499319  [34944/50000]
loss: 0.432696  [39936/50000]
loss: 0.519779  [44928/50000]
loss: 0.447838  [31200/50000]
Epoch average loss: 0.46694666147232056


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.501469  [    0/50000]
loss: 0.450157  [ 4992/50000]
loss: 0.476071  [ 9984/50000]
loss: 0.430901  [14976/50000]
loss: 0.434559  [19968/50000]
loss: 0.447274  [24960/50000]
loss: 0.497300  [29952/50000]
loss: 0.482795  [34944/50000]
loss: 0.484207  [39936/50000]
loss: 0.451954  [44928/50000]
loss: 0.401366  [31200/50000]
Epoch average loss: 0.4666641354560852
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 18 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.482495  [    0/50000]
loss: 0.484927  [ 4992/50000]
loss: 0.501310  [ 9984/50000]
loss: 0.519095  [14976/50000]
loss: 0.548186  [19968/50000]
loss: 0.441271  [24960/50000]
loss: 0.462279  [29952/50000]
loss: 0.499985  [34944/50000]
loss: 0.503107  [39936/50000]
loss: 0.480263  [44928/50000]
loss: 0.470140  [31200/50000]
Epoch average loss: 0.4690442979335785


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.398455  [    0/50000]
loss: 0.502112  [ 4992/50000]
loss: 0.446061  [ 9984/50000]
loss: 0.451930  [14976/50000]
loss: 0.416348  [19968/50000]
loss: 0.398900  [24960/50000]
loss: 0.462558  [29952/50000]
loss: 0.443684  [34944/50000]
loss: 0.432096  [39936/50000]
loss: 0.497407  [44928/50000]
loss: 0.415352  [31200/50000]
Epoch average loss: 0.46500763297080994


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.496254  [    0/50000]
loss: 0.380434  [ 4992/50000]
loss: 0.473774  [ 9984/50000]
loss: 0.454833  [14976/50000]
loss: 0.456954  [19968/50000]
loss: 0.466537  [24960/50000]
loss: 0.437903  [29952/50000]
loss: 0.512309  [34944/50000]
loss: 0.446846  [39936/50000]
loss: 0.493539  [44928/50000]
loss: 0.454389  [31200/50000]
Epoch average loss: 0.4634169936180115
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 19 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.477555  [    0/50000]
loss: 0.476573  [ 4992/50000]
loss: 0.502242  [ 9984/50000]
loss: 0.480113  [14976/50000]
loss: 0.464464  [19968/50000]
loss: 0.538147  [24960/50000]
loss: 0.487186  [29952/50000]
loss: 0.472866  [34944/50000]
loss: 0.433335  [39936/50000]
loss: 0.482293  [44928/50000]
loss: 0.454832  [31200/50000]
Epoch average loss: 0.4704963266849518


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.534723  [    0/50000]
loss: 0.449893  [ 4992/50000]
loss: 0.457884  [ 9984/50000]
loss: 0.463584  [14976/50000]
loss: 0.472090  [19968/50000]
loss: 0.402379  [24960/50000]
loss: 0.488251  [29952/50000]
loss: 0.462015  [34944/50000]
loss: 0.425673  [39936/50000]
loss: 0.435067  [44928/50000]
loss: 0.503056  [31200/50000]
Epoch average loss: 0.4670102894306183


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.543958  [    0/50000]
loss: 0.503758  [ 4992/50000]
loss: 0.478292  [ 9984/50000]
loss: 0.519342  [14976/50000]
loss: 0.484957  [19968/50000]
loss: 0.448890  [24960/50000]
loss: 0.446593  [29952/50000]
loss: 0.456437  [34944/50000]
loss: 0.478185  [39936/50000]
loss: 0.468295  [44928/50000]
loss: 0.370702  [31200/50000]
Epoch average loss: 0.46878892183303833
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 20 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.475372  [    0/50000]
loss: 0.447662  [ 4992/50000]
loss: 0.465995  [ 9984/50000]
loss: 0.455226  [14976/50000]
loss: 0.443935  [19968/50000]
loss: 0.516199  [24960/50000]
loss: 0.500436  [29952/50000]
loss: 0.485346  [34944/50000]
loss: 0.448143  [39936/50000]
loss: 0.518699  [44928/50000]
loss: 0.450377  [31200/50000]
Epoch average loss: 0.4680536091327667


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.418401  [    0/50000]
loss: 0.510836  [ 4992/50000]
loss: 0.492478  [ 9984/50000]
loss: 0.460347  [14976/50000]
loss: 0.452218  [19968/50000]
loss: 0.473047  [24960/50000]
loss: 0.412962  [29952/50000]
loss: 0.467663  [34944/50000]
loss: 0.463992  [39936/50000]
loss: 0.460755  [44928/50000]
loss: 0.453841  [31200/50000]
Epoch average loss: 0.4667637348175049


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.484204  [    0/50000]
loss: 0.494017  [ 4992/50000]
loss: 0.485653  [ 9984/50000]
loss: 0.432420  [14976/50000]
loss: 0.434462  [19968/50000]
loss: 0.480072  [24960/50000]
loss: 0.404751  [29952/50000]
loss: 0.461657  [34944/50000]
loss: 0.457187  [39936/50000]
loss: 0.442042  [44928/50000]
loss: 0.430391  [31200/50000]
Epoch average loss: 0.4658651053905487
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 21 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.430043  [    0/50000]
loss: 0.486160  [ 4992/50000]
loss: 0.469001  [ 9984/50000]
loss: 0.478766  [14976/50000]
loss: 0.456151  [19968/50000]
loss: 0.477125  [24960/50000]
loss: 0.456789  [29952/50000]
loss: 0.487704  [34944/50000]
loss: 0.453655  [39936/50000]
loss: 0.439149  [44928/50000]
loss: 0.415876  [31200/50000]
Epoch average loss: 0.46819308400154114


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.507056  [    0/50000]
loss: 0.508202  [ 4992/50000]
loss: 0.484298  [ 9984/50000]
loss: 0.472675  [14976/50000]
loss: 0.475956  [19968/50000]
loss: 0.430125  [24960/50000]
loss: 0.445786  [29952/50000]
loss: 0.445598  [34944/50000]
loss: 0.471165  [39936/50000]
loss: 0.462389  [44928/50000]
loss: 0.433512  [31200/50000]
Epoch average loss: 0.46503567695617676


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.440933  [    0/50000]
loss: 0.506381  [ 4992/50000]
loss: 0.474983  [ 9984/50000]
loss: 0.437688  [14976/50000]
loss: 0.445092  [19968/50000]
loss: 0.453555  [24960/50000]
loss: 0.526353  [29952/50000]
loss: 0.495068  [34944/50000]
loss: 0.458941  [39936/50000]
loss: 0.536526  [44928/50000]
loss: 0.483201  [31200/50000]
Epoch average loss: 0.4664197266101837
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 22 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.494433  [    0/50000]
loss: 0.452008  [ 4992/50000]
loss: 0.496163  [ 9984/50000]
loss: 0.419790  [14976/50000]
loss: 0.478190  [19968/50000]
loss: 0.499334  [24960/50000]
loss: 0.415537  [29952/50000]
loss: 0.496038  [34944/50000]
loss: 0.443988  [39936/50000]
loss: 0.499493  [44928/50000]
loss: 0.520658  [31200/50000]
Epoch average loss: 0.46517202258110046


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.430317  [    0/50000]
loss: 0.492328  [ 4992/50000]
loss: 0.466971  [ 9984/50000]
loss: 0.487365  [14976/50000]
loss: 0.439910  [19968/50000]
loss: 0.467765  [24960/50000]
loss: 0.436565  [29952/50000]
loss: 0.517786  [34944/50000]
loss: 0.516299  [39936/50000]
loss: 0.470324  [44928/50000]
loss: 0.513472  [31200/50000]
Epoch average loss: 0.46545150876045227


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.476449  [    0/50000]
loss: 0.430233  [ 4992/50000]
loss: 0.452127  [ 9984/50000]
loss: 0.449089  [14976/50000]
loss: 0.457004  [19968/50000]
loss: 0.444404  [24960/50000]
loss: 0.420915  [29952/50000]
loss: 0.501941  [34944/50000]
loss: 0.490740  [39936/50000]
loss: 0.494690  [44928/50000]
loss: 0.464558  [31200/50000]
Epoch average loss: 0.4650108218193054
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 23 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.491024  [    0/50000]
loss: 0.485894  [ 4992/50000]
loss: 0.552416  [ 9984/50000]
loss: 0.492041  [14976/50000]
loss: 0.501383  [19968/50000]
loss: 0.453357  [24960/50000]
loss: 0.476226  [29952/50000]
loss: 0.440259  [34944/50000]
loss: 0.472606  [39936/50000]
loss: 0.493140  [44928/50000]
loss: 0.414294  [31200/50000]
Epoch average loss: 0.46768316626548767


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.462126  [    0/50000]
loss: 0.471112  [ 4992/50000]
loss: 0.442025  [ 9984/50000]
loss: 0.440201  [14976/50000]
loss: 0.453907  [19968/50000]
loss: 0.487184  [24960/50000]
loss: 0.420961  [29952/50000]
loss: 0.430418  [34944/50000]
loss: 0.462805  [39936/50000]
loss: 0.443001  [44928/50000]
loss: 0.406781  [31200/50000]
Epoch average loss: 0.4656551480293274


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.494686  [    0/50000]
loss: 0.457101  [ 4992/50000]
loss: 0.488097  [ 9984/50000]
loss: 0.548597  [14976/50000]
loss: 0.482627  [19968/50000]
loss: 0.443434  [24960/50000]
loss: 0.446603  [29952/50000]
loss: 0.518370  [34944/50000]
loss: 0.547678  [39936/50000]
loss: 0.481638  [44928/50000]
loss: 0.412869  [31200/50000]
Epoch average loss: 0.46717384457588196
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 24 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.449381  [    0/50000]
loss: 0.468391  [ 4992/50000]
loss: 0.471227  [ 9984/50000]
loss: 0.513466  [14976/50000]
loss: 0.470658  [19968/50000]
loss: 0.418146  [24960/50000]
loss: 0.444162  [29952/50000]
loss: 0.489960  [34944/50000]
loss: 0.474804  [39936/50000]
loss: 0.465466  [44928/50000]
loss: 0.432358  [31200/50000]
Epoch average loss: 0.46811777353286743


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.419841  [    0/50000]
loss: 0.513758  [ 4992/50000]
loss: 0.520749  [ 9984/50000]
loss: 0.393288  [14976/50000]
loss: 0.469809  [19968/50000]
loss: 0.514845  [24960/50000]
loss: 0.500509  [29952/50000]
loss: 0.436897  [34944/50000]
loss: 0.505542  [39936/50000]
loss: 0.559366  [44928/50000]
loss: 0.421984  [31200/50000]
Epoch average loss: 0.4658794701099396


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.448667  [    0/50000]
loss: 0.481965  [ 4992/50000]
loss: 0.466432  [ 9984/50000]
loss: 0.497261  [14976/50000]
loss: 0.466778  [19968/50000]
loss: 0.463703  [24960/50000]
loss: 0.450942  [29952/50000]
loss: 0.466630  [34944/50000]
loss: 0.437509  [39936/50000]
loss: 0.465097  [44928/50000]
loss: 0.456906  [31200/50000]
Epoch average loss: 0.46664413809776306
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 25 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.502757  [    0/50000]
loss: 0.419633  [ 4992/50000]
loss: 0.452767  [ 9984/50000]
loss: 0.528674  [14976/50000]
loss: 0.467477  [19968/50000]
loss: 0.454769  [24960/50000]
loss: 0.464534  [29952/50000]
loss: 0.504483  [34944/50000]
loss: 0.486964  [39936/50000]
loss: 0.510879  [44928/50000]
loss: 0.424436  [31200/50000]
Epoch average loss: 0.46565693616867065


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.518114  [    0/50000]
loss: 0.457981  [ 4992/50000]
loss: 0.495266  [ 9984/50000]
loss: 0.475037  [14976/50000]
loss: 0.472165  [19968/50000]
loss: 0.497320  [24960/50000]
loss: 0.510403  [29952/50000]
loss: 0.424662  [34944/50000]
loss: 0.480267  [39936/50000]
loss: 0.440383  [44928/50000]
loss: 0.460675  [31200/50000]
Epoch average loss: 0.46286654472351074


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.501746  [    0/50000]
loss: 0.498642  [ 4992/50000]
loss: 0.441506  [ 9984/50000]
loss: 0.501758  [14976/50000]
loss: 0.478087  [19968/50000]
loss: 0.459252  [24960/50000]
loss: 0.423987  [29952/50000]
loss: 0.495204  [34944/50000]
loss: 0.433456  [39936/50000]
loss: 0.467658  [44928/50000]
loss: 0.374183  [31200/50000]
Epoch average loss: 0.4630371332168579
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 26 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.471016  [    0/50000]
loss: 0.411591  [ 4992/50000]
loss: 0.516195  [ 9984/50000]
loss: 0.462409  [14976/50000]
loss: 0.500151  [19968/50000]
loss: 0.497262  [24960/50000]
loss: 0.472269  [29952/50000]
loss: 0.481052  [34944/50000]
loss: 0.452132  [39936/50000]
loss: 0.477811  [44928/50000]
loss: 0.449242  [31200/50000]
Epoch average loss: 0.4659276306629181


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.524131  [    0/50000]
loss: 0.491138  [ 4992/50000]
loss: 0.491060  [ 9984/50000]
loss: 0.479177  [14976/50000]
loss: 0.456984  [19968/50000]
loss: 0.454466  [24960/50000]
loss: 0.509189  [29952/50000]
loss: 0.497001  [34944/50000]
loss: 0.453285  [39936/50000]
loss: 0.423740  [44928/50000]
loss: 0.458048  [31200/50000]
Epoch average loss: 0.46621713042259216


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.464815  [    0/50000]
loss: 0.479472  [ 4992/50000]
loss: 0.433501  [ 9984/50000]
loss: 0.504584  [14976/50000]
loss: 0.470295  [19968/50000]
loss: 0.459972  [24960/50000]
loss: 0.457953  [29952/50000]
loss: 0.509463  [34944/50000]
loss: 0.487049  [39936/50000]
loss: 0.443185  [44928/50000]
loss: 0.442090  [31200/50000]
Epoch average loss: 0.4657314419746399
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 27 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.431224  [    0/50000]
loss: 0.476911  [ 4992/50000]
loss: 0.509927  [ 9984/50000]
loss: 0.470430  [14976/50000]
loss: 0.517648  [19968/50000]
loss: 0.459220  [24960/50000]
loss: 0.494984  [29952/50000]
loss: 0.451616  [34944/50000]
loss: 0.443977  [39936/50000]
loss: 0.450821  [44928/50000]
loss: 0.519790  [31200/50000]
Epoch average loss: 0.4679385721683502


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.464182  [    0/50000]
loss: 0.456767  [ 4992/50000]
loss: 0.462471  [ 9984/50000]
loss: 0.487661  [14976/50000]
loss: 0.510534  [19968/50000]
loss: 0.410501  [24960/50000]
loss: 0.478851  [29952/50000]
loss: 0.489471  [34944/50000]
loss: 0.462766  [39936/50000]
loss: 0.455909  [44928/50000]
loss: 0.509098  [31200/50000]
Epoch average loss: 0.46559658646583557


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.496724  [    0/50000]
loss: 0.529835  [ 4992/50000]
loss: 0.447723  [ 9984/50000]
loss: 0.498441  [14976/50000]
loss: 0.466099  [19968/50000]
loss: 0.506965  [24960/50000]
loss: 0.465093  [29952/50000]
loss: 0.413828  [34944/50000]
loss: 0.468896  [39936/50000]
loss: 0.498783  [44928/50000]
loss: 0.462995  [31200/50000]
Epoch average loss: 0.46492210030555725
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 28 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.483441  [    0/50000]
loss: 0.462353  [ 4992/50000]
loss: 0.457516  [ 9984/50000]
loss: 0.526151  [14976/50000]
loss: 0.458532  [19968/50000]
loss: 0.456002  [24960/50000]
loss: 0.527666  [29952/50000]
loss: 0.492890  [34944/50000]
loss: 0.501283  [39936/50000]
loss: 0.411950  [44928/50000]
loss: 0.462558  [31200/50000]
Epoch average loss: 0.46603724360466003


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.471339  [    0/50000]
loss: 0.508212  [ 4992/50000]
loss: 0.439703  [ 9984/50000]
loss: 0.471296  [14976/50000]
loss: 0.525493  [19968/50000]
loss: 0.463997  [24960/50000]
loss: 0.467105  [29952/50000]
loss: 0.453285  [34944/50000]
loss: 0.510396  [39936/50000]
loss: 0.458118  [44928/50000]
loss: 0.474276  [31200/50000]
Epoch average loss: 0.46495386958122253


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.505913  [    0/50000]
loss: 0.388348  [ 4992/50000]
loss: 0.492330  [ 9984/50000]
loss: 0.413598  [14976/50000]
loss: 0.425266  [19968/50000]
loss: 0.487858  [24960/50000]
loss: 0.506644  [29952/50000]
loss: 0.504897  [34944/50000]
loss: 0.438456  [39936/50000]
loss: 0.445709  [44928/50000]
loss: 0.457414  [31200/50000]
Epoch average loss: 0.4645688235759735
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Seed 0 - Iteration 1 - Model 29 - Begin


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.537453  [    0/50000]
loss: 0.548949  [ 4992/50000]
loss: 0.475675  [ 9984/50000]
loss: 0.521718  [14976/50000]
loss: 0.478341  [19968/50000]
loss: 0.452174  [24960/50000]
loss: 0.480157  [29952/50000]
loss: 0.492783  [34944/50000]
loss: 0.476948  [39936/50000]
loss: 0.417546  [44928/50000]
loss: 0.419418  [31200/50000]
Epoch average loss: 0.4665236175060272


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.496483  [    0/50000]
loss: 0.471519  [ 4992/50000]
loss: 0.459434  [ 9984/50000]
loss: 0.431235  [14976/50000]
loss: 0.476649  [19968/50000]
loss: 0.441918  [24960/50000]
loss: 0.416509  [29952/50000]
loss: 0.487409  [34944/50000]
loss: 0.533104  [39936/50000]
loss: 0.434549  [44928/50000]
loss: 0.425829  [31200/50000]
Epoch average loss: 0.46508893370628357


  0%|          | 0/391 [00:00<?, ?it/s]

loss: 0.453289  [    0/50000]
loss: 0.441169  [ 4992/50000]
loss: 0.476270  [ 9984/50000]
loss: 0.439701  [14976/50000]
loss: 0.477623  [19968/50000]
loss: 0.436057  [24960/50000]
loss: 0.462508  [29952/50000]
loss: 0.505712  [34944/50000]
loss: 0.455471  [39936/50000]
loss: 0.495586  [44928/50000]
loss: 0.509076  [31200/50000]
Epoch average loss: 0.4643941819667816
Done!


  0%|          | 0/3473 [00:00<?, ?it/s]

Seed 0 - Iteration 1 - Model 30 - Done
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/home/shana92/.conda/envs/text/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3397, in run_code
  File "/tmp/ipykernel_30150/1420380528.py", line 44, in <cell line: 3>
    pred_scores = [sum([1 for y_hat in x if y_hat == s[idx]['label']])/len(x) \
  File "/tmp/ipykernel_30150/1420380528.py", line 44, in <listcomp>
    pred_scores = [sum([1 for y_hat in x if y_hat == s[idx]['label']])/len(x) \
  File "/tmp/ipykernel_30150/1420380528.py", line 44, in <listcomp>
    pred_scores = [sum([1 for y_hat in x if y_hat == s[idx]['label']])/len(x) \
NameError: name 's' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/shana92/.conda/envs/text/lib/python3.8/site-packages/pygments/styles/__init__.py", line 78, in get_style_by_name
ModuleNotFoundError: No module named 'pygments.styles.default'

During handling of the above exception, another exception occ

### 6. Fine-tuning GPT-2 with AFLite filtered datasets

In [None]:
# Free up some RAM
del train 
del model
del classifier

# Begin fine-tuning
for seed in AFLite_seeds:
    
    # instantiate new GPT-2 based model
    model = GPT2ForSequenceClassification.from_pretrained("gpt2", 
                                      num_labels=3,
                                      problem_type="multi_label_classification")
    model.config.pad_token_id = model.config.eos_token_id # specify pad_token used by tokenizer
    
    # set up data loader
    dataloader = torch.utils.data.DataLoader(filtered_datasets[seed], batch_size=batch_size, \
                                             shuffle=True, collate_fn=data_collator)
    
    # set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)    
    
    # fine-tune model
    torch.save(train_classifier(model, dataloader, optimizer, device), \
              'AFLite_fine_tuned_model_seed' + str(seed) + '.pth')