# Fine Tune BERT for Q&A with Apple MLX

and compare to PyTorch HuggingFace implementation

In [2]:
%load_ext autoreload
%autoreload 2

In [111]:
import time, math
from functools import partial

import torch
import numpy as np

from transformers import BertTokenizerFast, BertForQuestionAnswering

In [4]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

In [65]:
from utils import load_processed_squad
from qa import load_model_tokenizer, batch_iterate, loss_fn, eval_fn, build_parser

# file from mlx repo
from model_mlx import load_model

# Load

In [86]:
parser = build_parser()
# pass empty string in jupyter
args = parser.parse_args("")

args.model_str, args.dataset_size

('bert-base-uncased', 1000)

In [84]:
args.load_weights, args.save_weights

('weights/bert-base-uncased.npz', 'weights/tmp-fine-tuned.npz')

In [7]:
# # for Bert()
# batch = ["This is an example of BERT working on MLX."]
# tokens = tokenizer(batch, return_tensors="mlx", padding=True)
# output, pooled = model(**tokens)

In [145]:
model, tokenizer = load_model_tokenizer(hf_model=args.model_str,
                                        weights_pretrain_path=args.load_weights)



In [87]:
train_ds, valid_ds, test_ds = load_processed_squad(filter_size=args.dataset_size,
                                            model_max_length=tokenizer.model_max_length, tokenizer=tokenizer)

train_ds.shape, valid_ds.shape, test_ds.shape

Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 2177.07 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:00<00:00, 4181.79 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:00<00:00, 4051.47 examples/s]


((500, 5), (250, 5), (250, 5))

In [146]:
batch = next(batch_iterate(train_ds, batch_size=10))
batch.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])

In [147]:
# TODO: mx.array() in preprocess_tokenize_function() ????
input_ids, token_type_ids, attention_mask, start_positions, end_positions = map(
    mx.array,
    (batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['start_positions'], batch['end_positions'])
)

# input_ids = mx.expand_dims(input_ids, 0)
# token_type_ids = mx.expand_dims(token_type_ids, 0)
# attention_mask = mx.expand_dims(attention_mask, 0)

input_ids.shape, token_type_ids.shape, attention_mask.shape, input_ids.mean()

((10, 512), (10, 512), (10, 512), array(1263.82, dtype=float32))

In [148]:
start_positions.shape, end_positions.shape

((10,), (10,))

# Basics

### Decode

In [151]:
# currently, this requires non-mlx inputs
tokenizer.decode(np.array(input_ids[0]).flatten())

'[CLS] how much money did the mrs. carter show limited edition fragrance make? [SEP] beyonce has worked with tommy hilfiger for the fragrances true star ( singing a cover version of " wishing on a star " ) and true star gold ; she also promoted emporio armani\'s diamonds fragrance in 2007. beyonce launched her first official fragrance, heat in 2010. the commercial, which featured the 1956 song " fever ", was shown after the water shed in the united kingdom as it begins with an image of beyonce appearing to lie naked in a room. in february 2011, beyonce launched her second fragrance, heat rush. beyonce\'s third fragrance, pulse, was launched in september 2011. in 2013, the mrs. carter show limited edition version of heat was released. the six editions of heat are the world\'s best - selling celebrity fragrance line, with sales of over $ 400 million. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

# Train

follow transformer_lm/main/py 

In [35]:
len(valid_ds)

250

### Train on single batch

In [36]:
optimizer = optim.AdamW(learning_rate=1e-5)

In [37]:
batch.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])

In [38]:
start_logits, end_logits = model(
    input_ids=input_ids,
    token_type_ids=token_type_ids,
    attention_mask=attention_mask,
    start_positions=start_positions,
    end_positions=end_positions) 

a = nn.losses.cross_entropy(start_logits, start_positions)
b = nn.losses.cross_entropy(end_logits, end_positions)

a.shape, b.shape

((10,), (10,))

