# Fine Tune BERT for Q&A with Apple MLX

and compare to PyTorch HuggingFace implementation

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import time, math

import torch
import numpy as np

from transformers import BertTokenizerFast, BertForQuestionAnswering

In [None]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

In [None]:
from utils import load_processed_datasets
from qa import load_model_tokenizer, batch_iterate, loss_fn, eval_fn

# file from mlx repo
from model_mlx import load_model

# Load

In [None]:
bert_model = "bert-base-uncased"

In [None]:
# # for Bert()
# batch = ["This is an example of BERT working on MLX."]
# tokens = tokenizer(batch, return_tensors="mlx", padding=True)
# output, pooled = model(**tokens)

In [None]:
mlx_weights_path = "weights/bert-base-uncased.npz"
model, tokenizer = load_model_tokenizer(hf_model=bert_model,
                                        mlx_weights_path=mlx_weights_path)

In [None]:
train_ds, valid_ds, test_ds = load_processed_datasets(filter_size=100,
                                            model_max_length=tokenizer.model_max_length, tokenizer=tokenizer)

train_ds.shape, valid_ds.shape, test_ds.shape

In [None]:
batch = next(batch_iterate(train_ds, batch_size=3))
batch.keys()

In [None]:
# TODO: mx.array() in preprocess_tokenize_function() ????
input_ids, token_type_ids, attention_mask, start_positions, end_positions = map(
    mx.array,
    (batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['start_positions'], batch['end_positions'])
)

# input_ids = mx.expand_dims(input_ids, 0)
# token_type_ids = mx.expand_dims(token_type_ids, 0)
# attention_mask = mx.expand_dims(attention_mask, 0)

input_ids.shape, token_type_ids.shape, attention_mask.shape, input_ids.mean()

In [None]:
start_positions.shape, end_positions.shape

# Train

follow transformer_lm/main/py 

In [None]:
len(valid_ds)

### Train on single batch

In [None]:
optimizer = optim.AdamW(learning_rate=1e-5)

In [None]:
batch.keys()

In [None]:
start_logits, end_logits = model(
    input_ids=input_ids,
    token_type_ids=token_type_ids,
    attention_mask=attention_mask,
    start_positions=start_positions,
    end_positions=end_positions) 

a = nn.losses.cross_entropy(start_logits, start_positions)
b = nn.losses.cross_entropy(end_logits, end_positions)

a.shape, b.shape

In [None]:
a, b

In [None]:
loss_fn(model, input_ids, token_type_ids, attention_mask, start_positions, end_positions, reduce=True)

In [None]:
loss_fn(model, input_ids, token_type_ids, attention_mask, start_positions, end_positions, reduce=False)

In [None]:
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

In [None]:
loss, grads = loss_and_grad_fn(model, input_ids, token_type_ids, attention_mask, start_positions, end_positions)

# loss value, and gradients for model's trainable parameters
loss.item(), grads.keys()

### Train on full dataset

In [None]:
from functools import partial

In [None]:
# use args.
num_iters = 100
batch_size = 4
steps_per_report = 1
steps_per_eval = 1
n_epoch = 1

In [None]:
state = [model.state, optimizer.state]

# edit in qa.py
# need here because of state variable
@partial(mx.compile, inputs=state, outputs=state)
def step(input_ids, token_type_ids, attention_mask, start_positions, end_positions):
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    loss, grads = loss_and_grad_fn(
        model, input_ids, token_type_ids, attention_mask, start_positions, end_positions)
    optimizer.update(model, grads)
    return loss

In [None]:


train_iterator = batch_iterate(train_ds, batch_size=16)
losses = []
tic = time.perf_counter()

for it, batch in zip(range(num_iters), train_iterator):
    print(it)
    input_ids, token_type_ids, attention_mask, start_positions, end_positions = map(
        mx.array,
        (batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['start_positions'], batch['end_positions'])
    )

    loss = step(input_ids, token_type_ids, attention_mask, start_positions, end_positions)
    mx.eval(state)
    losses.append(loss.item())
    if (it + 1) % steps_per_report == 0:
        train_loss = np.mean(losses)
        toc = time.perf_counter()
        print(
            f"Iter {it + 1}: Train loss {train_loss:.3f}, "
            f"It/sec {steps_per_report / (toc - tic):.3f}"
        )
        losses = []
        tic = time.perf_counter()
    if (it + 1) % steps_per_eval == 0:
        val_loss = eval_fn(valid_ds, model, batch_size=batch_size)
        toc = time.perf_counter()
        print(
            f"Iter {it + 1}: "
            f"Val loss {val_loss:.3f}, "
            f"Val ppl {math.exp(val_loss):.3f}, "
            f"Val took {(toc - tic):.3f}s, "
        )
        tic = time.perf_counter()

# Misc

In [None]:
# MLX
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from model import Bert

bert_model = "bert-base-uncased"
mlx_weights_path = "weights/bert-base-uncased.npz"

config = AutoConfig.from_pretrained(bert_model)
model = Bert(config, add_pooler=True)
tokenizer = AutoTokenizer.from_pretrained(bert_model)

In [None]:
# PyTorch HF
pre_train_model = bert_model
tokenizerhf = BertTokenizerFast.from_pretrained(pre_train_model)
modelhf = BertForQuestionAnswering.from_pretrained(pre_train_model)

# MLX

In [None]:
import mlx.core as mx
from mlx.utils import tree_map

In [None]:
batch = ["This is an example of BERT working on MLX."]

