# Finetuned PaliGemma SAE Training

### Import libraries and define SAE class

In [5]:
import os
import sys
import flax.serialization
import msgpack
# TPUs with
# if "COLAB_TPU_ADDR" in os.environ:
#   raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

# === Standard Library ===
import os
import gc
import copy
from threading import Thread
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

# === Third-Party Libraries ===
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import requests
import wandb
import multiprocessing

# === PyTorch ===
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Subset, Dataset
from torch.cuda.amp import autocast
import torch

# === Transformers & Datasets ===
from transformers import (
    AutoModelForVision2Seq,
    AutoProcessor,
    PaliGemmaForConditionalGeneration,
)
from datasets import load_dataset

# === TorchVision ===
from torchvision.utils import make_grid

# === TensorFlow (if needed) ===
import tensorflow as tf

import os
import flax
import jax
import ml_collections
import sentencepiece  # Used in tokenization (not used yet)
from big_vision.models.proj.paligemma import paligemma
import multiprocessing as mp
from PIL import Image
import tensorflow_datasets as tfds

In [None]:
device_num = 4 # CUDA GPU number

DIM_IN = 2048
DIM_HIDDEN = 4096
LLM_VARIANT = "gemma2_2b"
VOCAB_SIZE = 257_152
IMG_ENCODER_VARIANT = "So400m/14"
CHECKPOINT_DIR = "/home/henrytsai/dhruv"
CHECKPOINT_STEP = 5  # Change this to load a different checkpoint
MODEL_PATH = "/home/henrytsai/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-896/1/./paligemma2-3b-pt-896.b16.npz"  # <-- You need to set this to your model's path


# === Custom Dataset ===
class PromptImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

class SparseAutoencoder(nn.Module):
    def __init__(self, d_in=DIM_IN, d_hidden=DIM_HIDDEN):
        super().__init__()
        self.encoder = nn.Linear(d_in, d_hidden)
        self.activation = nn.ReLU()
        self.decoder = nn.Linear(d_hidden, d_in)

    def forward(self, x):
        z = self.activation(self.encoder(x))
        x_recon = self.decoder(z)
        return x_recon, z

# === Checkpoint loading function ===
def load_checkpoint(template_params, step, save_dir=CHECKPOINT_DIR):
    """
    Load Flax model parameters from a msgpack checkpoint file.
    
    Args:
        template_params: parameter tree to match structure of saved checkpoint.
        step: training step number of the checkpoint (e.g., 1 for checkpoint_0001).
        save_dir: path to the directory containing checkpoint files.

    Returns:
        A PyTree of model parameters with trained weights loaded.
    """
    filename = f"checkpoint_{step:04d}.msgpack"
    path = os.path.join(save_dir, filename)
    
    if not os.path.exists(path):
        raise FileNotFoundError(f"Checkpoint file not found: {path}")
    
    with open(path, "rb") as f:
        raw_bytes = f.read()
    
    return flax.serialization.from_bytes(template_params, raw_bytes)

# === Model configuration ===
# Define architecture for language and image components
model_config = ml_collections.FrozenConfigDict({
    "llm": {
        "vocab_size": VOCAB_SIZE,
        "variant": LLM_VARIANT,
        "final_logits_softcap": 0.0
    },
    "img": {
        "variant": IMG_ENCODER_VARIANT,
        "pool_type": "none",
        "scan": True,
        "dtype_mm": "float16"  # Use half-precision for multimodal encoder
    }
})

dtype = torch.float16
if torch.cuda.is_available():
    torch.cuda.set_device(device_num)
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    dtype = torch.bfloat16

device = f"cuda:{device_num}" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda:4


### Load model (40sec)

In [6]:
# Load model structure (same config as used during training)
model = PaliGemmaForConditionalGeneration(model_config).to(device)

# Load weights from saved checkpoint (.pt file)
state_dict = torch.load("/home/henrytsai/dhruv/roboterp/finetuned_paligemma.pt", map_location=device)
model.load_state_dict(state_dict)
model.eval()


ValueError: Parameter config in `PaliGemmaForConditionalGeneration(config)` should be an instance of class `PretrainedConfig`. To create a model from a pretrained model use `model = PaliGemmaForConditionalGeneration.from_pretrained(PRETRAINED_MODEL_NAME)`

### Load dataset (30sec)

In [4]:
# --------- helpers you already wrote ----------
NUM_FRAMES = 6               # grid uses evenly spaced frames
FRAMES_PER_ROW = 3          

def preprocess_image(im: Image.Image) -> Image.Image:
    return im.resize((896, 896), Image.BILINEAR)   # <‑‑ used ONLY for display

def subsample_frames(images, num_samples):
    if len(images) <= num_samples:
        return images
    idx = np.linspace(0, len(images)-1, num=num_samples, dtype=int)
    return [images[i] for i in idx]

def stack_images_horizontally(imgs):
    w, h = zip(*(im.size for im in imgs))
    canvas = Image.new("RGB", (sum(w), max(h)))
    x = 0
    for im in imgs:
        canvas.paste(im, (x, 0))
        x += im.width
    return canvas

def stack_images_grid(imgs, frames_per_row=FRAMES_PER_ROW):
    rows = [imgs[i:i+frames_per_row] for i in range(0, len(imgs), frames_per_row)]
    rows = [stack_images_horizontally(r) for r in rows]
    as_np = np.vstack([np.asarray(r) for r in rows])
    return Image.fromarray(as_np)

def load_droid_subset(return_dict):
    ds = tfds.load("droid_100", split="train", data_dir="gs://gresearch/robotics")

    frames, prompts = [], []
    for episode in ds.take(100):                        # ←‑‑ keep your 10‑episode cap
        steps = list(episode["steps"])

        # ------- build the 8‑frame grid -------
        all_imgs = [Image.fromarray(step["observation"]["wrist_image_left"].numpy())
                    for step in steps]

        # pick evenly spaced frames
        sampled = subsample_frames(all_imgs, NUM_FRAMES)

        # build H×W grid (typically 4×2) and resize to 896x896
        grid_pil  = stack_images_grid(sampled)          # produced by helper above
        grid_pil  = grid_pil.resize((896, 896), Image.BILINEAR)

        frames.append(grid_pil)     # **store the PIL grid**
        prompts.append(steps[0]["language_instruction"].numpy().decode("utf-8"))

    return_dict["frames"]  = frames
    return_dict["prompts"] = prompts

manager = multiprocessing.Manager()
return_dict = manager.dict()
p = multiprocessing.Process(target=load_droid_subset, args=(return_dict,))
p.start()
p.join()

frames, prompts = return_dict["frames"], return_dict["prompts"]

2025-05-04 18:09:20.831322: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NOT_INITIALIZED: initialization error