In [39]:
a, b

(array([6.27944, 6.1827, 6.67555, ..., 6.17367, 6.08465, 6.76645], dtype=float32),
 array([5.7783, 6.2521, 6.27144, ..., 5.73106, 6.10344, 6.12746], dtype=float32))

In [40]:
loss_fn(model, input_ids, token_type_ids, attention_mask, start_positions, end_positions, reduce=True)

array(6.18729, dtype=float32)

In [41]:
loss_fn(model, input_ids, token_type_ids, attention_mask, start_positions, end_positions, reduce=False)

array([6.02887, 6.2174, 6.4735, ..., 5.95237, 6.09405, 6.44696], dtype=float32)

In [42]:
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

In [43]:
loss, grads = loss_and_grad_fn(model, input_ids, token_type_ids, attention_mask, start_positions, end_positions)

# loss value, and gradients for model's trainable parameters
loss.item(), grads.keys()

(6.187294006347656, dict_keys(['model', 'qa_output']))

### Train on full dataset

In [44]:
# use args.
num_iters = 20
batch_size = 10
steps_per_report = 10
steps_per_eval = 10
n_epoch = 1

In [45]:
state = [model.state, optimizer.state]

# edit in qa.py
# need here because of state variable
@partial(mx.compile, inputs=state, outputs=state)
def step(input_ids, token_type_ids, attention_mask, start_positions, end_positions):
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    loss, grads = loss_and_grad_fn(
        model, input_ids, token_type_ids, attention_mask, start_positions, end_positions)
    optimizer.update(model, grads)
    return loss

In [47]:
train_iterator = batch_iterate(train_ds, batch_size=batch_size)
losses = []
tic = time.perf_counter()

for it, batch in zip(range(num_iters), train_iterator):
    input_ids, token_type_ids, attention_mask, start_positions, end_positions = map(
        mx.array,
        (batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['start_positions'], batch['end_positions'])
    )

    loss = step(input_ids, token_type_ids, attention_mask, start_positions, end_positions)
    mx.eval(state)
    losses.append(loss.item())
    if (it + 1) % steps_per_report == 0:
        train_loss = np.mean(losses)
        toc = time.perf_counter()
        print(
            f"Iter {it + 1}: Train loss {train_loss:.3f}, "
            f"It/sec {steps_per_report / (toc - tic):.3f}"
        )
        losses = []
        tic = time.perf_counter()
    if (it + 1) % steps_per_eval == 0:
        val_loss = eval_fn(valid_ds, model, batch_size=batch_size)
        toc = time.perf_counter()
        print(
            f"Iter {it + 1}: "
            f"Val loss {val_loss:.3f}, "
            f"Val ppl {math.exp(val_loss):.3f}, "
            f"Val took {(toc - tic):.3f}s, "
        )
        tic = time.perf_counter()

Iter 10: Train loss 3.420, It/sec 0.923
Iter 10: Val loss 4.030, Val ppl 56.280, Val took 8.778s, 
Iter 20: Train loss 3.598, It/sec 0.927
Iter 20: Val loss 3.699, Val ppl 40.426, Val took 8.885s, 


# Saving and loading new weights

In [48]:
from mlx.utils import tree_flatten

In [53]:
p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")

Total parameters 108.893M
Trainable parameters 108.893M


In [51]:
f = "tmp-fine-tuned.npz"
# mx.savez(f, **dict(tree_flatten(model.trainable_parameters())))

# Compare fine-tuned to raw model
To show that fine tuning did something

In [61]:
from model_mlx import BertQA

In [89]:
model_pre, tokenizer = load_model_tokenizer(hf_model=args.model_str,
                                            weights_pretrain_path=args.load_weights)



In [90]:
f = "weights/fine-tuned-tiny.npz"
model_fine, _ = load_model_tokenizer(hf_model=args.model_str,
                                     weights_finetuned_path=f)

