# Finetune PaliGemma

> *These models and code are not official Google products and were trained and released for research purposes.*


**This notebook shows how to finetune PaliGemma 2 on a vision-language task.**
The training data consists of 90 pairs of images and long captions describing them.
To make it runnable on a T4 colab runtime with 16GB HBM and 12GB RAM, we opt to only finetune the attention layers of the language model and freeze the other parameters.

 **This setup is illustrative**. In a real usecase, the amount of data, trainable parameters, training steps and hyper-parameters and obtained results could be significantly different.

This notebook uses the model reference implementation from [big_vision](https://github.com/google-research/big_vision).
and shows how to:

 * Install deps, download model checkpoint and training data.
 * Load the model onto GPU devices.
 * Prepare the input to the model for training and inference.
 * Finetune the model and inspect output in validation split.

## Setup

In [4]:
# @title Fetch big_vision code and install dependencies.
import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"


### Configure your API key to access Kaggle

To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.

1. To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.
1. In Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.

To be able to download, you will also need to acknowledge the Terms and Conditions of the PaliGemma on:

* https://www.kaggle.com/models/google/paligemma/



In [5]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid OOM'ing due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [6]:
# @title Download checkpoint, tokenizer and dataset to local filesystem.
#
import os
import kagglehub

# Use these for PaliGemma-2 3B 224px²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "/content/drive/MyDrive/paligemma_assets/paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"  # Path to fetch from Kaggle.


# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"

if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "/content/drive/MyDrive/paligemma_assets/paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

# DATA_DIR="./longcap100"
# if not os.path.exists(DATA_DIR):
#   print("Downloading the dataset...")
#   !gsutil -m -q cp -n -r gs://longcap100/ .
#   print(f"Data path: {DATA_DIR}")

## Notebook

In [7]:
import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

JAX version:  0.5.3
JAX platform: gpu
JAX devices:  1


In [8]:
# @title Construct model and load params into RAM.

# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property, we set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

In [9]:
# @title Move params to GPU/TPU memory.
#
# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune
# a part of the parameters. Additionally we keep the frozen params in float16
# and cast trainable to float32.

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

#
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

In [10]:
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)

 == Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float16
