# Fine Tune BERT for Q&A with Apple MLX

and compare to PyTorch HuggingFace implementation

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from utils import load_squad
from model import load_model

# Load

In [6]:
squad = load_squad(filter_size=500)
squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 250
    })
    valid: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 125
    })
    test: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 125
    })
})

In [13]:
# MLX
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase
from model import Bert

bert_model = "bert-base-uncased"
mlx_weights_path = "weights/bert-base-uncased.npz"

config = AutoConfig.from_pretrained(bert_model)
model = Bert(config, add_pooler=True)
tokenizer = AutoTokenizer.from_pretrained(bert_model)



In [None]:
# PyTorch HF
pre_train_model = bert_model
tokenizerhf = BertTokenizerFast.from_pretrained(pre_train_model)
modelhf = BertForQuestionAnswering.from_pretrained(pre_train_model)

# MLX

In [27]:
import mlx.core as mx
from mlx.utils import tree_map

In [24]:
batch = ["This is an example of BERT working on MLX."]

In [25]:
tokens = tokenizer(batch, return_tensors="np", padding=True)
tokens = {key: mx.array(v) for key, v in tokens.items()}
tokens

{'input_ids': array([[101, 2023, 2003, ..., 2595, 1012, 102]], dtype=int64),
 'token_type_ids': array([[0, 0, 0, ..., 0, 0, 0]], dtype=int64),
 'attention_mask': array([[1, 1, 1, ..., 1, 1, 1]], dtype=int64)}

In [26]:
tokens = tokenizer(batch, return_tensors="mlx", padding=True)
tokens

{'input_ids': array([[101, 2023, 2003, ..., 2595, 1012, 102]], dtype=int32), 'token_type_ids': array([[0, 0, 0, ..., 0, 0, 0]], dtype=int32), 'attention_mask': array([[1, 1, 1, ..., 1, 1, 1]], dtype=int32)}

In [37]:
tokens['input_ids'].shape

(1, 13)

In [19]:
tokens['input_ids'].shape, tokens['input_ids']

((1, 13), array([[101, 2023, 2003, ..., 2595, 1012, 102]], dtype=int64))

# HF

In [40]:
from transformers import BertModel as BertModelHF

In [43]:
bert_model = "bert-base-uncased"
config = AutoConfig.from_pretrained(bert_model)

model = BertModelHF(config)
tokenizer = AutoTokenizer.from_pretrained(bert_model)



In [50]:
batch2 = tokenizer(batch, return_tensors="pt", padding=True)
batch2

