Before getting into the model and fine-tuning, we need to download some data to fine-tune our model with...

In [None]:
from datasets import load_dataset

# load the first 1K rows of the TREC dataset
trec = load_dataset('trec', split='train[:1000]')
trec

In [None]:
trec[0]

There are also a few data preparation steps we need to perform. We need to tokenize our input text `text`, one-hot encode our labels `label-coarse`, and then place these together in a dataset and dataloader.

For tokenization we need to use a *tokenizer*, we will use the `bert-base-uncased` tokenizer from HF.

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# we have a small dataset so we can tokenize everything at once
# tokenize everything
tokens = tokenizer(
    trec['text'], max_length=512,
    truncation=True, padding='max_length'
)

This returns a list of encoding objects

In [None]:
tokens[:2]

And we access individual components using (for example):

In [None]:
tokens[0].ids

Next we one-hot encode our labels.

In [None]:
import numpy as np

# initialize array to be used
labels = np.zeros(
    (len(trec), max(trec['label-coarse'])+1)
)
# one-hot encode
labels[np.arange(len(trec)), trec['label-coarse']] = 1
labels[:5]

In [None]:
import torch

labels = torch.Tensor(labels)

Now we're ready to create the dataset object.

In [None]:
class TrecDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, idx):
        input_ids = self.tokens[idx].ids
        attention_mask = self.tokens[idx].attention_mask
        labels = self.labels[idx]
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask),
            'labels': torch.tensor(labels)
        }

    def __len__(self):
        return len(self.labels)

dataset = TrecDataset(tokens, labels)

In [None]:
loader = torch.utils.data.DataLoader(
    dataset, batch_size=64
)

Now let's try training a BERT model, we'll use this and our TREC data to compare inference time on CPU vs MPS.

In [None]:
from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = max(trec['label-coarse'])+1
model = BertForSequenceClassification(config)

Fine-tuning the entire BERT model on first-gen M1 Mac is not going to work, but we can fine-tune the classification head, so let's test that by freezing all BERT layer parameters. Leaving fine-tuning to just to final few classification layers.

In [None]:
for param in model.bert.parameters():
    param.requires_grad = False

Training prep

In [None]:
# activate training mode of model
model.train()

# initialize adam optimizer with weight decay
optim = torch.optim.Adam(model.parameters(), lr=5e-5)

In [None]:
from time import time
from tqdm.auto import tqdm

loop_time = []

# setup loop (using tqdm for the progress bar)
loop = tqdm(loader, leave=True)
for batch in loop:
    t0 = time()
    # initialize calculated gradients (from prev step)
    optim.zero_grad()
    # train model on batch and return outputs (incl. loss)
    outputs = model(**batch)
    # extract loss
    loss = outputs[0]
    # calculate loss for every parameter that needs grad update
    loss.backward()
    # update parameters
    optim.step()
    loop_time.append(time()-t0)
    # print relevant info to progress bar
    loop.set_postfix(loss=loss.item())

In [None]:
loop_time

---