img/Transformer/encoder_norm/scale                                               (1152,)                float16
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (2

In [11]:
# # @title Define preprocess functions to create inputs to the model.

# def preprocess_image(image, size=224):
#   # Model has been trained to handle images of different aspects ratios
#   # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
#   # options are helpful to improve quality in some tasks.
#   image = np.asarray(image)
#   if image.ndim == 2:  # Convert image without last channel into greyscale.
#     image = np.stack((image,)*3, axis=-1)
#   image = image[..., :3]  # Remove alpha layer.
#   assert image.shape[-1] == 3

#   image = tf.constant(image)
#   image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
#   return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

# def preprocess_tokens(prefix, suffix=None, seqlen=None):
#   # Model has been trained to handle tokenized text composed of a prefix with
#   # full attention and a suffix with causal attention.
#   separator = "\n"
#   tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
#   mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
#   mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

#   if suffix:
#     suffix = tokenizer.encode(suffix, add_eos=True)
#     tokens += suffix
#     mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
#     mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

#   mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.
#   if seqlen:
#     padding = [0] * max(0, seqlen - len(tokens))
#     tokens = tokens[:seqlen] + padding
#     mask_ar = mask_ar[:seqlen] + padding
#     mask_loss = mask_loss[:seqlen] + padding
#     mask_input = mask_input[:seqlen] + padding

#   return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

# def postprocess_tokens(tokens):
#   tokens = tokens.tolist()  # np.array to list[int]
#   try:  # Remove tokens at and after EOS if any.
#     eos_pos = tokens.index(tokenizer.eos_id())
#     tokens = tokens[:eos_pos]
#   except ValueError:
#     pass
#   return tokenizer.decode(tokens)
import os
import glob
import random

# The path to your dataset in Google Drive
DATASET_PATH = '/content/drive/MyDrive/datasets/Cornell_Grasp_Kaggle'

def find_data_pairs(dataset_path):
    """
    Scans the dataset to find all matching pairs of image and annotation files.
    """
    data_pairs = []
    annotation_files = glob.glob(os.path.join(dataset_path, '**', '*cpos.txt'), recursive=True)

    for ann_path in annotation_files:
        img_path = ann_path.replace('cpos.txt', 'r.png')
        if os.path.exists(img_path):
            data_pairs.append((img_path, ann_path))

    return data_pairs

# Create the master list of data
all_pairs = find_data_pairs(DATASET_PATH)
random.shuffle(all_pairs)

# Split the data
split_index = int(len(all_pairs) * 0.80)
train_pairs = all_pairs[:split_index]
val_pairs = all_pairs[split_index:]

print(f"Total samples: {len(all_pairs)}")
print(f"Training samples: {len(train_pairs)}")
print(f"Validation samples: {len(val_pairs)}")

Total samples: 885
Training samples: 708
Validation samples: 177


In [12]:
# # @title Function to iterate over train and validation examples.
# SEQLEN = 128

# # TODO: Consider data iterators skipping big_vision and tf.data?
# train_dataset = big_vision.datasets.jsonl.DataSource(
#     os.path.join(DATA_DIR, "data_train90.jsonl"),
#     fopen_keys={"image": DATA_DIR})

# val_dataset = big_vision.datasets.jsonl.DataSource(
#     os.path.join(DATA_DIR, "data_val10.jsonl"),
#     fopen_keys={"image": DATA_DIR})


# def train_data_iterator():
#   """Never ending iterator over training examples."""
#   # Shuffle examples and repeat so one can train for many epochs.
#   dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
#   for example in dataset.as_numpy_iterator():
#     image = Image.open(io.BytesIO(example["image"]))
#     image = preprocess_image(image)

#     prefix = "caption en"  # Could also be a different prefix per example.
#     suffix = example["suffix"].decode().lower()
#     tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)

#     yield {
#         "image": np.asarray(image),
#         "text": np.asarray(tokens),
#         "mask_ar": np.asarray(mask_ar),
#         "mask_loss": np.asarray(mask_loss),
#     }


# def validation_data_iterator():
#   """Single iterator over validation examples."""
#   for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
#     image = Image.open(io.BytesIO(example["image"]))
#     image = preprocess_image(image)

#     prefix = "caption en"  # Could also be a different prefix per example.
#     tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

#     yield {
#         "image": np.asarray(image),
#         "text": np.asarray(tokens),
#         "mask_ar": np.asarray(mask_ar),
#         "mask_input": np.asarray(mask_input),
#     }

import numpy as np
import tensorflow as tf
from PIL import Image
import random
import os

# This cell assumes the 'tokenizer' variable (the real SentencePieceProcessor)
# has been loaded in a previous cell from your notebook.

# Using the smaller, memory-efficient sequence length
SEQLEN = 32

def preprocess_image(image, size=224):
    """Resizes and normalizes the image."""
    image = np.asarray(image)
    if image.ndim == 2: image = np.stack((image,)*3, axis=-1)
    image = image[..., :3]
    image = tf.constant(image)
    image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
    return image.numpy() / 127.5 - 1.0

def preprocess_tokens(prefix, suffix=None, seqlen=None):
    """Converts text into token IDs and all necessary masks for the model."""
    separator = "\n"
    tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
    mask_ar = [0] * len(tokens)
    mask_loss = [0] * len(tokens)
    if suffix:
        suffix_tokens = tokenizer.encode(suffix, add_eos=True)
        tokens += suffix_tokens
        mask_ar += [1] * len(suffix_tokens)
        mask_loss += [1] * len(suffix_tokens)
    mask_input = [1] * len(tokens)
    if seqlen:
        padding_len = max(0, seqlen - len(tokens))
        padding = [tokenizer.pad_id()] * padding_len
        tokens = tokens[:seqlen] + padding
        mask_ar = mask_ar[:seqlen] + padding
        mask_loss = mask_loss[:seqlen] + padding
        mask_input = mask_input[:seqlen] + ([0] * padding_len)
    return np.array(tokens), np.array(mask_ar), np.array(mask_loss), np.array(mask_input)

def process_grasp_file_to_bbox_tokens(file_path):
    """Converts a grasp file into a list of bounding box token strings."""
    try:
        with open(file_path, 'r') as f: lines = [line.strip() for line in f.readlines()]
    except FileNotFoundError: return []
    bbox_token_strings = []
    for i in range(0, len(lines), 4):
        points_str = lines[i:i+4]
        if len(points_str) == 4:
            try:
                corners = np.array([p.split() for p in points_str], dtype=float)
                if np.isnan(corners).any():
                    continue
                x_coords, y_coords = corners[:, 0], corners[:, 1]
                x_min, y_min = int(np.min(x_coords)), int(np.min(y_coords))
                x_max, y_max = int(np.max(x_coords)), int(np.max(y_coords))
                token_str = (f"<loc_{y_min:04d}><loc_{x_min:04d}><loc_{y_max:04d}><loc_{x_max:04d}>")
                bbox_token_strings.append(token_str)
            except (ValueError, TypeError):
                continue
    return bbox_token_strings

def cornell_grasp_iterator(data_pairs, is_training=True):
    """The main data iterator, yielding all required masks."""
    while True:
        if is_training:
            random.shuffle(data_pairs)
        for img_path, ann_path in data_pairs:
            try:
                image = Image.open(img_path)
                bbox_tokens_list = process_grasp_file_to_bbox_tokens(ann_path)
                if not bbox_tokens_list: continue
                selected_bbox = random.choice(bbox_tokens_list)
                processed_image = preprocess_image(image)
                prefix = "detect grasp"
                suffix = selected_bbox
                tokens, mask_ar, mask_loss, mask_input = preprocess_tokens(prefix, suffix, SEQLEN)
                yield {
                    "image": processed_image,
                    "text": tokens,
                    "mask_ar": mask_ar,
                    "mask_loss": mask_loss,
                    "mask_input": mask_input,
                }
            except Exception as e:
                print(f"Skipping file due to error: {e}, Path: {img_path}")
                continue

print("✅ Data processing pipeline functions are now defined.")

✅ Data processing pipeline functions are now defined.


In [13]:
import numpy as np
from PIL import Image
import cv2 # Import the OpenCV library
import re
import os

def parse_bbox_from_string(caption_str):
    """Extracts bounding box coordinates from the model's token string output."""
    numbers = re.findall(r'<loc_(\d+)>', caption_str)
    if len(numbers) == 4:
        coords = [int(n) for n in numbers]
        return tuple(coords) # (ymin, xmin, ymax, xmax)
    else:
        return None

def postprocess_tokens(tokens):
    """Converts token IDs back to a string."""
    tokens = tokens.tolist()
    try:
        eos_pos = tokens.index(tokenizer.eos_id())
        tokens = tokens[:eos_pos]
    except ValueError:
        pass
    return tokenizer.decode(tokens)

def save_eval_image(image, caption, filename):
    """
    Saves an evaluation image with its predicted bounding box drawn on it.
    """
    # Convert from [-1,1] to [0,255] and ensure it's in RGB format for saving
    image_display = ((image + 1) / 2 * 255).astype(np.uint8)
    image_to_save = image_display.copy()

    # Try to parse the bounding box from the caption string
    bbox = parse_bbox_from_string(caption)

    if bbox:
        ymin, xmin, ymax, xmax = bbox
        pt1 = (xmin, ymin)
        pt2 = (xmax, ymax)
        # Draw a RED rectangle using OpenCV (color is BGR format)
        cv2.rectangle(image_to_render, pt1, pt2, color=(0, 0, 255), thickness=2)

    # Convert from RGB (used by PIL/TF) to BGR (used by OpenCV) for saving
    image_bgr = cv2.cvtColor(image_to_render, cv2.COLOR_RGB2BGR)
    cv2.imwrite(filename, image_bgr)

print("✅ Helper functions for saving evaluation images are now defined.")

✅ Evaluation and rendering helper functions are now defined.


In [14]:
# @title Define the training step and evaluation loop.
#
# The main update_fn using simple SGD.
#
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens we normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using SGD.
  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss

# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    # Construct a list of examples in the batch.
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)  # Indicates true example.
    except StopIteration:
      if len(examples) == 0:
        return outputs

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]

    # Append to html output.
    for example, response in zip(examples, responses):
      outputs.append((example["image"], response))
      if num_examples and len(outputs) >= num_examples:
        return outputs

