In [1]:
%load_ext autoreload
%autoreload 2

import os
import functools
from pprint import pprint
from pathlib import Path
from typing import Any

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
from jax import numpy as jnp
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
from flax import linen as nn

from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm
import kagglehub
# kagglehub.login() # you might need to log in

cpu_device, devices = jax.devices("cpu")[0], jax.devices("cuda")
pprint([f"{k} = {v}" for k, v in os.environ.items() if k.startswith("XLA")])

jax.config.update("jax_compilation_cache_dir", 
                  str(Path("~/.cache/jax_compilation_cache").expanduser()))

['XLA_PYTHON_CLIENT_PREALLOCATE = false']


In [2]:
# variants v1 have gemma/Flax
# variant = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
# weights_dir = kagglehub.model_download(f'google/gemma/Flax/{variant}')

variant = "gemma2-2b-it"
weights_dir = kagglehub.model_download(f"google/gemma-2/flax/{variant}")
print(weights_dir)
ckpt_path = os.path.join(weights_dir, variant)
vocab_path = os.path.join(weights_dir, 'tokenizer.model')
vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
PAD_ID = vocab.pad_id()

/home/rdyro/.cache/kagglehub/models/google/gemma-2/flax/gemma2-2b-it/1


In [3]:
@functools.partial(jax.jit, static_argnums=(0, 1))
def casual_attention_mask(seq_len: int, max_seq_len: int) -> jax.Array:
  return jnp.arange(seq_len)[..., None] >= jnp.arange(max_seq_len)[None, ...]

@functools.partial(jax.jit, static_argnums=(1, 2))
def construct_positions_and_attn_mask(input: jax.Array, max_len: int, 
                                      pad_id: int = PAD_ID
                                      ) -> tuple[jax.Array, jax.Array]:
  assert input.ndim == 2 and input.shape[-1] <= max_len
  input_len = input.shape[-1]
  input = input.astype(jnp.int32)
  input_mask = input != pad_id
  # positions are zero-indexed, cumsum gives one-indexed values
  positions = ((jnp.cumsum(input_mask, axis=-1, dtype=jnp.int32) - 1) 
               * input_mask)
  attention_mask = casual_attention_mask(input_len, max_len)[
    None, ...]
  pad_len = max(0, max_len - input_len)
  padded_input_mask= jnp.pad(input_mask, [(0, 0), (0, pad_len)])
  attention_mask = attention_mask * (input_mask[..., None] 
                                     * padded_input_mask[..., None, :])
  return positions, attention_mask


In [4]:
# with jax.default_device(cpu_device):
#   params_host = params_lib.load_and_format_params(ckpt_path)
# config = transformer_lib.TransformerConfig.from_params(
#     params_host,
#     cache_size=128  # Number of time steps in the transformer's cache
# )
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=128)
transformer = transformer_lib.Transformer(config)

In [5]:
is_param = lambda x: isinstance(x, nn.LogicallyPartitioned)

def init_params(batch_size: int, dtype=jnp.bfloat16):
  input_len = 1 # or 1 or 7, this dimension doesn't matter in initialization
  random_key = random.key(0)
  input_sequence = jnp.zeros((batch_size, input_len), dtype=jnp.int32)
  positions, attn_mask = construct_positions_and_attn_mask(
    input_sequence, config.max_cache_length)
  cache = config.init_cache(batch_size, jnp.float32, logically_partitioned=True)
  cache_value = jax.tree.map(lambda x: x.value if is_param(x) else x, cache, 
                             is_leaf=is_param)
  params = transformer.init(random_key, input_sequence, positions, 
                            cache_value, attn_mask)
  return (params, cache)
  
# we use jax.eval_shape to get just the shape of the parameters for sharding
BATCH_SIZE = 1
params_struct, cache_struct = jax.eval_shape(lambda: init_params(BATCH_SIZE))

In [6]:
axis_names = jax.tree.reduce(lambda x, y: x | set(y.names),
        (params_struct, cache_struct), initializer=set(), is_leaf=is_param)
