In [1]:
# Set JAX platform to CPU for highest matmul precision
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:
from transformers import FlaxBloomForCausalLM, BloomForCausalLM, AutoTokenizer
import numpy as np
import torch
import jax

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_id = "bigscience/bigscience-small-testing"
#model_id = "bigscience/bloom-350m"

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-350m")

pt_model = BloomForCausalLM.from_pretrained(model_id)
flax_model = FlaxBloomForCausalLM.from_pretrained(model_id, from_pt=True)

Some of the weights of FlaxBloomForCausalLM were initialized in bfloat16 precision from the model checkpoint at bigscience/bigscience-small-testing:
[('transformer', 'h', '0', 'input_layernorm', 'bias'), ('transformer', 'h', '0', 'input_layernorm', 'scale'), ('transformer', 'h', '0', 'mlp', 'dense_4h_to_h', 'bias'), ('transformer', 'h', '0', 'mlp', 'dense_4h_to_h', 'kernel'), ('transformer', 'h', '0', 'mlp', 'dense_h_to_4h', 'bias'), ('transformer', 'h', '0', 'mlp', 'dense_h_to_4h', 'kernel'), ('transformer', 'h', '0', 'post_attention_layernorm', 'bias'), ('transformer', 'h', '0', 'post_attention_layernorm', 'scale'), ('transformer', 'h', '0', 'self_attention', 'dense', 'bias'), ('transformer', 'h', '0', 'self_attention', 'dense', 'kernel'), ('transformer', 'h', '0', 'self_attention', 'query_key_value', 'bias'), ('transformer', 'h', '0', 'self_attention', 'query_key_value', 'kernel'), ('transformer', 'h', '1', 'input_layernorm', 'bias'), ('transformer', 'h', '1', 'input_layernorm', 'sc

In [4]:
input_str = [10*"hello this string is definitely longer", "Hey you"]

inputs_pt = tokenizer(input_str, return_tensors="pt", padding=True, truncation=True)
inputs_np = tokenizer(input_str, return_tensors="np", padding=True, truncation=True)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [5]:
with torch.no_grad():
    logits_pt = pt_model(**inputs_pt).logits
    logits_pt_single = pt_model(inputs_pt.input_ids[:1]).logits

In [6]:
# default matmul precision (bfloat16)
logits_fx = flax_model(**inputs_np).logits
logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits

print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

2022-07-06 13:20:07.679762: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


batched padded pt vs padded flax
0.00037539005
batched full pt vs full flax
0.0003239885
single pt vs flax
0.0003239885
single flax vs flax
0.0


In [7]:
# highest matmul precision (float32)
with jax.default_matmul_precision('float32'):
    logits_fx = flax_model(**inputs_np).logits
    logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits
    
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

batched padded pt vs padded flax
5.9604645e-08
batched full pt vs full flax
1.4901161e-07
single pt vs flax
1.1920929e-07
single flax vs flax
5.9604645e-08


## JIT the fprop and watch the magic happen

In [8]:
@jax.jit
def flax_model_jitted(input_ids, attention_mask=None, **kwargs):
    return flax_model(input_ids, attention_mask=attention_mask, **kwargs)

In [9]:
# microbench jit compile time for batch
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

CPU times: user 6.77 s, sys: 414 ms, total: 7.18 s
Wall time: 6.54 s


In [10]:
# microbench compiled fprop -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

CPU times: user 2.62 ms, sys: 718 µs, total: 3.33 ms
Wall time: 2.93 ms


In [11]:
# microbench jit compile time for single input
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

CPU times: user 5.94 s, sys: 239 ms, total: 6.18 s
Wall time: 5.85 s


In [12]:
# microbench compiled fprop for single input -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

CPU times: user 2.28 ms, sys: 559 µs, total: 2.84 ms
Wall time: 1.39 ms


In [13]:
# verify correctness of jit-compiled fprop
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

batched padded pt vs padded flax
0.00037539005
batched full pt vs full flax
0.0003239885
single pt vs flax
0.0003239885
single flax vs flax
0.0
