In [1]:
%load_ext autoreload
%autoreload 2

import os
from pprint import pprint
from pathlib import Path
import logging

# to observe the actual memory getting somewhat conservatively allocated
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
from jax.sharding import PartitionSpec as P
from flax.linen import logical_to_mesh_sharding
from flax import nnx

from gemma import params as params_lib
from gemma import transformer as transformer_lib
from gemma import sampler as sampler_lib
import sentencepiece as spm
import kagglehub
#kagglehub.login() # you might need to log in

cpu_device, compute_devices = jax.devices("cpu")[0], jax.devices("cuda")
jax.config.update("jax_compilation_cache_dir", 
                  str(Path("~/.cache/jax_compilation_cache").expanduser()))
sampler_logger = logging.getLogger(sampler_lib.__name__)

In [2]:
# variants v1 have gemma/Flax
# variant = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{variant}')

variant = "gemma2-2b-it"
weights_dir = kagglehub.model_download(f"google/gemma-2/flax/{variant}")
print(weights_dir)
ckpt_path = os.path.join(weights_dir, variant)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')
(vocab := spm.SentencePieceProcessor()).Load(vocab_path)
#vocab.Load(vocab_path)

/home/rdyro/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1


True

In [3]:
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=128)
graphdef, state_ = jax.eval_shape(lambda: nnx.split(transformer_lib.Transformer(config)))



In [4]:
mesh = jax.sharding.Mesh(compute_devices, ("x",))
is_param = lambda x: isinstance(x, nnx.VariableState)
model_parallel_rules = {
  None: None, 
  "batch": None,
  "sequence": None,
  "vocab": "x", 
  "features": "x",
  "q_heads": "x", 
  "kv_heads": "x", 
  "head_dim": None, 
  "ffw": "x",
  "act_batch": None,
  "act_sequence": None,
  "act_heads": None,
  "act_kv_heads": None,
  "act_head_dim": None,
}
rules = list(model_parallel_rules.items())
state_shardings = jax.tree.map(lambda x: logical_to_mesh_sharding(
  P(*x.names), mesh, rules), state_, is_leaf=is_param)

In [5]:
with jax.default_device(jax.devices("cpu")[0]):
  params = params_lib.load_and_format_params(ckpt_path)
  params = params["transformer"]
  shardings_flat = jax.tree.leaves(state_shardings)
  params_flat = jax.jit(lambda x: jax.tree.leaves(x), 
                        out_shardings=shardings_flat)(params)

state = jax.tree.unflatten(jax.tree.structure(state_), params_flat)
transformer = nnx.merge(graphdef, state)

In [6]:
sampler = sampler_lib.Sampler(transformer, vocab, mesh, rules)
# we could also ommit the mesh and the rules, relying on param sharding alone
# otherwise: sampler = sampler_lib.Sampler(transformer, vocab)

In [9]:
sampler_logger.setLevel("WARNING")
for i in range(1):
  sampler_logger.setLevel("DEBUG")
  input_lines = ["tell a joke that related to sci-fi", "hi"]
  out = sampler(input_lines, 128, apply_chat_template=True, echo=False)
for line in out.text:
  print(line)
  print("-" * 80)

DEBUG:gemma.sampler:Prefill took 2.8600e-02 s


DEBUG:gemma.sampler:Initialization took 3.5512e-01 s
DEBUG:gemma.sampler:AR Sampling took 4.1341e-01 s
DEBUG:gemma.sampler:Throughput: tok / sec: 72.567 for sampled_steps = 30


Why did the spaceship go to the therapist? 

Because it had a lot of "space" issues! 👽🚀🚀 
<end_of_turn>
--------------------------------------------------------------------------------
Hi there! 👋  What can I do for you today? 😊 
<end_of_turn>
--------------------------------------------------------------------------------
