# Dual Objective:
- A) Use AFLite to greedily solve for $\text{arg min}_{S \subset \mathcal{D}, ~|S| \geq n}\mathcal{R}(\Phi, ~S, ~\mathcal{M})$
- B) Fine-tune GPT-2 with the resulting filtered dataset

### 1. Imports and Global Settings

In [None]:
from datasets import load_dataset, disable_caching
from tqdm.notebook import tqdm
from transformers import GPT2TokenizerFast, DataCollatorWithPadding, set_seed
import torch
from torch.nn.functional import one_hot
import copy
import numpy as np
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
disable_caching()

### 2. Pre-Processing
- Get SNLI Dataset (Train fold) and shuffle it using the same seed as used for obtaining GPT-2 based Feature Representation (see notebook [Filtering_Part1.ipynb](https://github.com/shashiniyer/adversarial_nli_gpt2/blob/main/Filtering_Part1.ipynb))
- Remove instances without gold standard labels, i.e., label = -1
- One-hot encoding for labels
- Partition data 10%/90%; use the 90% as `train`
- Tokenise train

In [None]:
snli_train = load_dataset('snli', split = 'train').shuffle(seed = 42)
snli_train = snli_train.filter(lambda x: x['label'] != -1).map( \
    lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
    batched = True)
train = snli_train.select(range(int(len(snli_train)/10), len(snli_train)))

In [None]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [None]:
# tokenize data
train = train.map(lambda x: tokenizer(x['premise'] + '|' + x['hypothesis']))

In [None]:
# keep only needed columns, set data format to PyTorch
train.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

### 3. Set up inputs for AFLite

In [None]:
# load in the feature representation, Phi, with linear layer attached
model = torch.load('feature_rep.pth')

# move model to CPU
model.to('cpu')

# freeze all layers except the last
num_layers = sum(1 for _ in model.parameters())
for idx, param in enumerate(model.parameters()):
    
    if idx != num_layers - 1:
        
        # freeze
        param.requires_grad = False

In [None]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 512)

In [None]:
# hyper-parameters - set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108
m = 64
n = 182000
t = 50000
k = 10000
tau = 0.75
AFLite_seeds = [0, 1, 2, 3, 4]

In [None]:
# hyper-parameters for model training within AFLite implementation
batch_size = 16 # constrained by GPU memory
lr = 1e-5 # also set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108

### 4. Utility Functions for AFLite

In [None]:
def train_classifier(classifier, dataloader, optimizer, device, npochs = 3):
    
    # initialise
    curr_loss = 0
    #prev_loss = float("inf")
    
    # move classifier to device and set it in train mode
    classifier.to(device)
    classifier.train()
    
    # cache training data size
    train_data_size = len(dataloader.dataset)
    
    # train until convergence
    # while abs(prev_loss - curr_loss) > 1e-5:
    
    # train for nepochs; nepochs = 3 in Le et al. (2020) - https://arxiv.org/abs/2002.04108
    for _ in range(npochs):
        
        # reset losses
        #prev_loss = curr_loss
        curr_loss = 0
        
        for batch, data in tqdm(enumerate(dataloader), total = len(dataloader)):

            # Torch requirement
            classifier.zero_grad()

            # Compute prediction and loss
            outputs = classifier(**data.to(device))
            batch_loss = outputs[0]

            # Backpropagation
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), 1.0)
            optimizer.step()
            
            # Log
            if batch % int(len(dataloader)/10) == 0:
                batch_loss, current = batch_loss.item(), batch * len(data['labels'])
                print(f"loss: {batch_loss:>7f}  [{current:>5d}/{train_data_size:>5d}]")
            
            # Add up batch-loss
            curr_loss += batch_loss * len(data['labels'])
        
        # Average out curr_loss
        curr_loss /= train_data_size
        print(f'Epoch average loss: {curr_loss}')
        
    # set classifier to eval model and move it to CPU
    classifier.eval()
    classifier.to('cpu')

    # print done
    print('Done!')
    
    return(classifier)

In [None]:
def predict(classifier, dataloader, device):
    
    classifier.to(device) # move classifier to device
    
    for batch, data in tqdm(enumerate(dataloader), total = len(dataloader)):
        
        with torch.no_grad():
            batch_preds = classifier(**data.to(device)).logits.argmax(1)
        
        #print(batch_preds)
        
        if batch == 0:
            
            preds = batch_preds
        
        else:
            
            preds = torch.cat((preds, batch_preds))
    
    classifier.to('cpu') # move classifier to cpu
    
    return(preds.numpy())

