In [1]:
# Set JAX platform to CPU for highest matmul precision
import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [9]:
from transformers import FlaxGoldenGateForCausalLM, GoldenGateForCausalLM, AutoTokenizer
import numpy as np
import torch
import jax
import tempfile

## 1. Load pre-trained PyTorch model

In [7]:
model_id = "gg-hf/golden-gate-2b"

tokenizer = AutoTokenizer.from_pretrained("gg-hf/golden-gate-2b")
pt_model = GoldenGateForCausalLM.from_pretrained(model_id)

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.76it/s]


OSError: Can't load the model for 'gg-hf/golden-gate-2b'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'gg-hf/golden-gate-2b' is the correct path to a directory containing a file named flax_model.msgpack or pytorch_model.bin.

## 2. Convert PyTorch weights to Flax

In [10]:
with tempfile.TemporaryDirectory() as tmpdirname:
    pt_model.save_pretrained(tmpdirname, safe_serialization=False)
    flax_model = FlaxGoldenGateForCausalLM.from_pretrained(tmpdirname, from_pt=True)

  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at /tmp/tmpspc7r7v9 were not used when initializing FlaxGoldenGateForCausalLM: {('lm_head', 'kernel')}
- This IS expected if you are initializing FlaxGoldenGateForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxGoldenGateForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
input_str = [10*"hello this string is definitely longer", "Hey you"]

inputs_pt = tokenizer(input_str, return_tensors="pt", padding=True, truncation=True)
inputs_np = tokenizer(input_str, return_tensors="np", padding=True, truncation=True)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [12]:
with torch.no_grad():
    logits_pt = pt_model(**inputs_pt).logits
    logits_pt_single = pt_model(inputs_pt.input_ids[:1]).logits

In [13]:
# default matmul precision (bfloat16)
logits_fx = flax_model(**inputs_np).logits
logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits

print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

batched padded pt vs padded flax
68.04801
batched full pt vs full flax
1.4137669
single pt vs flax
1.4139137
single flax vs flax
0.938324


In [14]:
# highest matmul precision (float32)
with jax.default_matmul_precision('float32'):
    logits_fx = flax_model(**inputs_np).logits
    logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits
    
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

batched padded pt vs padded flax
66.512764
batched full pt vs full flax
0.0005722046
single pt vs flax
0.00086021423
single flax vs flax
0.00028800964


## JIT the fprop and watch the magic happen

In [8]:
@jax.jit
def flax_model_jitted(input_ids, attention_mask=None, **kwargs):
    return flax_model(input_ids, attention_mask=attention_mask, **kwargs)

In [None]:
# microbench jit compile time for batch
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

tcmalloc: large alloc 2237095936 bytes == 0x29d86a000 @  0x7fbbe9d99680 0x7fbbe9dba824 0x58f8b8 0x586650 0x5869d4 0x619464 0x6194c2 0x6213b0 0x62177a 0x5c47d0 0x5f6517 0x7fba62077bf1 0x7fba618dab27 0x7fba61ab56a6 0x7fba619d1c46 0x7fba619d1f77 0x7fba619d833d 0x7fba619d8e80 0x7fba61a500b0 0x7fba619d9197 0x7fba619d9a5b 0x7fba619da21b 0x7fba61a3b5e6 0x7fba618ebfab 0x7fba618ec2a6 0x7fba619d9197 0x7fba619d9a5b 0x7fba619da21b 0x7fba61a1096d 0x7fba61a109c6 0x7fba619d9197
tcmalloc: large alloc 2237431808 bytes == 0x1e5272000 @  0x7fbbe9d99680 0x7fbbe9dba824 0x58f8b8 0x586650 0x5869d4 0x619464 0x6195b6 0x6217b3 0x5042cb 0x56b1da 0x5f6836 0x570035 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x5f6836 0x56b1da 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547
tcmalloc: large alloc 2237431808 bytes == 0x29d86a000 @  0x7fbbe9d99680 0x7fbbe9db9ff4 0x7fba794cd1de 0x7fba794cf979 0x7fba79505533 0x7fba794e4991 0x5f3989 0x5f3e1e 0x50b183 0x56c28c 0

In [None]:
# microbench compiled fprop -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

In [None]:
# microbench jit compile time for single input
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

In [None]:
# microbench compiled fprop for single input -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

In [None]:
# verify correctness of jit-compiled fprop
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))