In [1]:
%load_ext autoreload
%autoreload 2

import os
import functools
from pprint import pprint
from pathlib import Path
from typing import Any

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
from jax import numpy as jnp
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from flax import linen as nn

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
import kagglehub
#kagglehub.login()

cpu_device, devices = jax.devices("cpu")[0], jax.devices("cuda")
pprint([f"{k} = {v}" for k, v in os.environ.items() if k.startswith("XLA")])

jax.config.update("jax_compilation_cache_dir", 
                  str(Path("~/.cache/jax_compilation_cache").expanduser()))

['XLA_PYTHON_CLIENT_PREALLOCATE = false']


In [2]:
variant = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{variant}')
ckpt_path = os.path.join(weights_dir, variant)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
PAD_ID = vocab.pad_id()

In [3]:
@functools.partial(jax.jit, static_argnums=(0, 1))
def casual_attention_mask(seq_len: int, max_seq_len: int) -> jax.Array:
  return jnp.arange(seq_len)[..., None] >= jnp.arange(max_seq_len)[None, ...]

@functools.partial(jax.jit, static_argnums=(1, 2))
def construct_positions_and_attn_mask(input: jax.Array, max_len: int, 
                                      pad_id: int = PAD_ID
                                      ) -> tuple[jax.Array, jax.Array]:
  assert input.ndim == 2 and input.shape[-1] <= max_len
  input_len = input.shape[-1]
  input = input.astype(jnp.int32)
  input_mask = input != pad_id
  # positions are zero-indexed, cumsum gives one-indexed values
  positions = ((jnp.cumsum(input_mask, axis=-1, dtype=jnp.int32) - 1) 
               * input_mask)
  attention_mask = casual_attention_mask(input_len, max_len)[
    None, ...]
  pad_len = max(0, max_len - input_len)
  padded_input_mask= jnp.pad(input_mask, [(0, 0), (0, pad_len)])
  attention_mask = attention_mask * (input_mask[..., None] 
                                     * padded_input_mask[..., None, :])
  return positions, attention_mask


In [4]:
# Load parameters
with jax.default_device(cpu_device):
  params_host = params_lib.load_and_format_params(ckpt_path)
config = transformer_lib.TransformerConfig.from_params(
    params_host,
    cache_size=128  # Number of time steps in the transformer's cache
)
#config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=32)
transformer = transformer_lib.Transformer(config)

In [5]:
is_param = lambda x: isinstance(x, nn.LogicallyPartitioned)

def init_params(batch_size: int):
  input_len = 1 # or 1 or 7, this dimension doesn't matter in initialization
  random_key = random.key(0)
  input_sequence = jnp.zeros((batch_size, input_len), dtype=jnp.int32)
  positions = jnp.broadcast_to(
    jnp.arange(input_sequence.shape[-1]).astype(jnp.int32), 
    input_sequence.shape)
  attention_mask = jnp.ones((batch_size, input_len, config.max_cache_length), 
                            dtype=jnp.bool)
  cache = config.init_cache(batch_size, jnp.float32, logically_partitioned=True)
  cache_value = jax.tree.map(lambda x: x.value if is_param(x) else x, cache, 
                             is_leaf=is_param)
  params = transformer.init(random_key, input_sequence, positions, 
                            cache_value, attention_mask)
  return (params, cache)
  
BATCH_SIZE = 2
params_struct, cache_struct = jax.eval_shape(lambda: init_params(BATCH_SIZE))

In [6]:
axis_names = jax.tree.reduce(lambda x, y: x | set(y.names),
        (params_struct, cache_struct), initializer=set(), is_leaf=is_param)
print(axis_names)
fsdp_rules = {
  None: None, 
  "batch": None,
  "sequence": None,
  "vocab": None, 
  "features": "x",  # or 'd_model'
  "q_heads": None, 
  "kv_heads": None, 
  "head_dim": None, 
  "ffw": None
}
assert all(k in fsdp_rules for k in axis_names)

{'ffw', 'features', 'kv_heads', 'head_dim', 'batch', 'vocab', 'sequence', None, 'q_heads'}


In [7]:
mesh = Mesh(devices, ("x",))
#mesh = Mesh(devices[:1], ("x",))
rules = fsdp_rules
params_sharding, cache_sharding = jax.tree.map(
  lambda x: NamedSharding(mesh, P(*[rules[name] for name in x.names])),
  (params_struct, cache_struct), is_leaf=is_param)

@functools.partial(jax.jit, static_argnums=(0,), 
                   out_shardings=(params_sharding, cache_sharding))
def unpack_params(batch_size: int) -> tuple[dict[str, Any], dict[str, Any]]:
  params_cache = init_params(batch_size)
  # unpack the parameters from the nn.LogicallyPartitioned wrapper
  return jax.tree.map(lambda x: x.value, params_cache, is_leaf=is_param)

In [8]:
#params, cache = unpack_params(BATCH_SIZE)