In [None]:
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
tokens

In [None]:
tokens['input_ids'].shape, tokens['token_type_ids'].shape

In [None]:
tokens = tokenizer(batch, return_tensors="mlx", padding=True)
tokens

In [None]:
tokens['input_ids'].shape

In [None]:
tokens['input_ids'].shape, tokens['input_ids']

# HF

In [None]:
from transformers import BertModel as BertModelHF

In [None]:
bert_model = "bert-base-uncased"
config = AutoConfig.from_pretrained(bert_model)

model = BertModelHF(config)
tokenizer = AutoTokenizer.from_pretrained(bert_model)

In [None]:
batch2 = tokenizer(batch, return_tensors="pt", padding=True)
batch2

In [None]:
device = torch.device("mps")

input_ids = batch2['input_ids'].to(device)
token_type_ids = batch2['token_type_ids'].to(device)
attention_mask = batch2['attention_mask'].to(device)

model.to(device)
model.train()

with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask,
                    token_type_ids=token_type_ids)
outputs.keys()

In [None]:
# outputs

In [None]:
# this seems to be the main output
outputs[0]

In [None]:
self_qa_output = torch.nn.Linear(768, 2)
self_qa_output.to(device)

sequence_output = outputs[0]

logits = self_qa_output(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
# start_logits = start_logits.squeeze(-1)
# end_logits = end_logits.squeeze(-1)


In [None]:
logits.shape, start_logits.shape, end_logits.shape

In [None]:
logits

# Appendix

In [None]:
model

In [None]:
modelhf

### Pytorch model, input, output example


```python
>>> model
BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
  )
  (qa_outputs): Linear(in_features=768, out_features=2, bias=True)
)
>>> 
batch = next(iter(train_dataloader))
>>> batch
{'input_ids': tensor([[  101,  1996, 13546,  ...,     0,     0,     0],
        [  101,  2129,  2116,  ...,     0,     0,     0],
        [  101, 19739,  6862,  ...,     0,     0,     0],
        ...,
        [  101,  2129,  2116,  ...,     0,     0,     0],
        [  101,  1996, 26129,  ...,     0,     0,     0],
        [  101,  1999,  2054,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'start_positions': tensor([ 81, 267,  12, 149,  58,  80,  58,  74, 135,  98,  84, 107,  86,  28,
         37,  57]), 'end_positions': tensor([ 83, 269,  15, 149,  61,  82,  63,  74, 142, 100,  86, 107,  88,  28,
         38,  57])}
>>> input_ids = batch['input_ids'].to(device)
>>> attention_mask = batch['attention_mask'].to(device)
>>> start_positions = batch['start_positions'].to(device)
>>> end_positions = batch['end_positions'].to(device)
>>> outputs = model(input_ids, attention_mask=attention_mask,
...                             start_positions=start_positions, end_positions=end_positions)

>>> outputs
QuestionAnsweringModelOutput(loss=tensor(6.3726, device='mps:0', grad_fn=<DivBackward0>), start_logits=tensor([[-0.8564, -0.1199,  0.0125,  ...,  0.2456,  0.3330,  0.3860],
        [-0.7941, -0.1668,  0.2791,  ...,  0.1567,  0.1839,  0.1895],
        [-0.3875, -0.0241,  0.4691,  ...,  0.1442,  0.2343,  0.3314],
        ...,
        [-0.5861,  0.0238,  0.4390,  ...,  0.4266,  0.4844,  0.4618],
        [-0.8158, -0.1809, -0.2249,  ...,  0.3569,  0.3794,  0.3319],
        [-0.6281, -0.0430,  0.1375,  ...,  0.2855,  0.2233,  0.1889]],
       device='mps:0', grad_fn=<CloneBackward0>), end_logits=tensor([[ 0.3384, -0.1375, -0.0802,  ..., -0.0332, -0.0099, -0.0567],
        [ 0.2627, -0.2120, -0.3897,  ...,  0.0227, -0.0094,  0.0472],
        [ 0.6734,  0.2672, -0.0346,  ...,  0.0886,  0.1867,  0.1734],
        ...,
        [ 0.4112, -0.1477, -0.1991,  ...,  0.0158,  0.0874,  0.1071],
        [ 0.5058,  0.0427,  0.2843,  ...,  0.0442,  0.0791,  0.0314],
        [ 0.5014,  0.1437, -0.2467,  ...,  0.1379,  0.0093, -0.0738]],
       device='mps:0', grad_fn=<CloneBackward0>), hidden_states=None, attentions=None)
>>> outputs['start_logits']
tensor([[-0.8564, -0.1199,  0.0125,  ...,  0.2456,  0.3330,  0.3860],
        [-0.7941, -0.1668,  0.2791,  ...,  0.1567,  0.1839,  0.1895],
        [-0.3875, -0.0241,  0.4691,  ...,  0.1442,  0.2343,  0.3314],
        ...,
        [-0.5861,  0.0238,  0.4390,  ...,  0.4266,  0.4844,  0.4618],
        [-0.8158, -0.1809, -0.2249,  ...,  0.3569,  0.3794,  0.3319],
        [-0.6281, -0.0430,  0.1375,  ...,  0.2855,  0.2233,  0.1889]],
       device='mps:0', grad_fn=<CloneBackward0>)
>>> outputs['start_logits'].shape
torch.Size([16, 512])
>>> outputs.keys()
odict_keys(['loss', 'start_logits', 'end_logits'])

```