In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device      
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

In [None]:
from datasets import load_dataset  # pip install datasets

# load the first 1K rows of the TREC dataset
trec = load_dataset('trec', split='train[:1000]')
trec

In [None]:
trec[0]

In [None]:
from transformers import AutoTokenizer, AutoModel  # pip install transformers

# initialize BERT tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')

# take the first 64 rows of the trec data
text = trec['text'][:64]
# tokenize text using the BERT tokenizer
tokens = tokenizer(
    text, max_length=512,
    truncation=True, padding=True,
    return_tensors='pt'
)

In [None]:
device = torch.device('cpu')
model.to(device)
tokens.to(device)
device

In [None]:
%%timeit
model(**tokens)

In [None]:
device = torch.device('mps')
model.to(device)
tokens.to(device)
device

In [None]:
%%timeit
model(**tokens)

In [None]:
from transformers import BertForSequenceClassification, BertConfig

config = BertConfig.from_pretrained('bert-base-uncased')
config.num_labels = max(trec['coarse_label'])+1  # create 6 outputs
model = BertForSequenceClassification(config).to(device)
# remember to move to MPS with .to(device)

In [None]:
# activate training mode of model
model.train()

# initialize adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=5e-5)

# begin training loop
for batch in loader:
  	# note that we move everything to the MPS device
    batch_mps = {
        'input_ids': batch['input_ids'].to(device),
        'attention_mask': batch['attention_mask'].to(device),
        'labels': batch['labels'].to(device)
    }
    # initialize calculated gradients (from prev step)
    optim.zero_grad()
    # train model on batch and return outputs (incl. loss)
    outputs = model(**batch_mps)
    # extract loss
    loss = outputs[0]
    # calculate loss for every parameter that needs grad update
    loss.backward()
    # update parameters
    optim.step()