In [91]:
p = sum(v.size for _, v in tree_flatten(model_pre.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model_pre.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")

p = sum(v.size for _, v in tree_flatten(model_fine.parameters())) / 10**6
print(f"Total parameters {p:.3f}M")
p = sum(v.size for _, v in tree_flatten(model_fine.trainable_parameters())) / 10**6
print(f"Trainable parameters {p:.3f}M")

Total parameters 108.893M
Trainable parameters 108.893M
Total parameters 108.893M
Trainable parameters 108.893M


In [92]:
len(test_ds)

250

In [95]:
model = model_pre

test_loss = eval_fn(test_ds, model, batch_size=args.batch_size)
test_loss, math.exp(test_loss)

(6.280390125274658, 533.9969489006398)

In [96]:
model = model_fine

test_loss = eval_fn(test_ds, model, batch_size=args.batch_size)
test_loss, math.exp(test_loss)

(3.391675977706909, 29.715713432679717)

# Inference

In [152]:
from utils import find_valid_answers, find_context_start_end

In [153]:
f = "weights/fine-tuned-tiny.npz"
model, tokenizer = load_model_tokenizer(hf_model=args.model_str,
                                        weights_finetuned_path=f)
str(tokenizer)[:50]

"BertTokenizerFast(name_or_path='bert-base-uncased'"

In [155]:
context = "HF Transformers is backed by the three most popular deep learning libraries - Jax, PyTorch and TensorFlow - with a seamless integration between them. It's straightforward to train your models with one before loading them for inference with the other"
question = "Which deep learning libraries back HF Transformers?"

question = "How many programming languages does BLOOM support?"
context = "BLOOM has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages."

In [156]:
inputs = tokenizer(question, context, return_tensors="mlx")
inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

In [158]:
# to decode, need to use HF
tokenizer.decode(np.array(inputs['input_ids']).flatten())

'[CLS] how many programming languages does bloom support? [SEP] bloom has 176 billion parameters and can generate text in 46 languages natural languages and 13 programming languages. [SEP]'

### Run inference

### Find valid answers
follow find_valid_answers()

In [160]:
print(inputs.sequence_ids())

[None, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]


In [175]:


context_start_index, context_end_index = find_context_start_end(
        inputs.sequence_ids())

context_start_index, context_end_index

(10, 28)

In [174]:
# check dims like this
# len(start_logits.shape), len(start_logits.flatten().shape)

In [178]:
start_logits, end_logits = model(**inputs)
start_logits.shape, end_logits.shape

# flatten first to extrac
start_logits = start_logits.flatten()[context_start_index: context_end_index + 1]
end_logits = end_logits.flatten()[context_start_index: context_end_index + 1]

start_logits.shape, end_logits.shape

((19,), (19,))

In [180]:
n_best_size = 5
top_k = min(n_best_size, len(start_logits))

top_k

5

In [183]:
topk_start_indices = np.argsort(start_logits)[-1: -n_best_size - 1: -1].tolist()
topk_end_indices = np.argsort(end_logits)[-1: -n_best_size - 1: -1].tolist()

topk_start_indices, topk_end_indices

([2, 10, 15, 3, 0], [3, 4, 2, 10, 15])

In [184]:
valid_answers = []

# score all top logits
for start in topk_start_indices:
    for end in topk_end_indices:
        if start <= end:
            valid_answers.append({
                "score": start_logits[start] + end_logits[end],
                # shift indeces back to input-zero'd
                "start": start + context_start_index,
                "end": end + context_start_index
            })
valid_answers.sort(key=lambda x: x['score'], reverse=True)

In [185]:
valid_answers

[{'score': array(10.2414, dtype=float32), 'start': 12, 'end': 13},
 {'score': array(9.44422, dtype=float32), 'start': 12, 'end': 14},
 {'score': array(8.82822, dtype=float32), 'start': 12, 'end': 12},
 {'score': array(8.22495, dtype=float32), 'start': 13, 'end': 13},
 {'score': array(8.077, dtype=float32), 'start': 12, 'end': 20},
 {'score': array(8.03333, dtype=float32), 'start': 12, 'end': 25},
 {'score': array(7.88629, dtype=float32), 'start': 10, 'end': 13},
 {'score': array(7.42905, dtype=float32), 'start': 20, 'end': 20},
 {'score': array(7.42777, dtype=float32), 'start': 13, 'end': 14},
 {'score': array(7.38538, dtype=float32), 'start': 20, 'end': 25},
 {'score': array(7.08911, dtype=float32), 'start': 10, 'end': 14},
 {'score': array(6.47311, dtype=float32), 'start': 10, 'end': 12},
 {'score': array(6.4699, dtype=float32), 'start': 25, 'end': 25},
 {'score': array(6.06056, dtype=float32), 'start': 13, 'end': 20},
 {'score': array(6.01688, dtype=float32), 'start': 13, 'end': 25}

In [189]:
print("\n", question)
for d in valid_answers[:5]:
    score = d['score']
    start = d['start']
    end = d['end']
    print(score)
    predict_answer_tokens = inputs.input_ids[0, start: end + 1]
    print(tokenizer.decode(np.array(predict_answer_tokens.flatten())))


 How many programming languages does BLOOM support?
array(10.2414, dtype=float32)
176 billion
array(9.44422, dtype=float32)
176 billion parameters
array(8.82822, dtype=float32)
176
array(8.22495, dtype=float32)
billion
array(8.077, dtype=float32)
176 billion parameters and can generate text in 46


# Misc

In [None]:
# MLX
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from model import Bert

bert_model = "bert-base-uncased"
mlx_weights_path = "weights/bert-base-uncased.npz"

config = AutoConfig.from_pretrained(bert_model)
model = Bert(config, add_pooler=True)
tokenizer = AutoTokenizer.from_pretrained(bert_model)

In [None]:
# PyTorch HF
pre_train_model = bert_model
tokenizerhf = BertTokenizerFast.from_pretrained(pre_train_model)
modelhf = BertForQuestionAnswering.from_pretrained(pre_train_model)

# MLX

In [None]:
import mlx.core as mx
from mlx.utils import tree_map

In [None]:
batch = ["This is an example of BERT working on MLX."]

In [None]:
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
tokens

In [None]:
tokens['input_ids'].shape, tokens['token_type_ids'].shape

In [None]:
tokens = tokenizer(batch, return_tensors="mlx", padding=True)
tokens

In [None]:
tokens['input_ids'].shape

In [None]:
tokens['input_ids'].shape, tokens['input_ids']

# HF

In [None]:
from transformers import BertModel as BertModelHF

In [None]:
bert_model = "bert-base-uncased"
config = AutoConfig.from_pretrained(bert_model)

model = BertModelHF(config)
tokenizer = AutoTokenizer.from_pretrained(bert_model)

In [117]:
batch2 = tokenizer(batch, return_tensors="pt", padding=True)
batch2

ValueError: text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

In [None]:
device = torch.device("mps")

input_ids = batch2['input_ids'].to(device)
token_type_ids = batch2['token_type_ids'].to(device)
attention_mask = batch2['attention_mask'].to(device)

model.to(device)
model.train()

with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask,
                    token_type_ids=token_type_ids)
outputs.keys()

In [None]:
# outputs

In [None]:
# this seems to be the main output
outputs[0]

In [None]:
self_qa_output = torch.nn.Linear(768, 2)
self_qa_output.to(device)

sequence_output = outputs[0]

logits = self_qa_output(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
# start_logits = start_logits.squeeze(-1)
# end_logits = end_logits.squeeze(-1)


In [None]:
logits.shape, start_logits.shape, end_logits.shape

In [None]:
logits

# Appendix

In [None]:
model

In [None]:
modelhf

### Pytorch model, input, output example


```python
>>> model
BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
  )
  (qa_outputs): Linear(in_features=768, out_features=2, bias=True)
)
>>> 
batch = next(iter(train_dataloader))
>>> batch
{'input_ids': tensor([[  101,  1996, 13546,  ...,     0,     0,     0],
        [  101,  2129,  2116,  ...,     0,     0,     0],
        [  101, 19739,  6862,  ...,     0,     0,     0],
        ...,
        [  101,  2129,  2116,  ...,     0,     0,     0],
        [  101,  1996, 26129,  ...,     0,     0,     0],
        [  101,  1999,  2054,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'start_positions': tensor([ 81, 267,  12, 149,  58,  80,  58,  74, 135,  98,  84, 107,  86,  28,
         37,  57]), 'end_positions': tensor([ 83, 269,  15, 149,  61,  82,  63,  74, 142, 100,  86, 107,  88,  28,
         38,  57])}
>>> input_ids = batch['input_ids'].to(device)
>>> attention_mask = batch['attention_mask'].to(device)
>>> start_positions = batch['start_positions'].to(device)
>>> end_positions = batch['end_positions'].to(device)
>>> outputs = model(input_ids, attention_mask=attention_mask,
...                             start_positions=start_positions, end_positions=end_positions)

>>> outputs
QuestionAnsweringModelOutput(loss=tensor(6.3726, device='mps:0', grad_fn=<DivBackward0>), start_logits=tensor([[-0.8564, -0.1199,  0.0125,  ...,  0.2456,  0.3330,  0.3860],
        [-0.7941, -0.1668,  0.2791,  ...,  0.1567,  0.1839,  0.1895],
        [-0.3875, -0.0241,  0.4691,  ...,  0.1442,  0.2343,  0.3314],
        ...,
        [-0.5861,  0.0238,  0.4390,  ...,  0.4266,  0.4844,  0.4618],
        [-0.8158, -0.1809, -0.2249,  ...,  0.3569,  0.3794,  0.3319],
        [-0.6281, -0.0430,  0.1375,  ...,  0.2855,  0.2233,  0.1889]],
       device='mps:0', grad_fn=<CloneBackward0>), end_logits=tensor([[ 0.3384, -0.1375, -0.0802,  ..., -0.0332, -0.0099, -0.0567],
        [ 0.2627, -0.2120, -0.3897,  ...,  0.0227, -0.0094,  0.0472],
        [ 0.6734,  0.2672, -0.0346,  ...,  0.0886,  0.1867,  0.1734],
        ...,
        [ 0.4112, -0.1477, -0.1991,  ...,  0.0158,  0.0874,  0.1071],
        [ 0.5058,  0.0427,  0.2843,  ...,  0.0442,  0.0791,  0.0314],
        [ 0.5014,  0.1437, -0.2467,  ...,  0.1379,  0.0093, -0.0738]],
       device='mps:0', grad_fn=<CloneBackward0>), hidden_states=None, attentions=None)
>>> outputs['start_logits']
tensor([[-0.8564, -0.1199,  0.0125,  ...,  0.2456,  0.3330,  0.3860],
        [-0.7941, -0.1668,  0.2791,  ...,  0.1567,  0.1839,  0.1895],
        [-0.3875, -0.0241,  0.4691,  ...,  0.1442,  0.2343,  0.3314],
        ...,
        [-0.5861,  0.0238,  0.4390,  ...,  0.4266,  0.4844,  0.4618],
        [-0.8158, -0.1809, -0.2249,  ...,  0.3569,  0.3794,  0.3319],
        [-0.6281, -0.0430,  0.1375,  ...,  0.2855,  0.2233,  0.1889]],
       device='mps:0', grad_fn=<CloneBackward0>)
>>> outputs['start_logits'].shape
torch.Size([16, 512])
>>> outputs.keys()
odict_keys(['loss', 'start_logits', 'end_logits'])

```