In [None]:
def select_k(pred_scores, tau, k):
    
    """
        Select up to k instances with the highest predictability scores subject to score >= tau
    """
    
    k_idx = [] # initialise
    pred_scores = np.array(pred_scores) # format as numpy array for subsequent steps
    sorted_descended_idx = np.argsort(-pred_scores) # sorting because we want to select instances with high pred_score
    sorted_pred_scores = pred_scores[sorted_descended_idx] # cache
    
    for idx in sorted_descended_idx:
        
        if pred_scores[idx] < tau:
            
            return(k_idx)
        
        elif idx == sorted_descended_idx.shape[0]:
            
            k_idx += idx
        
        elif pred_scores[idx] == pred_scores[idx + 1]:
            
            candidates = sorted_descended_idx[sorted_pred_scores == pred_scores[idx]]
            max_candidates_to_select = k - len(k_idx)
            
            if max_candidates_to_select >= candidates.shape[0]:
                
                k_idx += candidates.tolist()
            
            else:
                
                # randomly select a subset of `max_candidates_to_select` candidates
                k_idx += np.random.default_rng(42).choice(candidates, max_candidates_to_select, replace = False)
        
        else:
            
            k_idx += idx
        
        if len(k_idx) == k:
            
            return(k_idx)

### 5.  AFLite Procedure

In [None]:
filtered_datasets = {}

for seed in AFLite_seeds:
    
    # first step of AFLite; initialise S
    S = copy.deepcopy(train)

    while len(S) > n:
        
        # initialise multiset for Out-Of-Sample predictions
        E = {x: [] for x in range(len(S))}

        for j in range(m):
            
            # randomly partition S into (S\T_j, T_j) s.t. |S\T_j| = t
            tr_idx = set(np.random.default_rng(j).choice(np.arange(len(S)), t, replace = False))
            te_idx = set(range(len(S))) - tr_idx
            tr, te = S.select(tr_idx), S.select(te_idx)
            
            print(str(seed) + '_' + str(j) + '_' + str(len(tr_idx)))
            #break
                        
            # train classifier on S\T_j, i.e. tr
            classifier = copy.deepcopy(model)
            dataloader = torch.utils.data.DataLoader(tr, batch_size=batch_size, \
                                 shuffle=True, collate_fn=data_collator)
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()), lr = lr)
            trained_classifier = train_classifier(classifier, dataloader, optimizer, device)
            print('Done!')
            
            # for all instances i in T_j, add predictions to E(i)
            te_dataloader = torch.utils.data.DataLoader(te, batch_size=batch_size, collate_fn=data_collator)
            preds = predict(trained_classifier, te_dataloader, device)
            
            for pred_idx, data_idx in enumerate(te_idx): # there are as many predictions as test instances
                
                E[data_idx] += [preds[pred_idx]]
        
        # for all instances in S, compute predictability score
        pred_scores = [sum([1 for y_hat in x if y_hat == s[idx]['label']])/len(x) \
                       for idx, x in enumerate(E.values())]
        
        
        # Select up to k instances with the highest predictability scores subject to score >= tau
        selected_idx = select_k(pred_scores, tau, k)
        
        # filter
        S = S.select(set(range(len(S))) - selected_idx)
        
        # early stopping
        if len(selected_idx) < k:
            
            break
    
    # cache file
    filtered_datasets[seed] = S
    
    # print number of instances in S, for creating random baseline
    print(f'Number of instances in S (seed {seed}): {len(S)}')

### 6. Fine-tuning GPT-2 with AFLite filtered datasets

In [None]:
# Free up some RAM
del train 
del model
del classifier

# Begin fine-tuning
for seed in AFLite_seeds:
    
    # instantiate GPT-2 based model
    model = GPT2ForSequenceClassification.from_pretrained("gpt2", 
                                      num_labels=3,
                                      problem_type="multi_label_classification")
    model.config.pad_token_id = model.config.eos_token_id # specify pad_token used by tokenizer
    
    # set up data loader
    dataloader = torch.utils.data.DataLoader(filtered_datasets[seed], batch_size=batch_size, \
                                             shuffle=True, collate_fn=data_collator)
    
    # set up optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr = lr)    
    
    # fine-tune model
    torch.save(train_classifier(model, dataloader, optimizer, device, npochs = 3), \
              'AFLite_fine_tuned_model_seed' + str(seed) + '.pth')