In [None]:
# # @title Run training loop.
# #
# # Run a short training loop with cosine learning rate schedule.
# #
# # Note: the first step can be quite slow on some machines (up to several minutes)
# # due to XLA compilation of the jax.jit'd function.
# #
# %%time

# BATCH_SIZE = 8
# TRAIN_EXAMPLES = 512
# LEARNING_RATE = 0.03

# TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
# EVAL_STEPS = TRAIN_STEPS // 4

# train_data_it = train_data_iterator()

# sched_fn = big_vision.utils.create_learning_rate_schedule(
#     total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
#     decay_type="cosine", warmup_percent=0.10)

# for step in range(1, TRAIN_STEPS+1):
#   # Make list of N training examples.
#   examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

#   # Convert list of examples into a dict of np.arrays and load onto devices.
#   batch = jax.tree.map(lambda *x: np.stack(x), *examples)
#   batch = big_vision.utils.reshard(batch, data_sharding)

#   # Training step and report training loss
#   learning_rate = sched_fn(step)
#   params, loss = update_fn(params, batch, learning_rate)

#   loss = jax.device_get(loss)
#   print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

#   if step == 1 or (step % EVAL_STEPS) == 0:
#     print(f"Model predictions at step {step}")
#     html_out = ""
#     for image, caption in make_predictions(
#         validation_data_iterator(), num_examples=4, batch_size=4):
#       html_out += render_example(image, caption)
#     display(HTML(html_out))