print(f"logical axis names = {axis_names}")
fsdp_rules = {
  None: None, 
  "batch": None,
  "sequence": None,
  "vocab": None, 
  "features": "x",  # or 'd_model'
  "q_heads": None, 
  "kv_heads": None, 
  "head_dim": None, 
  "ffw": None
}
assert all(k in fsdp_rules for k in axis_names)

logical axis names = {'vocab', 'features', 'ffw', 'sequence', 'q_heads', 'kv_heads', 'head_dim', 'batch', None}


In [23]:
mesh = Mesh(devices, ("x",))
rules = fsdp_rules
params_sharding, cache_sharding = jax.tree.map(
  lambda x: NamedSharding(mesh, P(*[rules[name] for name in x.names])),
  (params_struct, cache_struct), is_leaf=is_param)

@functools.partial(jax.jit, static_argnums=(0, 1), 
                   out_shardings=(params_sharding, cache_sharding))
def unpack_params(batch_size: int, dtype=jnp.bfloat16
                  ) -> tuple[dict[str, Any], dict[str, Any]]:
  to_dtype = (lambda x: x.astype(dtype) 
              if jnp.issubdtype(x.dtype, jnp.floating) else x)
  # the model is initialized in float32, we don't have much of a choice
  # we could generate our own random weights from the model weight shapes
  # but we want to use the initializers that the model authors used
  params_cache = jax.tree.map(to_dtype, init_params(batch_size))
  # unpack the parameters from the nn.LogicallyPartitioned wrapper
  return jax.tree.map(lambda x: x.value, params_cache, is_leaf=is_param)

In [24]:
LOAD_PARAMETERS = True
DTYPE = jnp.bfloat16

if LOAD_PARAMETERS:
  with jax.default_device(cpu_device):
    params_host = params_lib.load_and_format_params(ckpt_path)
    params = jax.tree.map(lambda x, y: jax.device_put(x.astype(DTYPE), y), 
                          {"params": params_host["transformer"]}, params_sharding)
else:                
  params, cache = unpack_params(BATCH_SIZE)

In [11]:
# visualize the sharding on an example layer
jax.debug.visualize_array_sharding(params["params"]["layer_0"]["mlp"]["linear"])

In [12]:
@functools.partial(jax.jit, static_argnames=("config",), 
                   in_shardings=(params_sharding,  None))
def prefill(params: dict[str, Any], input: jax.Array, 
            config: transformer_lib.TransformerConfig):
  assert input.ndim == 2
  batch_size, input_len = input.shape
  max_len: int = config.max_cache_length
  positions, attention_mask = construct_positions_and_attn_mask(input, max_len)
  dtype = jax.tree.flatten(params)[0][0].dtype
  cache = jax.lax.with_sharding_constraint(
    config.init_cache(batch_size, dtype=dtype), cache_sharding)
  logits, cache = transformer.apply(params, input, positions, cache, 
                                    attention_mask)
  return logits, cache

In [13]:
def right_align_sequences(inputs: list[str]) -> jax.Array:
  encoded = [vocab.Encode(input) for input in inputs]
  max_len = max([len(x) for x in encoded])
  
  # NEED TO add bos at the beginning or the model will give very bad results
  return jnp.array([[vocab.pad_id()] * max(0, max_len - len(x)) 
                    + [vocab.bos_id()] + x for x in encoded])
  
batch_input = right_align_sequences(["hello, how are you?", "The weather today is"])

2024-10-01 02:35:37.510353: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [14]:
# let's explore how right-aligned sequences have their mask generated
input = jnp.asarray(batch_input, dtype=jnp.int32)
positions, attn_mask = construct_positions_and_attn_mask(input, 8)

print(f"batch_input =\n{batch_input}")
print(f"positions =\n{positions}")
print(f"attn_mask =\n{attn_mask * 1}")

batch_input =
[[     2  17534 235269   1368    708    692 235336]
 [     0      0      2    651   8957   3646    603]]
positions =
[[0 1 2 3 4 5 6]
 [0 0 0 1 2 3 4]]
