In [14]:
%load_ext autoreload
%autoreload 2

import os
import functools
from pprint import pprint
from pathlib import Path
from typing import Any

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
from jax import numpy as jnp
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from flax import linen as nn

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
import kagglehub
#kagglehub.login()

cpu_device, devices = jax.devices("cpu")[0], jax.devices("cuda")
pprint([f"{k} = {v}" for k, v in os.environ.items() if k.startswith("XLA")])

jax.config.update("jax_compilation_cache_dir", 
                  str(Path("~/.cache/jax_compilation_cache").expanduser()))

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
['XLA_PYTHON_CLIENT_PREALLOCATE = false']


In [2]:
@functools.partial(jax.jit, static_argnums=(0, 1))
def casual_attention_mask(seq_len: int, max_seq_len: int):
  return jnp.arange(seq_len)[..., None] >= jnp.arange(max_seq_len)[None, ...]

In [3]:
variant = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{variant}')
ckpt_path = os.path.join(weights_dir, variant)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')

In [4]:
# Load parameters
#params = params_lib.load_and_format_params(ckpt_path)
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
#transformer_config=transformer_lib.TransformerConfig.from_params(
#    params,
#    cache_size=1024  # Number of time steps in the transformer's cache
#)
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=32)
transformer = transformer_lib.Transformer(config)

In [5]:
is_param = lambda x: isinstance(x, nn.LogicallyPartitioned)

def init_params(batch_size: int):
  input_len = 1 # or 1 or 7, this dimension doesn't matter in initialization
  random_key = random.key(0)
  input_sequence = jnp.zeros((batch_size, input_len), dtype=jnp.int32)
  positions = jnp.broadcast_to(
    jnp.arange(input_sequence.shape[-1]).astype(jnp.int32), 
    input_sequence.shape)
  attention_mask = jnp.ones((batch_size, input_len, config.max_cache_length), 
                            dtype=jnp.bool)
  cache = config.init_cache(batch_size, jnp.float32)
  cache_value = jax.tree.map(lambda x: x.value if is_param(x) else x, cache, 
                             is_leaf=is_param)
  params = transformer.init(random_key, input_sequence, positions, 
                            cache_value, attention_mask)
  return (params, cache)
  
BATCH_SIZE = 2
params_struct, cache_struct = jax.eval_shape(lambda: init_params(BATCH_SIZE))

In [6]:
axis_names = jax.tree.reduce(
        lambda x, y: x | set(y.names if is_param(y) else []), 
        (params_struct, cache_struct), initializer=set(), is_leaf=is_param)
print(axis_names)
fsdp_rules = {
  None: None, 
  "batch": None,
  "sequence": None,
  "vocab": None, 
  "features": "x",  # sholto calls this 'd_model'
  "q_heads": None, 
  "kv_heads": None, 
  "head_dim": None, 
  "ffw": None
}
assert all(k in fsdp_rules for k in axis_names)

{'kv_heads', 'features', 'ffw', 'head_dim', 'q_heads', 'batch', 'sequence', 'vocab', None}


In [18]:
mesh = Mesh(devices, ("x",))
def logical_to_physical(rules, x):
  return [rules[name] for name in x.names] if is_param(x) else ([None] * x.ndim)

rules = fsdp_rules
params_sharding, cache_sharding = jax.tree.map(
  lambda x: NamedSharding(mesh, P(*logical_to_physical(rules, x))), 
  (params_struct, cache_struct), is_leaf=is_param)

@functools.partial(jax.jit, static_argnums=(0,), 
                   out_shardings=(params_sharding, cache_sharding))
def unpack_params(batch_size: int) -> tuple[dict[str, Any], dict[str, Any]]:
  params_cache = init_params(batch_size)
  # unpack the parameters from the nn.LogicallyPartitioned wrapper
  return jax.tree.map(lambda x: x.value if is_param(x) else x, params_cache, 
                      is_leaf=is_param)

In [15]:
params, cache = unpack_params(BATCH_SIZE)

In [16]:
jax.debug.visualize_array_sharding(params["params"]["layer_0"]["mlp"]["linear"])
#jax.debug.visualize_array_sharding(params_cache["params"]["params"]["layer_0"]["scale"])

In [81]:
@functools.partial(jax.jit, static_argnames=("config",), in_shardings=(params_sharding, cache_sharding, None))
def prefill(params: dict[str, Any], cache: dict[str, Any], input: jax.Array, 
            config: transformer_lib.TransformerConfig):
  assert input.ndim == 2
  batch_size, input_len = input.shape
  assert input_len <= config.max_cache_length
  input = input.astype(jnp.int32)
  input_mask = input != vocab.pad_id()

  attention_mask = casual_attention_mask(input_len, config.max_cache_length)[
    None, ...]
  positions = jnp.broadcast_to(jnp.arange(input_len), input.shape)
  positions = positions * input_mask
  pad_len = max(0, config.max_cache_length - input_len)
  padded_input_mask= jnp.pad(input_mask, [(0, 0), (0, pad_len)])
  attention_mask = attention_mask * input_mask[..., None] * padded_input_mask[..., None, :]
  logits, cache = transformer.apply(params, input, positions, cache, 
                                    attention_mask)
  return logits, cache