import jax
import big_vision.utils
import glob

# This cell assumes variables like 'params', 'update_fn', etc., are loaded from the original notebook.

print("🚀 Starting the fine-tuning process...")

# Set training parameters
BATCH_SIZE = 4
TRAIN_EXAMPLES = len(train_pairs) * 2
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 5

# Create our custom data iterators
train_data_it = cornell_grasp_iterator(train_pairs, is_training=True)
val_data_it = cornell_grasp_iterator(val_pairs, is_training=False)

# Set up learning rate schedule
lr_sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS + 1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

# Run the main training loop
for step in range(1, TRAIN_STEPS + 1):
    examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    learning_rate = lr_sched_fn(step)
    params, loss = update_fn(params, batch, learning_rate)

    if step % 20 == 0 or step == 1 or step == TRAIN_STEPS:
      print(f"Step: {step:3d}/{TRAIN_STEPS} | Learning Rate: {learning_rate:.5f} | Loss: {loss:.4f}")

    # Run periodic evaluation on the validation set
    if step == 1 or (step % EVAL_STEPS) == 0 or step == TRAIN_STEPS:
        print(f"\n--- Running evaluation at step {step} ---")

        # Clean up old evaluation images
        for f in glob.glob('/content/eval_*.jpg'):
            os.remove(f)

        # Generate and save new evaluation images
        saved_files = []
        for i, (image, caption) in enumerate(make_predictions(
            val_data_it, num_examples=4, batch_size=BATCH_SIZE)):

            filename = f"/content/eval_step_{step}_img_{i}.jpg"
            save_eval_image(image, caption, filename)
            saved_files.append(filename)

        print(f"✅ Evaluation complete. Images saved to:")
        for f in saved_files:
            print(f"  - {f}")
        print("\n--> Please check the file browser on the left to view images.\n")

print("\n✅ Fine-tuning complete!")

Step:   1/708 | Learning Rate: 0.00042 | Loss: 2.8888

--- Running evaluation at step 1 ---


--- Evaluation complete ---



# Save the final checkpoint

In [None]:
def npsave(pytree, path):
  names_and_vals, _ = big_vision.utils.tree_flatten_with_names(pytree)
  with open(path, "wb") as f:
    np.savez(f, **{k:v for k, v in names_and_vals})

# Takes around 4 minutes
npsave(params, 'my-custom-paligemma-ckpt.npz')