{'input_ids': tensor([[  101,  2023,  2003,  2019,  2742,  1997, 14324,  2551,  2006, 19875,
          2595,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [56]:
device = torch.device("mps")

input_ids = batch2['input_ids'].to(device)
token_type_ids = batch2['token_type_ids'].to(device)
attention_mask = batch2['attention_mask'].to(device)

model.to(device)
model.train()

with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask,
                    token_type_ids=token_type_ids)
outputs.keys()

odict_keys(['last_hidden_state', 'pooler_output'])

In [None]:
# outputs

In [59]:
# this seems to be the main output
outputs[0]

tensor([[[-0.9749, -0.8778, -1.0495,  ...,  0.0232, -0.7137,  0.1091],
         [-0.2013, -2.1369, -0.1064,  ..., -0.2749,  0.2540, -0.1586],
         [-1.3305, -1.0369,  0.5428,  ..., -0.3688,  0.3484, -0.0307],
         ...,
         [-1.4051, -2.7570,  1.2217,  ...,  0.2561, -0.7478,  0.7259],
         [-0.3601,  0.1444, -0.0155,  ..., -0.7146,  1.0159,  0.1831],
         [-0.4500, -1.5011, -0.4606,  ..., -0.3835, -0.0132,  0.0782]]],
       device='mps:0')

In [64]:
self_qa_output = torch.nn.Linear(768, 2)
self_qa_output.to(device)

sequence_output = outputs[0]

logits = self_qa_output(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
# start_logits = start_logits.squeeze(-1)
# end_logits = end_logits.squeeze(-1)


In [65]:
logits.shape, start_logits.shape, end_logits.shape

(torch.Size([1, 13, 2]), torch.Size([1, 13, 1]), torch.Size([1, 13, 1]))

In [63]:
logits

tensor([[[ 0.4024,  0.4357],
         [-0.1407,  0.8876],
         [ 0.3326,  0.5539],
         [-0.2147,  0.6298],
         [ 0.0588,  0.9387],
         [-0.4208,  1.5841],
         [ 0.2975,  0.1298],
         [-0.1793,  0.6286],
         [-0.2853,  1.3027],
         [ 0.3102,  0.7564],
         [-0.1742,  0.9697],
         [-0.6204,  1.5335],
         [-0.4409,  0.9782]]], device='mps:0', grad_fn=<LinearBackward0>)

# Appendix

In [11]:
model

Bert(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (token_type_embeddings): Embedding(2, 768)
    (position_embeddings): Embedding(512, 768)
    (norm): LayerNorm(768, eps=1e-12, affine=True)
  )
  (encoder): TransformerEncoder(
    (layers.0): TransformerEncoderLayer(
      (attention): MultiHeadAttention(
        (query_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (key_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (value_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (out_proj): Linear(input_dims=768, output_dims=768, bias=True)
      )
      (ln1): LayerNorm(768, eps=1e-12, affine=True)
      (ln2): LayerNorm(768, eps=1e-12, affine=True)
      (linear1): Linear(input_dims=768, output_dims=3072, bias=True)
      (linear2): Linear(input_dims=3072, output_dims=768, bias=True)
      (gelu): GELU()
    )
    (layers.1): TransformerEncoderLayer(
      (attention): MultiHeadAttention(
        (q

In [21]:
modelhf

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, 

### Pytorch model, input, output example


```python
>>> model
BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
  )
  (qa_outputs): Linear(in_features=768, out_features=2, bias=True)
)
>>> 
batch = next(iter(train_dataloader))
>>> batch
{'input_ids': tensor([[  101,  1996, 13546,  ...,     0,     0,     0],
        [  101,  2129,  2116,  ...,     0,     0,     0],
        [  101, 19739,  6862,  ...,     0,     0,     0],
        ...,
        [  101,  2129,  2116,  ...,     0,     0,     0],
        [  101,  1996, 26129,  ...,     0,     0,     0],
        [  101,  1999,  2054,  ...,     0,     0,     0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'start_positions': tensor([ 81, 267,  12, 149,  58,  80,  58,  74, 135,  98,  84, 107,  86,  28,
         37,  57]), 'end_positions': tensor([ 83, 269,  15, 149,  61,  82,  63,  74, 142, 100,  86, 107,  88,  28,
         38,  57])}
>>> input_ids = batch['input_ids'].to(device)
>>> attention_mask = batch['attention_mask'].to(device)
>>> start_positions = batch['start_positions'].to(device)
>>> end_positions = batch['end_positions'].to(device)
>>> outputs = model(input_ids, attention_mask=attention_mask,
...                             start_positions=start_positions, end_positions=end_positions)

>>> outputs
QuestionAnsweringModelOutput(loss=tensor(6.3726, device='mps:0', grad_fn=<DivBackward0>), start_logits=tensor([[-0.8564, -0.1199,  0.0125,  ...,  0.2456,  0.3330,  0.3860],
        [-0.7941, -0.1668,  0.2791,  ...,  0.1567,  0.1839,  0.1895],
        [-0.3875, -0.0241,  0.4691,  ...,  0.1442,  0.2343,  0.3314],
        ...,
        [-0.5861,  0.0238,  0.4390,  ...,  0.4266,  0.4844,  0.4618],
        [-0.8158, -0.1809, -0.2249,  ...,  0.3569,  0.3794,  0.3319],
        [-0.6281, -0.0430,  0.1375,  ...,  0.2855,  0.2233,  0.1889]],
       device='mps:0', grad_fn=<CloneBackward0>), end_logits=tensor([[ 0.3384, -0.1375, -0.0802,  ..., -0.0332, -0.0099, -0.0567],
        [ 0.2627, -0.2120, -0.3897,  ...,  0.0227, -0.0094,  0.0472],
        [ 0.6734,  0.2672, -0.0346,  ...,  0.0886,  0.1867,  0.1734],
        ...,
        [ 0.4112, -0.1477, -0.1991,  ...,  0.0158,  0.0874,  0.1071],
        [ 0.5058,  0.0427,  0.2843,  ...,  0.0442,  0.0791,  0.0314],
        [ 0.5014,  0.1437, -0.2467,  ...,  0.1379,  0.0093, -0.0738]],
       device='mps:0', grad_fn=<CloneBackward0>), hidden_states=None, attentions=None)
>>> outputs['start_logits']
tensor([[-0.8564, -0.1199,  0.0125,  ...,  0.2456,  0.3330,  0.3860],
        [-0.7941, -0.1668,  0.2791,  ...,  0.1567,  0.1839,  0.1895],
        [-0.3875, -0.0241,  0.4691,  ...,  0.1442,  0.2343,  0.3314],
        ...,
        [-0.5861,  0.0238,  0.4390,  ...,  0.4266,  0.4844,  0.4618],
        [-0.8158, -0.1809, -0.2249,  ...,  0.3569,  0.3794,  0.3319],
        [-0.6281, -0.0430,  0.1375,  ...,  0.2855,  0.2233,  0.1889]],
       device='mps:0', grad_fn=<CloneBackward0>)
>>> outputs['start_logits'].shape
torch.Size([16, 512])
>>> outputs.keys()
odict_keys(['loss', 'start_logits', 'end_logits'])

```