# Finetune a BERT model on the SWAG Dataset

Task Description: Involves producing several candidate answers along with a context.

Original Tutorial: https://huggingface.co/docs/transformers/tasks/multiple_choice

In [None]:
!pip install -q transformers datasets evaluate accelerate

# Load SWAG dataset

In [None]:
from datasets import load_dataset

swag = load_dataset("swag", "regular")

In [None]:
# Look at the data
import pprint
pprint.pprint(swag['train'][0])

# The text column is our model input


In [None]:
# Preprocessing
## Load Model
from transformers import AutoTokenizer

checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Preprocessing
We need to create a preprocess function that we will apply to every instance in the dataset. The preprocess function needs to:

1. Make four copies of the sent1 field and combine each of them with sent2 to recreate how a sentence starts.

2. Combine sent2 with each of the four possible sentence endings.

3. Flatten these two lists so you can tokenize them, and then unflatten them afterward so each example has a corresponding input_ids, attention_mask, and labels field.

In [None]:
ending_names = ["ending0", "ending1", "ending2", "ending3"]


def preprocess_function(examples):
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

In [None]:
# Apply preprocessing over entire dataset - batched = True process multiple elements of the datasets
tokenized_swag = swag.map(preprocess_function, batched = True)

In [None]:
tokenized_swag['train'][0]

In [None]:
# Create a batch of examples, with dynamic padding. Create the appropriate collator function


from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

# Evaluate

We want to create a `compute_metrics` function that monitors a metric during training. For this task, use the accuracy metric.

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

# Train using the Trainer API
The main training steps are:

1. Define training hyperparameters using a model specific TrainingArguments function. At the end of each epoch, the Trainer will evaluate the defined loss metric and save the training checkpoint.

2. Pass the training arguments to a Trainer function alongside the model, dataset, tokenizer, data collator.

3. Call train() to finetune the model

In [None]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained("bert-base-uncased")

In [None]:
training_args = TrainingArguments(
    output_dir="my_awesome_swag_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_swag["train"],
    eval_dataset=tokenized_swag["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
)

trainer.train()

In [None]:
trainer.save_model("swag_multiple_model")

In [None]:
# In this case, the tokenizer was not saved automatically, save it manually in the model folder for inference
tokenizer.save_pretrained("swag_multiple_model", legacy_format=False)

# Inference

Use model for inference using PyTorch

In [None]:
prompt = "France has a bread law, Le Décret Pain, with strict rules on what is allowed in a traditional baguette."
candidate1 = "The law does not apply to croissants and brioche."
candidate2 = "The law applies to baguettes."

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("swag_multiple_model")
inputs = tokenizer([[prompt, candidate1], [prompt, candidate2]], return_tensors="pt", padding=True)
labels = torch.tensor(0).unsqueeze(0)

In [None]:
from transformers import AutoModelForMultipleChoice

model = AutoModelForMultipleChoice.from_pretrained("swag_multiple_model")
outputs = model(**{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels)
logits = outputs.logits

In [None]:
predicted_class = logits.argmax().item()
predicted_class