# Training a Baseline Model: GPT-2 fine-tuned for NLI using all of SNLI train data

### 1. Imports and Global Settings

In [None]:
from datasets import load_dataset, disable_caching
from tqdm.notebook import tqdm
from transformers import GPT2TokenizerFast, DataCollatorWithPadding, set_seed
import torch
from torch.nn.functional import one_hot
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
disable_caching()

### 2. Pre-Processing
- Get SNLI Dataset (Train fold)
- Remove instances without gold standard labels, i.e., label = -1
- One-hot encoding for labels
- Tokenise data

In [None]:
train = load_dataset('snli', split = 'train')
train = train.filter(lambda x: x['label'] != -1).map( \
    lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
    batched = True)

In [None]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [None]:
# tokenize data
train = train.map(lambda x: tokenizer(x['premise'] + '|' + x['hypothesis']))

In [None]:
# keep only needed columns, set data format to PyTorch
train.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

### 3. Model

In [None]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 512)

In [None]:
# hyper-parameters for model training
batch_size = 16 # constrained by GPU memory
lr = 1e-5 # also set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108

In [None]:
# set up a dataloader (batch generator)
dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, \
                                 shuffle=True, collate_fn=data_collator) # batch size constrained by GPU memory

In [None]:
# move model to device
model.to(device)

In [None]:
# set up optimizer (loss function in-built)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
# Train
model.train()

size = len(dataloader.dataset)

for epoch in range(3):

    for batch, data in tqdm(enumerate(dataloader), total = len(dataloader)):

            # Torch requirement
            model.zero_grad()

            # Compute prediction and loss
            outputs = model(**data.to(device))
            loss = outputs[0]

            # Backpropagation
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if batch % int(len(dataloader)/10) == 0:
                loss, current = loss.item(), batch * len(data['labels'])
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
# Save model
torch.save(model, 'baseline_unfiltered.pth')

In [None]:
# check last batch loss
loss.item()