# Fine Tune BERT for Q&A with Apple MLX

and compare to PyTorch HuggingFace implementation

In [2]:
%load_ext autoreload
%autoreload 2

In [13]:
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering

In [8]:
from utils import load_squad
from model import load_model

# Load

In [7]:
squad = load_squad(filter_size=500)
squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 250
    })
    valid: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 125
    })
    test: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 125
    })
})

In [10]:
# MLX
bert_model = "bert-base-uncased"
mlx_weights_path = "weights/bert-base-uncased.npz"
model, tokenizer = load_model(bert_model, mlx_weights_path)

In [17]:
# PyTorch HF
pre_train_model = bert_model
tokenizerhf = BertTokenizerFast.from_pretrained(pre_train_model)
modelhf = BertForQuestionAnswering.from_pretrained(pre_train_model)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Appendix

In [11]:
model

Bert(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (token_type_embeddings): Embedding(2, 768)
    (position_embeddings): Embedding(512, 768)
    (norm): LayerNorm(768, eps=1e-12, affine=True)
  )
  (encoder): TransformerEncoder(
    (layers.0): TransformerEncoderLayer(
      (attention): MultiHeadAttention(
        (query_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (key_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (value_proj): Linear(input_dims=768, output_dims=768, bias=True)
        (out_proj): Linear(input_dims=768, output_dims=768, bias=True)
      )
      (ln1): LayerNorm(768, eps=1e-12, affine=True)
      (ln2): LayerNorm(768, eps=1e-12, affine=True)
      (linear1): Linear(input_dims=768, output_dims=3072, bias=True)
      (linear2): Linear(input_dims=3072, output_dims=768, bias=True)
      (gelu): GELU()
    )
    (layers.1): TransformerEncoderLayer(
      (attention): MultiHeadAttention(
        (q

In [21]:
modelhf

BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, 