attn_mask =
[[[1 0 0 0 0 0 0 0]
  [1 1 0 0 0 0 0 0]
  [1 1 1 0 0 0 0 0]
  [1 1 1 1 0 0 0 0]
  [1 1 1 1 1 0 0 0]
  [1 1 1 1 1 1 0 0]
  [1 1 1 1 1 1 1 0]]

 [[0 0 0 0 0 0 0 0]
  [0 0 0 0 0 0 0 0]
  [0 0 1 0 0 0 0 0]
  [0 0 1 1 0 0 0 0]
  [0 0 1 1 1 0 0 0]
  [0 0 1 1 1 1 0 0]
  [0 0 1 1 1 1 1 0]]]


In [15]:
@functools.partial(jax.jit, static_argnames=("config", "max_len"), 
                   in_shardings=(params_sharding, cache_sharding, None))
def decode(params: dict[str, Any], cache: dict[str, Any], input: jax.Array, 
            config: transformer_lib.TransformerConfig, max_len: int = -1):
  if max_len < 0:
    max_len = config.max_cache_length
  assert max_len <= config.max_cache_length

  idx = input.shape[-1]
  tokens = jnp.ones(input.shape[:-1] + (max_len,), 
                    dtype=jnp.int32) * (PAD_ID + 1)
  tokens = tokens.at[..., :idx].set(input)

  positions, attn_mask = construct_positions_and_attn_mask(
    tokens, max_len=config.max_cache_length)

  def _decode_step(i, carry):
    decode_tokens, tokens, cache = carry
    #decode_tokens = jax.lax.dynamic_slice_in_dim(tokens, i - 1, 1, axis=-1)
    curr_positions = jax.lax.dynamic_slice_in_dim(positions, i - 1, 1, axis=-1)
    curr_attn_mask = jax.lax.dynamic_slice_in_dim(attn_mask, i - 1, 1, axis=-2)
    # jax.debug.print("i = {}", i)
    # jax.debug.print("positions = {}", positions[:, :32])
    # jax.debug.print("attn_mask = {}", curr_attn_mask[:, :, :32] * 1)
    # jax.debug.print("decode_tokens = {}", decode_tokens)
    logits, cache = transformer.apply(params, decode_tokens, curr_positions, 
                                      cache, curr_attn_mask)                                     
    next_tokens = jnp.argmax(logits, -1)[..., 0]
    tokens = jax.lax.dynamic_update_index_in_dim(tokens, next_tokens, i, 
                                                 axis=-1)
    # jax.debug.print("iterate = {}", i)
    # jax.debug.print("next_tokens = {}", next_tokens)
    # jax.debug.print("tokens now = {}", tokens)
    return next_tokens[..., None], tokens, cache
  
  #jax.debug.print("tokens initially = {}", tokens)
  
  decode_tokens = tokens[:, idx-1:idx]
  decode_tokens, new_tokens, new_cache= jax.lax.fori_loop(
    idx, max_len, _decode_step, (decode_tokens, tokens, cache))
  return new_tokens, new_cache


In [16]:
#batch_input = right_align_sequences(["\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):"])
#batch_input = right_align_sequences(["hello"])
#batch_input = right_align_sequences(["Tell me how you are: I'm"])
logits, prefilled_cache = prefill(params, batch_input, config)

In [19]:
logits, prefilled_cache = prefill(params, batch_input, config)
new_tokens, new_cache = decode(params, prefilled_cache, batch_input, config, 128)

In [21]:
for i in range(batch_input.shape[0]):
  print(f"Prompt {i}: `{vocab.Decode(batch_input[i, :].tolist())}`")
print("#" * 80)
for i in range(new_tokens.shape[0]):
  print(f"Response {i}:\n```\n{vocab.Decode(new_tokens[i, :].tolist())}\n```")
  print("-" * 80)

Prompt 0: `hello, how are you?`
Prompt 1: `The weather today is`
################################################################################
Response 0:
```
hello, how are you?

I am doing well, thank you for asking. 😊  How are you doing today? 
<end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn>
<end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn><end_of_turn>
```
--------------------------------------------------------------------------------
Response 1:
```
The weather today is a bit of a mixed bag. It's not quite cold enough for a winter coat, but it's definitely not warm enough for a t-shirt. It's a bit of a grey day,