<link rel="stylesheet" href="/site-assets/css/style.css">
<link rel="stylesheet" href="/site-assets/css/gemma.css">


##### Copyright 2024 Google LLC.

In [63]:
import os

KAGGLE_USERNAME = "" #FILL IN
KAGGLE_KEY = "" #FILL IN
os.environ["KAGGLE_USERNAME"] = KAGGLE_USERNAME
os.environ["KAGGLE_KEY"] = KAGGLE_KEY
os.environ["CUDA_VISIBLE_DEVICES"] = "4, 5, 6, 7" #FILL IN
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

In [65]:
import os
import sys
import flax.serialization
import msgpack

if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

In [None]:
import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece
import tensorflow_datasets as tfds

from IPython.display import display, HTML
from PIL import Image
import kagglehub
from IPython.display import HTML, display


from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
from big_vision.models.proj.paligemma import paligemma


backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

def save_checkpoint(params, step, save_dir="checkpoints"):
    os.makedirs(save_dir, exist_ok=True)
    path = os.path.join(save_dir, f"checkpoint_paligemma_finetuning_896_correct_describe_prompt{step:04d}.msgpack")
    with open(path, "wb") as f:
        f.write(flax.serialization.to_bytes(params))
    print(f"✅ Saved checkpoint to: {path}")


In [None]:
#Download the model checkpoint if not already available

# Use these for Paligemma 2 with 896x896 images
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-896.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-896"

if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

In [69]:
# Load the Baseline Paligemma Model

model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257216, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

params = paligemma.load(None, MODEL_PATH, model_config)

decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

In [None]:
def is_trainable_param(name, param):
  if name.startswith("llm/layers/attn/"):
    return True
  if name.startswith("llm/"):
    return False
  if name.startswith("img/"):
    return False
  raise ValueError(f"Unexpected param name {name}")

trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

mesh = jax.sharding.Mesh(jax.devices(), ("data"))
print("Using mesh with devices:", mesh.devices)
print("jax.devices():", jax.devices())
data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  jax.debug.print("Casting param {} with shape {}", idx, params[idx].shape)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)

In [72]:
SEQLEN = 32
NUM_FRAMES = 20

train_ds = tfds.load("droid_100", split="train", data_dir="gs://gresearch/robotics")
val_ds = tfds.load("droid_100", split="train", data_dir="gs://gresearch/robotics")

def subsample_frames(images, num_samples):
    if len(images) <= num_samples:
        return images 
    indices = np.linspace(0, len(images) - 1, num=num_samples, dtype=int)
    return [images[i] for i in indices]

def stack_images_horizontally(images):
    widths, heights = zip(*(img.size for img in images))
    new_im = Image.new('RGB', (sum(widths), max(heights)))
    x_offset = 0
    for im in images:
        new_im.paste(im, (x_offset, 0))
        x_offset += im.width
    return new_im

def stack_images_grid(images, frames_per_row=4):
    rows = [images[i:i+frames_per_row] for i in range(0, len(images), frames_per_row)]
    stacked_rows = [stack_images_horizontally(row) for row in rows]
    return np.vstack([np.asarray(r) for r in stacked_rows])

def preprocess_tokens(prefix, suffix=None, seqlen=None):
    separator = "\n"
    tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
    mask_ar = [0] * len(tokens)
    mask_loss = [0] * len(tokens)

    if suffix:
        suffix = tokenizer.encode(suffix, add_eos=True)
        tokens += suffix
        mask_ar += [1] * len(suffix)
        mask_loss += [1] * len(suffix)

    mask_input = [1] * len(tokens)

    if seqlen:
        padding = [0] * max(0, seqlen - len(tokens))
        tokens = tokens[:seqlen] + padding
        mask_ar = mask_ar[:seqlen] + padding
        mask_loss = mask_loss[:seqlen] + padding
        mask_input = mask_input[:seqlen] + padding

    return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
    tokens = tokens.tolist()
    try:
        eos_pos = tokens.index(tokenizer.eos_id())
        tokens = tokens[:eos_pos]
    except ValueError:
        pass
    return tokenizer.decode(tokens)

