# Calling BERT model from JAX (with BERT weights in JAX)

In [3]:
from __future__ import annotations

import random
import time

from transformers import BertTokenizer, BertModel
from datasets import load_dataset
import torch
from torch import Tensor
from torch.func import functional_call
import jax
from jax import numpy as jnp, Array

from torch2jax import tree_t2j, torch2jax_with_vjp


### Loading the dataset and the model (in PyTorch)

In [4]:
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")
model.to(device)
model.eval()


def tokenizer_torch(text: list[str]) -> dict[str, Tensor]:
    encoded = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
    return {k: v.to(device) for (k, v) in encoded.items()}

Found cached dataset wikitext (/home/rdyro/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Let's convert the torch model to a function, using `torch.func.functional_call`

In [5]:
params, buffers = dict(model.named_parameters()), dict(model.named_buffers())

def torch_fwd_fn(params, buffers, input):
    return functional_call(model, (params, buffers), args=(), kwargs=input).pooler_output

### We do not need to specify output, the library will call the torch function ones to infer the output

In [7]:
nb = 50
text = [x["text"] for x in random.choices(dataset, k=int(1e3)) if len(x["text"]) > 100][:nb]
encoded_text = tokenizer_torch(text)

jax_fwd_fn = jax.jit(torch2jax_with_vjp(torch_fwd_fn, params, buffers, encoded_text))
params_jax, buffers_jax = tree_t2j(params), tree_t2j(buffers)

In [9]:
for i in range(10):
    text = [x["text"] for x in random.choices(dataset, k=int(1e3)) if len(x["text"]) > 100][:nb]
    encoded_text = tokenizer_torch(text)
    encoded_text_jax = tree_t2j(encoded_text)

    t = time.time()
    out1 = jax_fwd_fn(params_jax, buffers_jax, encoded_text_jax)
    t = time.time() - t
    print(f"JAX version took:   {t:.4e} s")

    t = time.time()
    with torch.no_grad():
        out2 = model(**encoded_text).pooler_output
    torch.cuda.synchronize()
    t = time.time() - t
    print(f"Torch version took: {t:.4e} s")
    print(f"err = {jnp.linalg.norm(out1 - tree_t2j(out2)):.4e}")


JAX version took:   2.9324e-01 s
Torch version took: 1.7142e-01 s
err = 0.0000e+00
JAX version took:   3.5733e-01 s
Torch version took: 2.3677e-01 s
err = 0.0000e+00
JAX version took:   3.1890e-01 s
Torch version took: 1.9941e-01 s
err = 0.0000e+00
JAX version took:   4.4257e-01 s
Torch version took: 3.2214e-01 s
err = 0.0000e+00
JAX version took:   2.9708e-01 s
Torch version took: 1.7585e-01 s
err = 0.0000e+00
JAX version took:   3.2778e-01 s
Torch version took: 2.0791e-01 s
err = 0.0000e+00
JAX version took:   2.7836e-01 s
Torch version took: 1.5862e-01 s
err = 0.0000e+00
JAX version took:   3.1563e-01 s
Torch version took: 1.9557e-01 s
err = 0.0000e+00
JAX version took:   3.0478e-01 s
Torch version took: 1.8507e-01 s
err = 0.0000e+00
JAX version took:   2.7461e-01 s
Torch version took: 1.5416e-01 s
err = 0.0000e+00