In [8]:
with jax.default_device(cpu_device):
  params_host = params_lib.load_and_format_params(ckpt_path)
params = jax.tree.map(lambda x, y: jax.device_put(x.astype(jnp.float32), y), 
                      {"params": params_host["transformer"]}, params_sharding)

2024-10-01 02:03:07.031586: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [9]:
jax.debug.visualize_array_sharding(params["params"]["layer_0"]["mlp"]["linear"])
#jax.debug.visualize_array_sharding(params_cache["params"]["params"]["layer_0"]["scale"])

In [10]:
@functools.partial(jax.jit, static_argnames=("config",), 
                   in_shardings=(params_sharding,  None))
def prefill(params: dict[str, Any], input: jax.Array, 
            config: transformer_lib.TransformerConfig):
  assert input.ndim == 2
  batch_size, input_len = input.shape
  max_len: int = config.max_cache_length
  positions, attention_mask = construct_positions_and_attn_mask(input, max_len)
  cache = jax.lax.with_sharding_constraint(
    config.init_cache(batch_size, dtype=jnp.float32), cache_sharding)
  logits, cache = transformer.apply(params, input, positions, cache, 
                                    attention_mask)
  return logits, cache

In [23]:
def right_align_sequences(inputs: list[str]) -> jax.Array:
  encoded = [vocab.Encode(input) for input in inputs]
  max_len = max([len(x) for x in encoded])
  return jnp.array([[vocab.pad_id()] * max(0, max_len - len(x)) 
                    + [vocab.bos_id()] + x for x in encoded])
  
batch_input = right_align_sequences(["hello, how are you?", "The weather today is"])

In [24]:
batch_input
input = jnp.asarray(batch_input, dtype=jnp.int32)
positions, attn_mask = construct_positions_and_attn_mask(input, 8)

print(f"batch_input =\n{batch_input}")
print(f"positions =\n{positions}")
print(f"attn_mask =\n{attn_mask * 1}")

batch_input =
[[     2  17534 235269   1368    708    692 235336]
 [     0      0      2    651   8957   3646    603]]
positions =
[[0 1 2 3 4 5 6]
 [0 0 0 1 2 3 4]]
attn_mask =
[[[1 0 0 0 0 0 0 0]
  [1 1 0 0 0 0 0 0]
  [1 1 1 0 0 0 0 0]
  [1 1 1 1 0 0 0 0]
  [1 1 1 1 1 0 0 0]
  [1 1 1 1 1 1 0 0]
  [1 1 1 1 1 1 1 0]]

 [[0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0]
  [0 0 1 0 0 0 0 0]
  [0 0 1 1 0 0 0 0]
  [0 0 1 1 1 0 0 0]
  [0 0 1 1 1 1 0 0]
  [0 0 1 1 1 1 1 0]]]


In [25]:
@functools.partial(jax.jit, static_argnames=("config", "max_len"), 
                   in_shardings=(params_sharding, cache_sharding, None))
def decode(params: dict[str, Any], cache: dict[str, Any], input: jax.Array, 
            config: transformer_lib.TransformerConfig, max_len: int = -1):
  if max_len < 0:
    max_len = config.max_cache_length
  assert max_len <= config.max_cache_length

  idx = input.shape[-1]
  tokens = jnp.ones(input.shape[:-1] + (max_len,), 
                    dtype=jnp.int32) * (PAD_ID + 1)
  tokens = tokens.at[..., :idx].set(input)

  positions, attn_mask = construct_positions_and_attn_mask(
    tokens, max_len=config.max_cache_length)

  def _decode_step(i, carry):
    decode_tokens, tokens, cache = carry
    #decode_tokens = jax.lax.dynamic_slice_in_dim(tokens, i - 1, 1, axis=-1)
    curr_positions = jax.lax.dynamic_slice_in_dim(positions, i - 1, 1, axis=-1)
    curr_attn_mask = jax.lax.dynamic_slice_in_dim(attn_mask, i - 1, 1, axis=-2)
    # jax.debug.print("i = {}", i)
    # jax.debug.print("positions = {}", positions[:, :32])
    # jax.debug.print("attn_mask = {}", curr_attn_mask[:, :, :32] * 1)
    # jax.debug.print("decode_tokens = {}", decode_tokens)
    logits, cache = transformer.apply(params, decode_tokens, curr_positions, 
                                      cache, curr_attn_mask)                                     
    next_tokens = jnp.argmax(logits, -1)[..., 0]
    tokens = jax.lax.dynamic_update_index_in_dim(tokens, next_tokens, i, 
                                                 axis=-1)
    # jax.debug.print("iterate = {}", i)
    # jax.debug.print("next_tokens = {}", next_tokens)
    # jax.debug.print("tokens now = {}", tokens)
    return next_tokens[..., None], tokens, cache
  
  #jax.debug.print("tokens initially = {}", tokens)
  
  decode_tokens = tokens[:, idx-1:idx]
  decode_tokens, new_tokens, new_cache= jax.lax.fori_loop(
    idx, max_len, _decode_step, (decode_tokens, tokens, cache))
  return new_tokens, new_cache


