In [None]:
from transformers import FlaxRobertaModel, RobertaTokenizerFast
from datasets import load_dataset
import jax

dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)

dummy_input = next(iter(dataset))["text"]

tokenizer = RobertaTokenizerFast.from_pretrained("FacebookAI/roberta-base")
input_ids = tokenizer(dummy_input, return_tensors="np").input_ids[:, :10]

model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")

# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
model(input_ids)

In [None]:
input_ids

In [None]:
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

In [None]:
params = {'n': 5, 'W': jnp.ones((2,2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

In [None]:
from typing import NamedTuple 

class Params(NamedTuple):
    a: int 
    b: float 

params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))

In [None]:
key1, key2 = jax.random.split(jax.random.PRNGKey(0))
mat = jax.random.normal(key1, (150, 100))
batched_x = jax.random.normal(key2, (10, 100))


def apply_matrix(x):
    return jnp.dot(mat, x)

In [None]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

In [None]:
@jit
def vmap_batched_apply_matrix(batched_x):
  return jax.vmap(apply_matrix)(batched_x)

vmap_batched_apply_matrix(batched_x)
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()