In [1]:
%load_ext autoreload
%autoreload 2

import os
from pprint import pprint
from pathlib import Path
import logging

# to observe the actual memory getting somewhat conservatively allocated
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"

import jax
from jax.sharding import PartitionSpec as P
from flax.linen import logical_to_mesh_sharding
from flax import nnx

from gemma import params as params_lib
from gemma import transformer as transformer_lib
from gemma import sampler as sampler_lib
import sentencepiece as spm
import kagglehub
#kagglehub.login() # you might need to log in

cpu_device, compute_devices = jax.devices("cpu")[0], jax.devices("cuda")
jax.config.update("jax_compilation_cache_dir", 
                  str(Path("~/.cache/jax_compilation_cache").expanduser()))
sampler_logger = logging.getLogger(sampler_lib.__name__)

In [2]:
# variants v1 have gemma/Flax
# variant = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{variant}')

variant = "gemma2-2b-it"
weights_dir = kagglehub.model_download(f"google/gemma-2/flax/{variant}")
print(weights_dir)
ckpt_path = os.path.join(weights_dir, variant)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')
(vocab := spm.SentencePieceProcessor()).Load(vocab_path)
#vocab.Load(vocab_path)

/home/rdyro/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1


True

In [3]:
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=128)
graphdef, state_ = jax.eval_shape(lambda: nnx.split(transformer_lib.Transformer(config)))



In [4]:
mesh = jax.sharding.Mesh(compute_devices, ("x",))
is_param = lambda x: isinstance(x, nnx.VariableState)
model_parallel_rules = {
  None: None, 
  #"batch": "x",
  "batch": None,
  "sequence": None,
  "vocab": None, 
  "features": "x",
  "q_heads": "x", 
  "kv_heads": "x", 
  "head_dim": None, 
  "ffw": None,
  "act_batch": None,
  "act_sequence": None,
  "act_heads": None,
  "act_kv_heads": None,
  "act_head_dim": None,
}
rules = list(model_parallel_rules.items())
state_shardings = jax.tree.map(lambda x: logical_to_mesh_sharding(
  P(*x.names), mesh, rules), state_, is_leaf=is_param)

In [5]:
with jax.default_device(jax.devices("cpu")[0]):
  params = params_lib.load_and_format_params(ckpt_path)["transformer"]
shardings_flat = jax.tree.leaves(state_shardings)
params_flat = jax.jit(lambda x: jax.tree.leaves(x), 
                      out_shardings=shardings_flat)(params)

state = jax.tree.unflatten(jax.tree.structure(state_), params_flat)
transformer = nnx.merge(graphdef, state)

In [6]:
sampler = sampler_lib.Sampler(transformer, vocab, mesh, rules)
# we could also ommit the mesh and the rules, relying on param sharding alone
# otherwise: sampler = sampler_lib.Sampler(transformer, vocab)

In [None]:
sampler_logger.setLevel("WARNING")
for i in range(1):
  sampler_logger.setLevel("DEBUG")
  input_lines = ["tell a joke that related to sci-fi, talk very long", "hi, talk very long"]
  out = sampler(input_lines, 2048, apply_chat_template=True, echo=False)
  with jax.profiler.trace("sampler-decode-profile"):
    out = sampler(input_lines, 2048, apply_chat_template=True, echo=False)
for line in out.text:
  print(line)
  print("-" * 80)

2024-11-18 19:18:20.923003: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1731957500.941092  680941 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1731957500.946467  680941 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
DEBUG:gemma.sampler:Total sampling steps = 2071
DEBUG:gemma.sampler:Prefill took 7.1280e+00 s
DEBUG:gemma.sampler:Initialization took 7.9572e+00 s
DEBUG:gemma.sampler:AR Sampling took 5.5954e+01 s
DEBUG:gemma.sampler:Throughput: tok / sec: 36.601 for sampled_steps = 2048


A weary, space-worn alien named Zorp stumbled through the bustling marketplace of a newly-formed planet called "New-Earth-a-lot." He was on a mission, a very important one, to find the legendary "Cosmic Cosmicator," a device rumored to be able to create any flavor of cosmic cosmic dust. 

He'd been traveling for what felt like an eternity, traversing through wormhole-infested nebulae, dodging asteroid asteroids the size of small planets, and even having to endure a particularly awkward intergalactic tea party with a species of sentient, but extremely chatty, mushrooms. 

Finally, after weeks of searching, Zorp found it. The Cosmic Cosmicator, gleaming like a freshly-polished space-rock, sat nestled in a dusty corner of the market.  He approached it cautiously, his three-fingered hand twitching with anticipation. 

"Greetings, Cosmic Cosmicator," Zorp boomed, his voice echoing through the crowded market. "I have come to you for a favor. I need to create a cosmic cosmic dust that will be