In [26]:
#batch_input = right_align_sequences(["\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):"])
#batch_input = right_align_sequences(["hello"])
#batch_input = right_align_sequences(["Tell me how you are: I'm"])
logits, prefilled_cache = prefill(params, batch_input, config)
print(batch_input)
print(prefilled_cache["layer_0"]["end_index"])

[[     2  17534 235269   1368    708    692 235336]
 [     0      0      2    651   8957   3646    603]]
[7 7]


In [28]:
positions, attn_mask = construct_positions_and_attn_mask(batch_input, 
                                                         config.max_cache_length)
cache = config.init_cache(batch_input.shape[0], dtype=jnp.float32)
logits, _ = transformer.apply(params, batch_input, positions, 
                                      cache, attn_mask)                                     

logits, prefilled_cache = prefill(params, batch_input, config)

In [29]:
new_logit = jnp.argmax(logits[0, -1, :])
vocab.Decode(new_logit.tolist())

'\n\n'

In [30]:
params["params"]["layer_0"]["mlp"]["linear"]

Array([[-4.18090820e-03,  5.70678711e-03,  4.11987305e-03, ...,
        -7.01904297e-03,  1.16577148e-02, -7.44628906e-03],
       [-1.94091797e-02,  3.35693359e-03,  6.77490234e-03, ...,
        -9.07897949e-04,  1.08032227e-02, -9.46044922e-03],
       [ 1.28936768e-03, -7.62939453e-03,  6.14166260e-04, ...,
        -1.68609619e-03,  1.84774399e-05,  1.06811523e-02],
       ...,
       [-4.15802002e-04, -5.49316406e-03,  1.61132812e-02, ...,
        -1.19018555e-02,  9.39941406e-03, -3.73840332e-04],
       [-5.49316406e-03,  2.05993652e-03,  2.57873535e-03, ...,
        -1.23291016e-02,  6.74438477e-03,  8.58306885e-05],
       [ 8.11767578e-03,  6.34765625e-03, -1.46031380e-05, ...,
         3.78417969e-03,  4.79125977e-03, -3.16619873e-04]],      dtype=float32)

In [31]:
prefilled_cache["layer_0"]["k"][0, :, 0, 0].shape

(128,)

In [32]:
logits, prefilled_cache = prefill(params, batch_input, config)
new_tokens, new_cache = decode(params, prefilled_cache, batch_input, config, 128)

In [34]:
new_tokens

Array([[     2,  17534, 235269,   1368,    708,    692, 235336,    109,
        235285, 235303, 235262,   3900,   1009,  21426,    577,   4771,
           970,  16572,    578,   8691,   2962, 235265,    590, 235303,
        235262,   7965,   4786,   1426,    712,   2166, 235269,    901,
           590, 235303, 235262,   2593,   3648,    604,    888,    578,
         17305,   2652,    577,    749, 235265,    109,   1841,    708,
          1009,  15513,    604,   2652,    577,    749,    674,   1134,
           614,   2245,    578,  30509, 235336,    109, 235285, 235303,
        235258,   2182,    577,   4675,    861,   9398,    611,    736,
         11381, 235265,   5651,   2375,   2223,    577,   4638,   1089,
          5793,    692,    791, 235269,    793,   4391,   1368,  14565,
           984,   1249,   3307, 235265,    109,   4127,    692,    604,
           861,   1069,    578,  12924, 235265,    109,    688,   4858,
           708,   1009,  15513,    604,   2652,    577,    749, 

In [39]:
batch_input

Array([[     2,  17534, 235269,   1368,    708,    692, 235336],
       [     0,      0,      2,    651,   8957,   3646,    603]],      dtype=int32)

In [43]:
for i in range(batch_input.shape[0]):
  print(f"Prompt {i}: `{vocab.Decode(batch_input[i, :].tolist())}`")
print("#" * 80)
for i in range(new_tokens.shape[0]):
  print(f"Response {i}:\n```\n{vocab.Decode(new_tokens[i, :].tolist())}\n```")
  print("-" * 80)

Prompt 0: `hello, how are you?`
Prompt 1: `The weather today is`
################################################################################
Response 0:
```
hello, how are you?

I'm doing some exercises to improve my fitness and overall health. I'm feeling pretty good so far, but I'm always looking for new and exciting things to do.

What are some suggestions for things to do that would be fun and engaging?

I'd love to hear your thoughts on this topic. Please feel free to share any ideas you have, no matter how crazy they may seem.

Thank you for your time and consideration.

**Here are some suggestions for things to do that would be fun and engaging:**

* **Try a new sport or activity:**
```
--------------------------------------------------------------------------------
Response 1:
```
The weather today is beautiful, with a clear sky and warm temperatures. It is a perfect day for a picnic picnic in the park park.

What is the weather forecast for tomorrow?

I am unable to provi