# Training Bert with Torch MPS

[ipynb file](https://github.com/jamescalam/pytorch-mps/blob/main/code/01mps_training_bert.ipynb)

In [2]:
import platform
platform.platform()

'macOS-13.0-arm64-arm-64bit'

In [3]:
import torch
torch.has_mps

True

In [4]:
from datasets import load_dataset  # pip install datasets

# load the first 1K rows of the TREC dataset
trec = load_dataset('trec', split='train[:1000]')
trec

Downloading builder script:   0%|          | 0.00/5.09k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading and preparing dataset trec/default to /Users/samgreen/.cache/huggingface/datasets/trec/default/2.0.0/f2469cab1b5fceec7249fda55360dfdbd92a7a5b545e91ea0f78ad108ffac1c2...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/336k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.4k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5452 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

Dataset trec downloaded and prepared to /Users/samgreen/.cache/huggingface/datasets/trec/default/2.0.0/f2469cab1b5fceec7249fda55360dfdbd92a7a5b545e91ea0f78ad108ffac1c2. Subsequent calls will reuse this data.


Dataset({
    features: ['text', 'coarse_label', 'fine_label'],
    num_rows: 1000
})

In [10]:
trec[0]


{'text': 'How did serfdom develop in and then leave Russia ?',
 'coarse_label': 2,
 'fine_label': 26}

In [6]:
from transformers import AutoTokenizer, AutoModel  # pip install transformers

# initialize BERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

# take the first 64 rows of the trec data
text = trec['text'][:64]
# tokenize text using the BERT tokenizer
tokens = tokenizer(
    text, max_length=512,
    truncation=True, padding=True,
    return_tensors='pt'
)

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
tokens[0].ids

[101,
 2129,
 2106,
 14262,
 2546,
 9527,
 4503,
 1999,
 1998,
 2059,
 2681,
 3607,
 1029,
 102,
 0,
 0,
 0,
 0,
 0,
 0,
 0]

In [7]:
device = torch.device('mps')
model.to(device)
tokens.to(device)
device

device(type='mps')

In [8]:
%%timeit
model(**tokens)

122 ms ± 32.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
# one-hot encode the labels
import numpy as np

# initialize array to be used
labels = np.zeros(
    (len(trec), max(trec['coarse_label'])+1)
)
# one-hot encode
labels[np.arange(len(trec)), trec['coarse_label']] = 1
labels[:5]

array([[0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0.]])

In [16]:
labels = torch.Tensor(labels)


In [18]:
class TrecDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, labels):
        self.tokens = tokens
        self.labels = labels

    def __getitem__(self, idx):
        input_ids = self.tokens[idx].ids
        attention_mask = self.tokens[idx].attention_mask
        labels = self.labels[idx]
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask),
            'labels': torch.tensor(labels)
        }

    def __len__(self):
        return len(self.labels)

dataset = TrecDataset(tokens, labels)

In [19]:
loader = torch.utils.data.DataLoader(
    dataset, batch_size=64
)

In [11]:
from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = max(trec['coarse_label'])+1  # create 6 outputs
model = BertForSequenceClassification(config).to(device)
# remember to move to MPS with .to(device)

In [20]:
for param in model.bert.parameters():
    param.requires_grad = False

In [21]:
# activate training mode of model
model.train()

# initialize adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=5e-5)

# begin training loop
for batch in loader:
  	# note that we move everything to the MPS device
    batch_mps = {
        'input_ids': batch['input_ids'].to(device),
        'attention_mask': batch['attention_mask'].to(device),
        'labels': batch['labels'].to(device)
    }
    # initialize calculated gradients (from prev step)
    optim.zero_grad()
    # train model on batch and return outputs (incl. loss)
    outputs = model(**batch_mps)
    # extract loss
    loss = outputs[0]
    # calculate loss for every parameter that needs grad update
    loss.backward()
    # update parameters
    optim.step()

  'labels': torch.tensor(labels)


IndexError: list index out of range