In [1]:
# Set JAX platform to CPU for highest matmul precision
import os
#os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:
from transformers import FlaxBloomForCausalLM, BloomForCausalLM, AutoTokenizer
import numpy as np
import torch
import jax

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#model_id = "bigscience/bigscience-small-testing"
model_id = "bigscience/bloom-350m"
scan_model_id = "sanchit-gandhi/bloom-350m-scan"

tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-350m")

pt_model = BloomForCausalLM.from_pretrained(model_id)
flax_model = FlaxBloomForCausalLM.from_pretrained(scan_model_id, use_scan=True)

Downloading: 100%|██████████████████████████████████████████████████████████████████████| 748/748 [00:00<00:00, 574kB/s]
Downloading: 100%|█████████████████████████████████████████████████████████████████| 1.04G/1.04G [00:11<00:00, 97.8MB/s]
tcmalloc: large alloc 1118437376 bytes == 0x1215de000 @  0x7f9741f04680 0x7f9741f25824 0x5f8a01 0x648cf1 0x5c4676 0x4f290e 0x64f718 0x5048b3 0x56b1da 0x56939a 0x50aaa0 0x56c28c 0x56939a 0x68d047 0x6003a4 0x5c4a40 0x56b0ae 0x5002d8 0x56cadf 0x5002d8 0x56cadf 0x5002d8 0x503fb6 0x56b1da 0x5f6836 0x56b0ae 0x5f6836 0x56b1da 0x56939a 0x5f6a13 0x50aa2c
Some of the weights of FlaxBloomForCausalLM were initialized in float16 precision from the model checkpoint at sanchit-gandhi/bloom-350m-scan:
[('transformer', 'h', 'FlaxBloomBlockLayers', 'input_layernorm', 'bias'), ('transformer', 'h', 'FlaxBloomBlockLayers', 'input_layernorm', 'scale'), ('transformer', 'h', 'FlaxBloomBlockLayers', 'mlp', 'dense_4h_to_h', 'bias'), ('transformer', 'h', 'FlaxBloomBlockLayer

In [4]:
input_str = [10*"hello this string is definitely longer", "Hey you"]

inputs_pt = tokenizer(input_str, return_tensors="pt", padding=True, truncation=True)
inputs_np = tokenizer(input_str, return_tensors="np", padding=True, truncation=True)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [5]:
with torch.no_grad():
    logits_pt = pt_model(**inputs_pt).logits
    logits_pt_single = pt_model(inputs_pt.input_ids[:1]).logits

In [6]:
# default matmul precision (bfloat16)
logits_fx = flax_model(**inputs_np).logits
logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits

print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

2022-07-06 13:34:14.872740: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


batched padded pt vs padded flax
1.8187866
batched full pt vs full flax
4.532715
single pt vs flax
4.8477783
single flax vs flax
2.32901


In [7]:
# highest matmul precision (float32)
with jax.default_matmul_precision('float32'):
    logits_fx = flax_model(**inputs_np).logits
    logits_fx_single = flax_model(inputs_np.input_ids[:1]).logits
    
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

batched padded pt vs padded flax
0.0031738281
batched full pt vs full flax
0.005065918
single pt vs flax
0.0045776367
single flax vs flax
0.0004272461


## JIT the fprop and watch the magic happen

In [8]:
@jax.jit
def flax_model_jitted(input_ids, attention_mask=None, **kwargs):
    return flax_model(input_ids, attention_mask=attention_mask, **kwargs)

In [9]:
# microbench jit compile time for batch
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

tcmalloc: large alloc 2236956672 bytes == 0x2d2610000 @  0x7f9741f04680 0x7f9741f25824 0x58f8b8 0x586650 0x5869d4 0x619464 0x6195b6 0x6217b3 0x5042cb 0x56b1da 0x5f6836 0x570035 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x5f6836 0x56b1da 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547
tcmalloc: large alloc 2236956672 bytes == 0x1ef80c000 @  0x7f9741f04680 0x7f9741f24ff4 0x7f95d16381de 0x7f95d163a979 0x7f95d1670533 0x7f95d164f991 0x5f3989 0x5f3e1e 0x50b183 0x56c28c 0x5f6836 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x56b0ae 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x5f6836 0x56b1da 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547
tcmalloc: large alloc 2146123776 bytes == 0x4fe574000 @  0x7f9741f04680 0x7f9741f24ff4 0x7f95d54228ca 0x7f95d46f3cb7 0x7f95d46e9e17 0x7f95d46e3249 0x7f95d382a611 0x7f95d3838ad0 0x7f95d188aeaa 0x7f95d166ff56 0x7f95d1670597 0x7f95d164f991 0x5f3989 0x5f3e1e 0x50b183 0x56c28c 0x5f683

CPU times: user 2min 7s, sys: 13 s, total: 2min 20s
Wall time: 2min 14s


In [10]:
# microbench compiled fprop -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx = flax_model_jitted(**inputs_np).logits.block_until_ready()

CPU times: user 3.77 ms, sys: 1.02 ms, total: 4.79 ms
Wall time: 11.4 ms


In [11]:
# microbench jit compile time for single input
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

tcmalloc: large alloc 2236956672 bytes == 0x1e380a000 @  0x7f9741f04680 0x7f9741f25824 0x58f8b8 0x586650 0x5869d4 0x619464 0x6195b6 0x6217b3 0x5042cb 0x56b1da 0x5f6836 0x570035 0x5f6836 0x56b0ae 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x5f6836 0x56b1da 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547 0x56c8cd 0x56939a 0x5f6a13 0x5f3547


CPU times: user 1min 32s, sys: 6.46 s, total: 1min 38s
Wall time: 1min 35s


In [12]:
# microbench compiled fprop for single input -> should be ~ms, if on the order of seconds inidicates a recompilation
%time logits_fx_single = flax_model_jitted(inputs_np.input_ids[:1]).logits.block_until_ready()

CPU times: user 3.17 ms, sys: 694 µs, total: 3.87 ms
Wall time: 10.1 ms


In [13]:
# verify correctness of jit-compiled fprop
print("batched padded pt vs padded flax")
print(np.max(np.abs(logits_pt[1, :2].numpy() - np.array(logits_fx[1, :2]))))

print("batched full pt vs full flax")
print(np.max(np.abs(logits_pt[0].numpy() - np.array(logits_fx[0]))))

print("single pt vs flax")
print(np.max(np.abs(logits_pt_single[0].numpy() - np.array(logits_fx_single[0]))))

print("single flax vs flax")
print(np.max(np.abs(np.array(logits_fx[0]) - np.array(logits_fx_single[0]))))

batched padded pt vs padded flax
1.8187866
batched full pt vs full flax
4.532715
single pt vs flax
4.8477783
single flax vs flax
2.32901