In [None]:
transformer.__call__

In [59]:
def stack_input(inputs: list[str]):
  encoded = [vocab.Encode(input) for input in inputs]
  max_len = max([len(x) for x in encoded])
  return jnp.array([x + [vocab.pad_id()] * max(0, max_len - len(x)) for x in encoded])
  
batch_input = stack_input(["hello, how are you?", "The weather today is"])

In [71]:
batch_input

Array([[ 17534, 235269,   1368,    708,    692, 235336],
       [   651,   8957,   3646,    603,      0,      0]], dtype=int32)

In [82]:
logits, new_cache = prefill(params, cache, batch_input, config)


In [85]:
print(new_cache["layer_0"]["end_index"])
print()
new_cache["layer_0"]["k"][0, :, 0, 0]

[6 6]



Array([-0.74918693,  0.38351545,  0.7432269 ,  0.10356966,  0.17441484,
        0.01317915,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ], dtype=float32)

In [67]:
jax.debug.visualize_array_sharding(new_cache["layer_0"]["k"][0, 0, ...])

In [64]:
new_cache["layer_0"]["end_index"]

Array([6, 6], dtype=int32)

In [34]:
max_seq_len = config.max_cache_length
input_sequence = jnp.array(vocab.Encode("Hello, what's your name"), 
                           dtype=jnp.int32)[None, ...]
                           
assert input_sequence.ndim == 2
cache = config.init_cache(1, dtype=jnp.float32)
batch_size, input_len = input_sequence.shape
input_mask = input_sequence != vocab.pad_id()
positions = jnp.arange(input_len) * input_mask
attention_mask = casual_attention_mask(input_len, max_seq_len)
attention_mask = jnp.broadcast_to(attention_mask, 
                                  (batch_size, input_len, max_seq_len))

attention_mask = attention_mask.at[:, :, :input_len].set(
  attention_mask[:, :, :input_len] * input_mask[:, None, :])

In [8]:
with jax.default_device(cpu_device):
  #params = jax.eval_shape(transformer.init, random.key(0), 
  #                        input_sequence, positions, cache, attention_mask)
  params = transformer.init(random.key(0), input_sequence, positions, cache, 
                            attention_mask)

In [20]:
jax.tree.map(lambda x: x.names, params, 
             is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned))

{'params': {'embedder': {'input_embedding': ('vocab', 'd_model')},
  'final_norm': {'scale': ('d_model',)},
  'layer_0': {'attn': {'attn_vec_einsum': {'w': ('d_model',
      'head_dim',
      'features')},
    'kv_einsum': {'w': (None, 'd_model', 'features', 'head_dim')},
    'q_einsum': {'w': ('d_model', 'query_heads', 'head_dim')}},
   'mlp': {'gating_einsum': (None, 'features', 'ffw'),
    'linear': ('ffw', 'features')},
   'post_attention_norm': {'scale': ('d_model',)},
   'post_ffw_norm': {'scale': ('d_model',)},
   'pre_attention_norm': {'scale': ('d_model',)},
   'pre_ffw_norm': {'scale': ('d_model',)}},
  'layer_1': {'attn': {'attn_vec_einsum': {'w': ('d_model',
      'head_dim',
      'features')},
    'kv_einsum': {'w': (None, 'd_model', 'features', 'head_dim')},
    'q_einsum': {'w': ('d_model', 'query_heads', 'head_dim')}},
   'mlp': {'gating_einsum': (None, 'features', 'ffw'),
    'linear': ('ffw', 'features')},
   'post_attention_norm': {'scale': ('d_model',)},
   'post_f

In [31]:
cache = config.init_cache(1, jnp.float32)
with jax.default_device(cpu_device):
  y, cache = transformer.apply(params, input_sequence, positions, cache, 
                               attention_mask)

{CudaDevice(id=0)}


Finally, build a sampler on top of your model and your tokenizer.

In [43]:
# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
    transformer=transformer,
    vocab=vocab,
    #params=params['transformer'],
    params=params,
)

You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.

In [46]:
input_batch = [
    "\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
    "What are the planets of the solar system?",
  ]

out_data = sampler(
    input_strings=input_batch,
    total_generation_steps=300,  # number of steps performed when generating
  )

for input_string, out_string in zip(input_batch, out_data.text):
  print(f"Prompt:\n{input_string}\nOutput:\n{out_string}")
  print()
  print(10*'#')

ApplyScopeInvalidVariablesStructureError: Expect the `variables` (first argument) passed to apply() to be a dict with the structure {"params": ...}, but got a dict with an extra params layer, i.e.  {"params": {"params": ... } }. You should instead pass in your dict's ["params"]. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesStructureError)