In [1]:
from IPython.display import display, HTML

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h1 style="color:red;">DI-725 : Transformers and Attention-Based Deep Networks</h1>
  <h2 style="color:red;">Final Project : Phase - 2</h2>
  <br><br>
  <h4 style="color:red;">Turgay Yıldız</h4>
  <br>
  <h4 style="color:red;">Graduate School of Informatics, Middle East Technical University (METU)</h4>
</div>

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;">Fetch big_vision code and install dependencies</h3>
</div>

In [6]:
import os
import sys

In [7]:
# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

In [8]:
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

In [9]:
# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [10]:
import os
from google.colab import userdata

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Model and Pre-trained weights : </h3>
</div>

In [11]:
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json

os.environ["KAGGLE_USERNAME"] =    '.........'
os.environ["KAGGLE_KEY"]      =    '...........'

In [12]:
# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid OOM'ing due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [13]:
# @title Download checkpoint, tokenizer and dataset to local filesystem.
#
import os
import kagglehub

In [14]:
# Use these for PaliGemma-2 3B 224px²
#LLM_VARIANT   = "gemma2_2b"
#MODEL_PATH    = "./paligemma2-3b-pt-224.b16.npz"
#KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"  # Path to fetch from Kaggle.

In [15]:
# Use these for PaliGemma 1:
LLM_VARIANT   = "gemma_2b"
MODEL_PATH    = "./paligemma-3b-pt-224.f16.npz"
KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"

In [16]:
if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

Downloading the checkpoint from Kaggle, this could take a few minutes....
Model path: /kaggle/input/paligemma/jax/paligemma-3b-pt-224/1/./paligemma-3b-pt-224.f16.npz


In [17]:
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
    print("Downloading the model tokenizer...")
    !wget https://storage.googleapis.com/big_vision/paligemma_tokenizer.model -O {TOKENIZER_PATH}
    print(f"Tokenizer path: {TOKENIZER_PATH}")


Downloading the model tokenizer...
--2025-05-02 06:49:16--  https://storage.googleapis.com/big_vision/paligemma_tokenizer.model
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.174.207, 74.125.204.207, 64.233.187.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.174.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4264023 (4.1M) [application/octet-stream]
Saving to: ‘./paligemma_tokenizer.model’


2025-05-02 06:49:20 (1.95 MB/s) - ‘./paligemma_tokenizer.model’ saved [4264023/4264023]

Tokenizer path: ./paligemma_tokenizer.model


<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Core Library Imports : </h3>
</div>

In [119]:
import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.core.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding


<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Reserve GPU/TPU for JAX </h3>
</div>

In [19]:
# Don't let TF use the GPU or TPUs
# Disables TensorFlow’s access to GPUs/TPUs so JAX can fully utilize them without resource contention.
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

JAX version:  0.4.33
JAX platform: gpu
JAX devices:  1


<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;">Construct model and load params into RAM </h3>

 <h5 style="color:red;"> model_config: hyperparameters for both the vision encoder and text decoder.
<br>
                          Instantiate the combined Vision+LLM model.
<br>
                          Load pretrained weights into a parameter tree.
<br>
                          Build a decode function for efficient batched generation.
                          </h5>


</div>

In [20]:
# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property, we set it to 0.0
# for better transfer results.
#model_config = ml_collections.FrozenConfigDict({
#    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
#    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
#})
#--------------------------------------------------------------------------------------------------------#
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})

In [21]:
model     = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

In [158]:
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

In [159]:
tokenizer.eos_id()

1

In [160]:
# Define `decode` function to sample outputs from the model.
decode_fn  =   predict_fns.get_all(model)['decode']
decode     =   functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Prepare for Partial Fine-Tuning :  </h3>

  <h5 style="color:red;">  Defines a rule that only attention layers in the LLM ("llm/layers/attn/…") are trainable; everything else stays frozen.
<br>
Creates a parallel "mask" PyTree of booleans marking which sub-trees to update.
</h5>
</div> 

In [161]:
# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune
# a part of the parameters. Additionally we keep the frozen params in float16
# and cast trainable to float32.

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)


<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Sharding & Casting Parameters :  </h3>

  <h5 style="color:red;">  Sharding: split tensors across devices (if you had >1 GPU).
<br>
                        maybe_cast_to_f32: keep the frozen weights in fp16 to save memory; cast the few trainable ones to fp32 so their gradients remain stable.
<br>
                        The loop unpacks the parameter tree, reshares & casts each leaf, and reassembles it.
</h5>
</div>


In [162]:
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

In [163]:
data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

In [164]:
params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

In [165]:
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

In [166]:
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

In [167]:
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef  = jax.tree.flatten(params)
sharding_leaves  = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)

for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

In [168]:
# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")


In [169]:
print(" == Model params == ")
parameter_overview(params)

 == Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float16
img/Transformer/encoder_norm/scale                                               (1152,)                float16
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float16
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float16
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float16
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (2

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Define preprocess functions to create inputs to the model :  </h3>
</div>

In [170]:
# @title Define preprocess functions to create inputs to the model.

def preprocess_image(image, size=224):
  # Model has been trained to handle images of different aspects ratios
  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
  # options are helpful to improve quality in some tasks.
  image = np.asarray(image)
  if image.ndim == 2:  # Convert image without last channel into greyscale.
    image = np.stack((image,)*3, axis=-1)
  image = image[..., :3]  # Remove alpha layer.
  assert image.shape[-1] == 3

  image = tf.constant(image)
  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
  tokens = tokens.tolist()  # np.array to list[int]
  try:  # Remove tokens at and after EOS if any.
    eos_pos = tokens.index(tokenizer.eos_id())
    tokens = tokens[:eos_pos]
  except ValueError:
    pass
  return tokenizer.decode(tokens)


<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Import Data  :  </h3>
</div>

In [40]:
import os
import tensorflow as tf
import pandas as pd
import numpy as np

In [41]:
# List all datasets under /kaggle/input
for dirname, _, filenames in os.walk('/kaggle/input'):
    print(f"Directory: {dirname}")
    for filename in filenames:
        print(f"  File: {filename}")


Directory: /kaggle/input
Directory: /kaggle/input/paligemma
Directory: /kaggle/input/paligemma/jax
Directory: /kaggle/input/paligemma/jax/paligemma-3b-pt-224
Directory: /kaggle/input/paligemma/jax/paligemma-3b-pt-224/1
  File: paligemma-3b-pt-224.npz
  File: paligemma-3b-pt-224.bf16.npz
  File: paligemma-3b-pt-224.f16.npz
Directory: /kaggle/input/rsics-dataset
  File: captions.csv
Directory: /kaggle/input/rsics-dataset/resized
  File: NWPU_11772.jpg
  File: NWPU_7525.jpg
  File: NWPU_23522.jpg
  File: RSICD_4232.jpg
  File: NWPU_28090.jpg
  File: RSICD_4344.jpg
  File: NWPU_1459.jpg
  File: NWPU_24378.jpg
  File: NWPU_25939.jpg
  File: RSICD_8878.jpg
  File: NWPU_1430.jpg
  File: NWPU_11144.jpg
  File: NWPU_19837.jpg
  File: NWPU_5500.jpg
  File: RSICD_5489.jpg
  File: NWPU_26670.jpg
  File: NWPU_1229.jpg
  File: RSICD_8170.jpg
  File: NWPU_20050.jpg
  File: NWPU_14575.jpg
  File: NWPU_2838.jpg
  File: NWPU_11933.jpg
  File: RSICD_6660.jpg
  File: NWPU_30918.jpg
  File: NWPU_23405.jpg


In [42]:
weights = np.load('/kaggle/input/paligemma/jax/paligemma-3b-pt-224/1/paligemma-3b-pt-224.npz')
print(weights.files[:5])  # print first 5 weight names


['params/img/Transformer/encoder_norm/bias', 'params/img/Transformer/encoder_norm/scale', 'params/img/Transformer/encoderblock/LayerNorm_0/bias', 'params/img/Transformer/encoderblock/LayerNorm_0/scale', 'params/img/Transformer/encoderblock/LayerNorm_1/bias']


In [43]:
captions_df = pd.read_csv('/kaggle/input/rsics-dataset/captions.csv')
print(captions_df.head()) 

  source split           image  \
0   NWPU  test  NWPU_31430.jpg   
1   NWPU  test  NWPU_31431.jpg   
2   NWPU  test  NWPU_31432.jpg   
3   NWPU  test  NWPU_31433.jpg   
4   NWPU  test  NWPU_31434.jpg   

                                           caption_1  \
0   A gray plane on the runway and the lawn beside .   
1  Three small planes parked in a line on the air...   
2  A plane parked in a line on the airport with s...   
3  A small plane and a big plane parked next to b...   
4       Two planes parked next to boarding bridges .   

                                           caption_2  \
0        A grey plane is on the runway by the lawn .   
1  There are four aircraft on the open ground, Th...   
2  A white plane was parked on the instruction li...   
3  A white plane and a gray plane parked at the b...   
4  Two aircraft were parked at the departure gates .   

                                           caption_3  \
0  There is an airplane on the runway with a larg...   
1  There 

In [44]:

IMAGE_ROOT = "/kaggle/input/rsics-dataset/resized" 
BATCH_SIZE = 32
IMG_SIZE   = (224, 224) 

In [45]:
print("All split labels:", captions_df['split'].unique())
print("Counts:\n", captions_df['split'].value_counts())

All split labels: ['test' 'val' 'train']
Counts:
 split
train    35614
test      4454
val       4453
Name: count, dtype: int64


In [46]:
# 4. Filter into splits
splits = {}
for split_name in ['train', 'val', 'test']:
    splits[split_name] = captions_df[captions_df['split'] == split_name]

In [47]:
# 5. Convert each split-DataFrame into (paths, captions)
def df_to_paths_and_captions(split_df):
    # Full paths
    paths = split_df['image'].apply(lambda fn: os.path.join(IMAGE_ROOT, fn)).tolist()
    # List-of-captions per example
    captions_cols = [f'caption_{i}' for i in range(1,6)]
    captions = split_df[captions_cols].values.tolist()
    return paths, captions

In [48]:
train_paths, train_caps = df_to_paths_and_captions(splits['train'])
val_paths,   val_caps   = df_to_paths_and_captions(splits['val'])
test_paths,  test_caps  = df_to_paths_and_captions(splits['test']) 

In [171]:
# 6. Preprocessing fn: load image + return captions list
def _load_and_preprocess(path, captions):
    # Read & decode
    img = tf.io.read_file(path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    # Resize & normalize
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.cast(img, tf.float32) / 255.0
    return img, captions

# 7. Build the tf.data pipeline
def make_dataset(paths, captions, shuffle=False):
    # turn your Python list-of-strings into a tf.string tensor
    paths_ds = tf.data.Dataset.from_tensor_slices(tf.constant(paths, dtype=tf.string))
    # turn your list-of-lists-of-strings into a [5] tf.string tensor
    caps_ds  = tf.data.Dataset.from_tensor_slices(tf.constant(captions, dtype=tf.string))

    ds = tf.data.Dataset.zip((paths_ds, caps_ds))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(paths))
    ds = ( ds
           .map(_load_and_preprocess, num_parallel_calls=tf.data.AUTOTUNE)
           .batch(BATCH_SIZE)
           .prefetch(tf.data.AUTOTUNE)
         )
    return ds

In [221]:
train_ds = make_dataset(train_paths, train_caps, shuffle=True)
val_ds   = make_dataset(val_paths,   val_caps,   shuffle=True) 
test_ds  = make_dataset(test_paths,  test_caps,  shuffle=False)

In [222]:
# 8. Quick sanity check
for imgs, caps in test_ds.take(1):
    print("Images batch shape:", imgs.shape)            # (BATCH_SIZE, H, W, 3)
    print("Captions batch shape:", len(caps), "examples")
    print("First example captions:", caps[0])            # a list of 5 strings

Images batch shape: (8, 224, 224, 3)
Captions batch shape: 8 examples
First example captions: tf.Tensor(
[b'A gray plane on the runway and the lawn beside .'
 b'A grey plane is on the runway by the lawn .'
 b'There is an airplane on the runway with a large lawn by the runway .'
 b'A plane is parked on the runway next to the grass .'
 b'There is a plane on the runway beside the grass .'], shape=(5,), dtype=string)


In [63]:
for i in range(10):
    
    print(np.random.randint(5) )

3
3
3
0
0
1
4
4
1
1


In [223]:
def train_data_iterator():
    """Never-ending iterator over training examples from train_ds."""
    while True:
        for image_batch, caps_batch in train_ds:
            images_np = image_batch.numpy()      # [B,H,W,3]
            caps_np   = caps_batch.numpy()       # [B,5] bytes

            for img, caps in zip(images_np, caps_np):
                # Decode all 5 captions into Python strings
                decoded_caps = [c.decode('utf-8') for c in caps]

                # Prepare the model input from a random caption 
                suffix = decoded_caps[np.random.randint(5)].lower() 
                prefix = "caption en"
                
                tokens, mask_ar, mask_loss, _ = preprocess_tokens(
                    prefix, suffix, SEQLEN
                )

                img_proc = img * 2.0 - 1.0
                yield [{
                    "image":     img_proc,
                    "text":      np.asarray(tokens,    dtype=np.int32),
                    "mask_ar":   np.asarray(mask_ar,    dtype=np.int32),
                    "mask_loss": np.asarray(mask_loss,  dtype=np.int32),
                },
                    decoded_caps,           
                      ] 

#----------------------------------------------------------------------------------------------------------------#
#                                              Validation :
#----------------------------------------------------------------------------------------------------------------#
def validation_data_iterator():
    """Single-pass iterator over validation examples from val_ds."""
    for image_batch, caps_batch in val_ds:

         images_np = image_batch.numpy()
         caps_np   = caps_batch.numpy()

         for img, caps in zip(images_np, caps_np):

            decoded_caps =   [c.decode('utf-8') for c in caps] 
            suffix       =   decoded_caps[np.random.randint(5)].lower() 
            prefix       =   "caption en"

            tokens, mask_ar, _, mask_input = preprocess_tokens(
                 prefix,  seqlen=SEQLEN)
             
            img_proc = img * 2. - 1.
            yield [{
                 "image":     img_proc,
                 "text":      np.asarray(tokens,     dtype=np.int32),
                 "mask_input": np.asarray(mask_input, dtype=np.int32),
                 "mask_ar":    np.asarray(mask_ar,    dtype=np.int32),
            },
                 decoded_caps,            
                  ]


<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Check Iterators :  </h3>
</div>

In [194]:
val_it = validation_data_iterator()
       
for _ in range(1):
    ex   = next(val_it) 

In [195]:
ex[0]["image"].shape,   ex[0]["text"].shape,  ex[0]["mask_input"].shape,   ex[0]["mask_ar"].shape

((224, 224, 3), (128,), (128,), (128,))

In [196]:
ex[1]

['Two planes parked next to boarding bridges and another passing plane beside .',
 'Two aircraft were parked at the departure gates and one was on the runway .',
 'Several planes of the same size parked neatly next to the buildings inside the airport .',
 'Three planes are parked on the open space next to the terminal .',
 'Three planes are in the parking lot next to the building .']

In [197]:
# @title Inspect training examples.
def render_inline(image, resize=(128, 128)):
  """Convert image into inline html."""
  image = Image.fromarray(image)
  image.resize(resize)
  with io.BytesIO() as buffer:
    image.save(buffer, format='jpeg')
    image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
    return f"data:image/jpeg;base64,{image_b64}"

def render_example(image, caption):
  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]
  return f"""
    <div style="display: inline-flex; align-items: center; justify-content: center;">
        <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
        <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
    </div>
    """


In [198]:
SEQLEN   =   128

In [206]:
import html

In [208]:
html_out = ""
for idx, example in zip(range(4), train_data_iterator()): 
    img = example[0]["image"] 
    
    html_out += render_example(img, example[1][0]) 
    # optionally a divider between images:
    html_out += "<hr style='width:100%;'>"

display(HTML("<h3>Training examples </h3>" + html_out))


<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Define the training step and evaluation loop :  </h3>
</div>

In [224]:
# @title Define the training step and evaluation loop.
#
# The main update_fn using simple SGD.
#
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens we normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using SGD.
  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss
#---------------------------------------------------------------------------------------------------------#
#                                     Evaluation/inference loop.
#---------------------------------------------------------------------------------------------------------#
def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    # Construct a list of examples in the batch.
    examples = []
    captions = []
    try:
#---------------------------------------------------------------------------------------------------------#
      for _ in range(batch_size):
          
        example, caps   = next(data_iterator)  
        example["_mask"]  = np.array(True) 
        examples.append(example) 
        captions.append(caps)
#---------------------------------------------------------------------------------------------------------#
    except StopIteration:
      if len(examples) == 0:
        return outputs

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.
      captions.append(captions[-1])  # pad suffix list accordingly 
#---------------------------------------------------------------------------------------------------------#
    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)
#---------------------------------------------------------------------------------------------------------#
    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)
#---------------------------------------------------------------------------------------------------------#
    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]
#---------------------------------------------------------------------------------------------------------#
    # Only include non-padded outputs
    for example, response, caps in zip(examples, responses, captions):
        if example["_mask"]:
          outputs.append((example["image"], response, caps))
          if num_examples and len(outputs) >= num_examples:
            return outputs

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Check the predictions :  </h3>
</div>

In [225]:
(example, caps) = next(validation_data_iterator()) 

In [226]:
img, pred, ref = make_predictions(iter([(example, caps)]),
                                  num_examples=1,
                                  batch_size=1)[0]

In [228]:
print("PRED:", pred)
print("REF:", ref)

PRED: a bridge is on the river .
REF: ['the two bridges with gridding steel beams joins the back covered by trees .', 'two truss bridges cross the river parallel', 'a bridge with iron trestle links the forests .', 'two parallel tied arch bridge traverse the wide river with rows of trees on its banks', 'some green trees are in two sides of a river with two parallel bridges .']


In [None]:
(example, caps) = next(validation_data_iterator()) 

img, pred, ref = make_predictions(iter([(example, caps)]),
                                  num_examples=1,
                                  batch_size=1)[0] 

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> WANDB :  </h3>
</div>

In [97]:
import wandb

In [None]:
os.environ["WANDB_API_KEY"]       = "...................."

run  =  wandb.init(project="DI_725_Project_Phase_2__2697258", entity="DI_725___Final_Project", settings=wandb.Settings(init_timeout=500)) 

<div style="background-color:yellow; text-align:center; padding:40px; font-family:sans-serif;">
  <h3 style="color:red;"> Train :  </h3>
</div>

In [101]:
import multiprocessing as mp
mp.set_start_method("spawn", force=True)


In [102]:
!pip install evaluate



In [103]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=d10ea00d016d84b7507aaed95646198e70124be76b630de163d3fae2551aa7d4
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [None]:
%%time
import io, base64
from PIL import Image
from IPython.display import HTML, display

import wandb
import pickle
import evaluate
import tqdm
#---------------------------------------------------------------------------------------------------------------------------#
# 1) Install & load metrics
#    pip install evaluate
bleu_metric   = evaluate.load("bleu")
rouge_metric  = evaluate.load("rouge")
meteor_metric = evaluate.load("meteor")
#---------------------------------------------------------------------------------------------------------------------------#
# 2) Hyperparams
BATCH_SIZE     = 8 
TRAIN_EXAMPLES = 35614 
LEARNING_RATE  = 0.0001 

TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS  = 1000 #max(1, TRAIN_STEPS // 100)  
#---------------------------------------------------------------------------------------------------------------------------#
run = wandb.init(project="DI_725_Project_Phase_2__2697258", reinit=True, settings=wandb.Settings(init_timeout=500))


run.config.update({
    "model_name"     : "PaliGemma on original data",
    "batch_size"     : BATCH_SIZE,
    "train_examples" : TRAIN_EXAMPLES,
    "learning_rate"  : LEARNING_RATE,
    "train_steps"    : TRAIN_STEPS,
    "eval_steps"     : EVAL_STEPS,
})
#---------------------------------------------------------------------------------------------------------------------------#
# 3) Data & scheduler
train_data_it = train_data_iterator() 
  
sched_fn  = big_vision.utils.create_learning_rate_schedule(
                                                                total_steps    = TRAIN_STEPS + 1,
                                                                base           = LEARNING_RATE,
                                                                decay_type     = "cosine",
                                                                warmup_percent = 0.10,
)
#---------------------------------------------------------------------------------------------------------------------------#
#                                           Training loop
#---------------------------------------------------------------------------------------------------------------------------#
for step in tqdm.tqdm(range(1, TRAIN_STEPS + 1)): 
    # ——— train step ———

    examples = [next(train_data_it)[0] for _ in range(BATCH_SIZE)]

    batch    = jax.tree.map(lambda *x: np.stack(x), *examples) 
    batch    = big_vision.utils.reshard(batch, data_sharding)

    lr           = sched_fn(step)
    params, loss = update_fn(params, batch, lr)
    loss_val     = float(jax.device_get(loss))

    print(f"step: {step}/{TRAIN_STEPS}   lr: {lr:.8f}   loss: {loss_val:.4f}")
    wandb.log({"train/loss": loss_val, "train/lr": lr}, step=step)
    
#---------------------------------------------------------------------------------------------------------------------------#
#                                         ——— eval step ———
#---------------------------------------------------------------------------------------------------------------------------#
    if step == 1 or step % EVAL_STEPS == 0:
        print(f"→ Evaluation at step {step}")

        # A) Batch‐collect all validation examples once
        all_examples = []
        all_captions = []

        val_it = validation_data_iterator()
        
        # Pull exactly 4453 examples 
        num_to_eval = 800  

        for _ in range(num_to_eval):
            example, caps = next(val_it)
            
            all_examples.append(example)
            all_captions.append(caps) 
        
        # B) One big make_predictions call
        #    This internally pads to multiples of batch_size as needed.

        demo_outputs = make_predictions(
            iter(zip(all_examples, all_captions)),
            num_examples = num_to_eval,
            batch_size   = BATCH_SIZE,
        )
        
        # demo_outputs is a list of (img, pred, ref) of length num_to_eval
        all_hyps = [pred for (_, pred, _) in demo_outputs]
        all_refs = [[ref] for (_, _, ref) in demo_outputs]

#---------------------------------------------------------------------------------------------------------------------------#
#                                          compute & log BLEU/ROUGE/METEOR …
#---------------------------------------------------------------------------------------------------------------------------#
        # 1) Filter out any fully empty predictions
        # Filter out any empty predictions, just in case
        filtered = [(h, refs) for h, refs in zip(all_hyps, all_refs) if h.strip()]
        if filtered:
            all_hyps, all_refs = zip(*filtered)
            all_hyps, all_refs = list(all_hyps), list(all_refs)
        else:
            all_hyps, all_refs = [], []

        # Only call metrics if we have at least one example
        if all_hyps:
            bleu_res = bleu_metric.compute(
                predictions=all_hyps,
                references=all_refs,
                smooth=True   # avoid zero‐div errors
            )
            rouge_res = rouge_metric.compute(
                predictions=all_hyps,
                references=all_refs
            )
            meteor_res = meteor_metric.compute(
                predictions=all_hyps,
                references=all_refs
            )
        else:
            # no valid hyps → set everything to zero
            bleu_res   = {"bleu": 0.0}
            rouge_res  = {"rouge1": 0.0, "rouge2": 0.0, "rougeL": 0.0}
            meteor_res = {"meteor": 0.0}

        print(f"   BLEU:   {bleu_res['bleu']:.3f}")
        print(f"   ROUGE1: {rouge_res['rouge1']:.3f}, ROUGE2: {rouge_res['rouge2']:.3f}, ROUGEL: {rouge_res['rougeL']:.3f}")
        print(f"   METEOR: {meteor_res['meteor']:.3f}")

        wandb.log({
            "eval/bleu":    bleu_res["bleu"],
            "eval/rouge1":  rouge_res["rouge1"],
            "eval/rouge2":  rouge_res["rouge2"],
            "eval/rougeL":  rouge_res["rougeL"],
            "eval/meteor":  meteor_res["meteor"],
            "eval/loss":    loss_val,
        }, step=step)

#---------------------------------------------------------------------------------------------------------------------------#
#                        ---- B) Demo‐sample pass via make_predictions ----
#---------------------------------------------------------------------------------------------------------------------------#
        demo_samples = make_predictions(
            validation_data_iterator(),
            num_examples=4,
            batch_size=4,
        )

        # now demo_samples is a list of (img, pred, ref)
#---------------------------------------------------------------------------------------------------------------------------#
#                         ---- Display side‐by‐side in Colab ----
#---------------------------------------------------------------------------------------------------------------------------#

        html = "<div style='display:flex; gap:16px;'>"
        for img, pred, ref in demo_samples:
            img_u8 = ((img + 1) * 127.5).astype("uint8")
            buf    = io.BytesIO()
            Image.fromarray(img_u8).save(buf, format="PNG")
            b64    = base64.b64encode(buf.getvalue()).decode("utf-8")
            html += f"""
              <div style='text-align:center;'>
                <img src="data:image/png;base64,{b64}" width="200"/><br/>
                <strong>Pred:</strong> {pred}<br/>
                <strong>Ref:</strong> {ref}
              </div>
            """
        html += "</div>"
        display(HTML(html))
#---------------------------------------------------------------------------------------------------------------------------#
        # ---- D) W&B table ----
#---------------------------------------------------------------------------------------------------------------------------#
        table = wandb.Table(columns=["image","predicted","reference"])
        for img, pred, ref in demo_samples:
            wb_img = wandb.Image(((img + 1) * 127.5).astype("uint8"))
            table.add_data(wb_img, pred, ref)
        wandb.log({"eval/samples_table": table}, step=step)
#---------------------------------------------------------------------------------------------------------------------------#
                            # 5) Save final params
#---------------------------------------------------------------------------------------------------------------------------#
with open("params_final.pkl", "wb") as f:
    pickle.dump(params, f)
wandb.save("params_final.pkl")
wandb.finish()


[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /usr/share/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
