In [1]:
# Set JAX platform to CPU for highest matmul precision
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:
from transformers import FlaxBloomForCausalLM, BloomForCausalLM, AutoTokenizer
import numpy as np
import torch
import jax

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#model_id = "bigscience/bigscience-small-testing"
model_id = "bigscience/bloom-350m"

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-350m")

pt_model = BloomForCausalLM.from_pretrained(model_id)
flax_model = FlaxBloomForCausalLM.from_pretrained(model_id, from_pt=True)

Some of the weights of FlaxBloomForCausalLM were initialized in float16 precision from the model checkpoint at bigscience/bloom-350m:
[('transformer', 'h', '0', 'input_layernorm', 'bias'), ('transformer', 'h', '0', 'input_layernorm', 'scale'), ('transformer', 'h', '0', 'mlp', 'dense_4h_to_h', 'bias'), ('transformer', 'h', '0', 'mlp', 'dense_4h_to_h', 'kernel'), ('transformer', 'h', '0', 'mlp', 'dense_h_to_4h', 'bias'), ('transformer', 'h', '0', 'mlp', 'dense_h_to_4h', 'kernel'), ('transformer', 'h', '0', 'post_attention_layernorm', 'bias'), ('transformer', 'h', '0', 'post_attention_layernorm', 'scale'), ('transformer', 'h', '0', 'self_attention', 'dense', 'bias'), ('transformer', 'h', '0', 'self_attention', 'dense', 'kernel'), ('transformer', 'h', '0', 'self_attention', 'query_key_value', 'bias'), ('transformer', 'h', '0', 'self_attention', 'query_key_value', 'kernel'), ('transformer', 'h', '1', 'input_layernorm', 'bias'), ('transformer', 'h', '1', 'input_layernorm', 'scale'), ('transf

In [4]:
input_str = [10*"hello this string is definitely longer", "Hey you"]

inputs_pt = tokenizer(input_str, return_tensors="pt", padding=True, truncation=True)
inputs_np = tokenizer(input_str, return_tensors="np", padding=True, truncation=True)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [5]:
with torch.no_grad():
    logits_pt = pt_model(**inputs_pt).logits
    logits_pt_single = pt_model(inputs_pt.input_ids[:1]).logits

In [6]:
# default matmul precision (bfloat16)
logits_fx = flax_model(**inputs_np).logits
logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits

print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

2022-07-06 13:22:48.113268: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


batched padded pt vs padded flax
1.8187866
batched full pt vs full flax
4.13562
single pt vs flax
5.2735596
single flax vs flax
2.4197693


In [7]:
# highest matmul precision (float32)
with jax.default_matmul_precision('float32'):
    logits_fx = flax_model(**inputs_np).logits
    logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits
    
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

batched padded pt vs padded flax
0.0032348633
batched full pt vs full flax
0.0056152344
single pt vs flax
0.0048828125
single flax vs flax
0.00036621094


## JIT the fprop and watch the magic happen

In [8]:
@jax.jit
def flax_model_jitted(input_ids, attention_mask=None, **kwargs):
    return flax_model(input_ids, attention_mask=attention_mask, **kwargs)

In [None]:
# microbench jit compile time for batch
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

tcmalloc: large alloc 2237095936 bytes == 0x29d86a000 @  0x7fbbe9d99680 0x7fbbe9dba824 0x58f8b8 0x586650 0x5869d4 0x619464 0x6194c2 0x6213b0 0x62177a 0x5c47d0 0x5f6517 0x7fba62077bf1 0x7fba618dab27 0x7fba61ab56a6 0x7fba619d1c46 0x7fba619d1f77 0x7fba619d833d 0x7fba619d8e80 0x7fba61a500b0 0x7fba619d9197 0x7fba619d9a5b 0x7fba619da21b 0x7fba61a3b5e6 0x7fba618ebfab 0x7fba618ec2a6 0x7fba619d9197 0x7fba619d9a5b 0x7fba619da21b 0x7fba61a1096d 0x7fba61a109c6 0x7fba619d9197
tcmalloc: large alloc 2237431808 bytes == 0x1e5272000 @  0x7fbbe9d99680 0x7fbbe9dba824 0x58f8b8 0x586650 0x5869d4 0x619464 0x6195b6 0x6217b3 0x5042cb 0x56b1da 0x5f6836 0x570035 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x5f6836 0x56b1da 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547
tcmalloc: large alloc 2237431808 bytes == 0x29d86a000 @  0x7fbbe9d99680 0x7fbbe9db9ff4 0x7fba794cd1de 0x7fba794cf979 0x7fba79505533 0x7fba794e4991 0x5f3989 0x5f3e1e 0x50b183 0x56c28c 0

In [None]:
# microbench compiled fprop -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

In [None]:
# microbench jit compile time for single input
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

In [None]:
# microbench compiled fprop for single input -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

In [None]:
# verify correctness of jit-compiled fprop
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))