# Training a Baseline Model: GPT-2 fine-tuned for NLI using a random subset of size 190k of SNLI train data

### 1. Imports and Global Settings

In [1]:
from datasets import load_dataset, disable_caching
from transformers import GPT2ForSequenceClassification, GPT2TokenizerFast, DataCollatorWithPadding, set_seed
import torch
from torch.nn.functional import one_hot
import sys
sys.path.append('..')
from utils_ import tokenize, train_classifier
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)
disable_caching()

### 2. Pre-Processing
- Get SNLI Dataset (Train fold) and shuffle it
- Remove instances without gold standard labels, i.e., label = -1
- Subset data to get 190k samples
- One-hot encoding for labels
- Tokenise data

In [2]:
train = load_dataset('snli', split = 'train').shuffle(42)
train = train.filter(lambda x: x['label'] != -1).map( \
    lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
    batched = True)
train = train.select(range(190000))

Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/551 [00:00<?, ?ba/s]

  0%|          | 0/550 [00:00<?, ?ba/s]

In [3]:
# set up tokeniser
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2-medium", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [4]:
# tokenize data
train = train.map(lambda x: tokenize(tokenizer, x['premise'] + '|' + x['hypothesis']))
len_bef_exclusion = len(train)

# exclude instances with > 128 tokens
train = train.filter(lambda x: x['exclude'] == False)
len_aft_exclusion = len(train)

# print message if instances were in fact excluded
if len_bef_exclusion - len_aft_exclusion > 0:
    
    print(f'{len_bef_exclusion - len_aft_exclusion} ' + \
          f'({(len_bef_exclusion/len_aft_exclusion - 1)*100:>2f}%) sequences excluded')

  0%|          | 0/190000 [00:00<?, ?ex/s]

  0%|          | 0/190 [00:00<?, ?ba/s]

In [5]:
# keep only needed columns, set data format to PyTorch
train.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

### 3. Model

In [6]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 128)

In [7]:
# hyper-parameters for model training
batch_size = 32 # constrained by GPU memory
lr = 1e-5 # set to match Le et al. (2020) - https://arxiv.org/abs/2002.04108

In [8]:
# set up a dataloader (batch generator)
dataloader = torch.utils.data.DataLoader(train, batch_size=batch_size, \
                                 shuffle=True, collate_fn=data_collator) # batch size constrained by GPU memory

In [9]:
# instantiate GPT2 classifier model
model = GPT2ForSequenceClassification.from_pretrained("gpt2-medium", 
                                  num_labels=3,
                                  problem_type="multi_label_classification")
model.config.pad_token_id = model.config.eos_token_id # specify pad_token used by tokenizer

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2-medium and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# set up optimizer (loss function in-built)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [11]:
# Train and Save model
torch.save(train_classifier(model, dataloader, optimizer, device), 'baseline_random_190k.pth')

  0%|          | 0/5938 [00:00<?, ?it/s]

loss: 1.086189  [    0/190000]
loss: 0.638672  [18976/190000]
loss: 0.619112  [37952/190000]
loss: 0.638432  [56928/190000]
loss: 0.532313  [75904/190000]
loss: 0.479692  [94880/190000]
loss: 0.234584  [113856/190000]
loss: 0.217729  [132832/190000]
loss: 0.327121  [151808/190000]
loss: 0.340143  [170784/190000]
loss: 0.316959  [189760/190000]
Epoch average loss: 0.4904515743255615


  0%|          | 0/5938 [00:00<?, ?it/s]

loss: 0.381749  [    0/190000]
loss: 0.283529  [18976/190000]
loss: 0.255294  [37952/190000]
loss: 0.273739  [56928/190000]
loss: 0.295765  [75904/190000]
loss: 0.202576  [94880/190000]
loss: 0.272400  [113856/190000]
loss: 0.310980  [132832/190000]
loss: 0.269270  [151808/190000]
loss: 0.245539  [170784/190000]
loss: 0.322159  [189760/190000]
Epoch average loss: 0.3022118806838989


  0%|          | 0/5938 [00:00<?, ?it/s]

loss: 0.239778  [    0/190000]
loss: 0.215871  [18976/190000]
loss: 0.243856  [37952/190000]
loss: 0.352319  [56928/190000]
loss: 0.343499  [75904/190000]
loss: 0.222338  [94880/190000]
loss: 0.098706  [113856/190000]
loss: 0.187588  [132832/190000]
loss: 0.304597  [151808/190000]
loss: 0.275192  [170784/190000]
loss: 0.314682  [189760/190000]
Epoch average loss: 0.2570127844810486
Done!