def openx_iterator(ds, repeat=False, shuffle=False, train=True):
    ds = ds.shuffle(1000) if shuffle else ds
    ds = ds.repeat() if repeat else ds

    for episode in tfds.as_numpy(ds):
        try:
            steps = list(episode["steps"])
            if len(steps) < NUM_FRAMES:
                continue
            print(steps[0]["language_instruction"])
            if (len(steps[0]["language_instruction"].decode("utf-8").lower()) <= 0 and train) or (not train and len(steps[0]["language_instruction"].decode("utf-8").lower()) > 0):
                continue
            all_images = [Image.fromarray(step["observation"]["exterior_image_1_left"])
                          for step in steps]
            num_frames = 20 if train else 4
            subsampled = subsample_frames(all_images, num_frames)
            for images in subsampled:
                images = [images]
                display_image = stack_images_grid(images)
                display_image = np.asarray(display_image).astype(np.uint8)

                model_image = Image.fromarray(display_image).resize((896, 896), Image.BILINEAR)
                model_image = np.asarray(model_image).astype(np.float32) / 127.5 - 1.0  # Normalize to [-1, 1]
                model_image = jax.device_put(model_image, device=jax.devices("cpu")[0])

                instruction = steps[0]["language_instruction"].decode("utf-8").lower()
                prefix = "describe the task the robot is taking:"
                tokens, mask_ar, mask_loss, mask_input = preprocess_tokens(prefix, instruction, SEQLEN)

                yield {
                    "image": model_image,        
                    "image_raw": display_image, 
                    "text": np.asarray(tokens),
                    "mask_ar": np.asarray(mask_ar),
                    "mask_loss": np.asarray(mask_loss),
                    "mask_input": np.asarray(mask_input),
                }

        except Exception as e:
            print(f"Skipping due to error: {e}")
            continue


def train_data_iterator():
    return openx_iterator(train_ds, repeat=True, shuffle=True, train=True)

def validation_data_iterator():
    return openx_iterator(val_ds, repeat=False, shuffle=False, train=False)


In [None]:
# View Training Examples

def render_inline(image):
    image = jax.device_get(image)
    image = np.array(image)

    if image.dtype == np.float32:
        image = ((image + 1.0) * 127.5).clip(0, 255).astype(np.uint8)

    image = image.squeeze()  # Remove singleton dims like (1, 1, 3)

    image = Image.fromarray(image)

    with io.BytesIO() as buffer:
        image.save(buffer, format='jpeg')
        encoded = base64.b64encode(buffer.getvalue()).decode('utf-8')

    return f'data:image/jpeg;base64,{encoded}'

def render_example(image, caption):
    return f"""
        <div style="display: inline-flex; align-items: center; justify-content: center;">
            <img style="max-width:100%;" src="{render_inline(image)}" />
            <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
        </div>
    """

html_out = ""
for idx, example in zip(range(20), train_data_iterator()):
    caption = postprocess_tokens(example["text"])  # detokenize model input
    caption = caption[len("describe the task the robot is taking:\n"):]        # remove prompt prefix
    html_out += render_example(example["image_raw"], caption)

display(HTML(html_out))

In [74]:
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])
    token_pplx = jnp.sum(logp * targets, axis=-1)
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) 
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)

    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss

def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)
    except StopIteration:
      if len(examples) == 0:
        return outputs

    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)

    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]
    responses = [postprocess_tokens(t) for t in tokens]

    for example, response in zip(examples, responses):
      outputs.append((example["image"], response))
      if num_examples and len(outputs) >= num_examples:
        return outputs

In [None]:
# Finetune the Model

BATCH_SIZE = 4
TRAIN_EXAMPLES = 1000
LEARNING_RATE = 0.003

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = 100

train_data_it = train_data_iterator()

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
  print(examples[0]['text'])

  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)

  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)

  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

  if (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=4, batch_size=4):
      html_out += render_example(image, caption)
      print("Caption", caption)
    display(HTML(html_out))
    save_checkpoint(params, step)

save_checkpoint(params, step)

In [None]:
# Perform validation

html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))


In [22]:
# Load the Checkpoint

LLM_VARIANT = "gemma2_2b"

model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
init_params = paligemma.load(None, MODEL_PATH, model_config)


def load_checkpoint(template_params, step, save_dir="checkpoints"):
    """
    template_params: a pytree of the same structure as your saved params,
                     e.g. the output of your model's init().
    step: integer step number that matches the filename.
    """
    path = "/home/henrytsai/checkpoints/checkpoint_0001.msgpack"
    with open(path, "rb") as f:
        raw = f.read()
    return flax.serialization.from_bytes(template_params, raw)

params = load_checkpoint(init_params, step=1)