# Filtering the SNLI Dataset using AFLite
### 1. Imports and GPU set up

In [1]:
from datasets import load_dataset
from tqdm.notebook import tqdm
from transformers import GPT2TokenizerFast, GPT2ForSequenceClassification, \
    DataCollatorWithPadding, TrainingArguments, set_seed
import torch
from torch.nn.functional import one_hot
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
set_seed(42)

### 2. Pre-Processing
- Get SNLI Dataset (Train fold) and shuffle it
- One-hot encoding for labels
- Remove instances without gold standard labels, i.e., label = -1
- Partition data 10%/90% into `warmup` and `train_`
- Tokenise warmup

In [2]:
snli_train = load_dataset('snli', split = 'train').shuffle(seed = 42)
snli_train = snli_train.filter(lambda x: x['label'] != -1).map( \
    lambda x: {'label': one_hot(torch.tensor(x['label']), 3).type(torch.float32).numpy()}, \
    batched = True)
warmup, train_ = snli_train.select(range(0, int(len(snli_train)/10))), \
    snli_train.select(range(int(len(snli_train)/10), len(snli_train)))

Reusing dataset snli (/home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)
Loading cached shuffled indices for dataset at /home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-3c48a07b49c48dd6.arrow
Loading cached processed dataset at /home/shana92/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-ed3322adf0d6443a.arrow


  0%|          | 0/550 [00:00<?, ?ba/s]

In [3]:
# padding to left because GPT2 uses last token for prediction
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2", padding_side = 'left', \
                                              padding = True, truncation = True)
tokenizer.pad_token = tokenizer.eos_token # pad with 'eos' token

In [4]:
# tokenize data
warmup = warmup.map(lambda x: tokenizer(x['premise'] + '|' + x['hypothesis']))

  0%|          | 0/54936 [00:00<?, ?ex/s]

In [5]:
# keep only needed columns, set data format to PyTorch
warmup.set_format(type = 'torch', columns = ['label', 'input_ids', 'attention_mask'])

### 3. Model:

In [6]:
# set up data collator - https://huggingface.co/docs/transformers/main_classes/data_collator
# this is a (callable) helper object that sends batches of data to the model
data_collator = DataCollatorWithPadding(tokenizer, padding = 'max_length', \
                                         return_tensors = 'pt', max_length = 120)

In [7]:
# set up GPT2
model = GPT2ForSequenceClassification.from_pretrained("gpt2", 
                                  num_labels=3,
                                  problem_type="multi_label_classification")
model.resize_token_embeddings(len(tokenizer)) # Resize input token embeddings matrix if num_tokens != config.vocab_size. - Source: HuggingFace
model.config.pad_token_id = model.config.eos_token_id # specify pad_token used by tokenizer

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# set up a dataloader (batch generator)
dataloader = torch.utils.data.DataLoader(warmup, batch_size=92, \
                                         shuffle=True, collate_fn=data_collator)

In [9]:
# move model to device
model.to(device)

GPT2ForSequenceClassification(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid

In [10]:
# set up loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [11]:
# Train
model.train()

size = len(dataloader.dataset)

for epoch in range(3):

    for batch, data in tqdm(enumerate(dataloader), total = len(dataloader)):

            # Torch requirement
            model.zero_grad()

            # Compute prediction and loss
            outputs = model(**data.to(device))
            loss = outputs[0]

            # Backpropagation
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            if batch % int(len(dataloader)/10) == 0:
                loss, current = loss.item(), batch * len(data['labels'])
                print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

  0%|          | 0/598 [00:00<?, ?it/s]

loss: 0.742174  [    0/54936]
loss: 0.640142  [ 5428/54936]
loss: 0.646829  [10856/54936]
loss: 0.635300  [16284/54936]
loss: 0.624931  [21712/54936]
loss: 0.623921  [27140/54936]
loss: 0.641142  [32568/54936]
loss: 0.632342  [37996/54936]
loss: 0.624371  [43424/54936]
loss: 0.642328  [48852/54936]
loss: 0.613522  [54280/54936]


  0%|          | 0/598 [00:00<?, ?it/s]

loss: 0.584989  [    0/54936]
loss: 0.616870  [ 5428/54936]
loss: 0.599203  [10856/54936]
loss: 0.562805  [16284/54936]
loss: 0.541916  [21712/54936]
loss: 0.531023  [27140/54936]
loss: 0.546053  [32568/54936]
loss: 0.551011  [37996/54936]
loss: 0.470505  [43424/54936]
loss: 0.532148  [48852/54936]
loss: 0.495916  [54280/54936]


  0%|          | 0/598 [00:00<?, ?it/s]

loss: 0.482832  [    0/54936]
loss: 0.497932  [ 5428/54936]
loss: 0.471682  [10856/54936]
loss: 0.446039  [16284/54936]
loss: 0.532555  [21712/54936]
loss: 0.478110  [27140/54936]
loss: 0.532668  [32568/54936]
loss: 0.454391  [37996/54936]
loss: 0.416018  [43424/54936]
loss: 0.467112  [48852/54936]
loss: 0.464416  [54280/54936]


In [12]:
torch.save(model.state_dict(), 'feature_rep.pth')