In [1]:
from transformers import T5Tokenizer, FlaxMT5ForConditionalGeneration
import jax



In [2]:
model = FlaxMT5ForConditionalGeneration.from_pretrained("lewtun/tiny-random-mt5", from_pt=True)
tokenizer = T5Tokenizer.from_pretrained("lewtun/tiny-random-mt5")

input_context = "The dog"
input_ids = tokenizer(input_context, return_tensors="np").input_ids

Some weights of FlaxMT5ForConditionalGeneration were not initialized from the model checkpoint at lewtun/tiny-random-mt5 and are newly initialized: {('lm_head', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
# generate candidates using sampling and vanilla generate method
%time outputs = model.generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True).sequences.block_until_ready()

CPU times: user 4.54 s, sys: 86 ms, total: 4.62 s
Wall time: 4.52 s


In [4]:
# vanilla generate -> JIT generate 
jit_generate = jax.jit(model.generate, static_argnames=["max_length", "top_k", "do_sample"])

In [5]:
# benchmark JIT compile time
%time outputs = jit_generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True).sequences.block_until_ready()

CPU times: user 6.34 s, sys: 166 ms, total: 6.51 s
Wall time: 6.46 s


In [6]:
# benchmark compiled generation time (should be << JIT compile time)
%time outputs = jit_generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True).sequences.block_until_ready()

CPU times: user 97.7 ms, sys: 17.5 ms, total: 115 ms
Wall time: 85.9 ms


In [None]:
#Â you should remove `.block_until_ready() when no longer benchmarking
outputs = jit_generate(input_ids=input_ids, max_length=20, top_k=30, do_sample=True).sequences
print(outputs)