# Finetuned PaliGemma SAE Training

### Import libraries and define helper functions and constants


In [16]:
from transformers import (
    AutoModelForVision2Seq,
    AutoProcessor,
    PaliGemmaForConditionalGeneration,
    PretrainedConfig
)
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from PIL import Image
from tqdm.notebook import tqdm
from torch import nn
import torch
import numpy as np
import tensorflow_datasets as tfds
import gc
import wandb

BATCH_SIZE = 2
CUDA_DEVICE_INDEX = 3
FINETUNE_CHECKPOINT_PATH = "/home/henrytsai/dhruv/roboterp/finetuned_paligemma.pt"

MODEL_ID = "google/paligemma2-3b-pt-896"
LAYER_IDX = -1
NUM_FRAMES = 6               
FRAMES_PER_ROW = 3

DROID_DATASET_NAME = "droid_100"
DROID_DATASET_SPLIT = "train"
DROID_DATASET_GCS_DIR = "gs://gresearch/robotics"
NUM_EPISODES = 100
NUM_FRAMES = 6
TARGET_IMAGE_SIZE = (896, 896)  # Width × Height
DIM_IN = 2048
DIM_HIDDEN = 4096

# === Dataset loading and preprocessing ===
def load_droid_subset(num_episodes=100):
    ds = tfds.load(DROID_DATASET_NAME, split=DROID_DATASET_SPLIT, data_dir=DROID_DATASET_GCS_DIR)

    frames, prompts = [], []
    for episode in ds.take(num_episodes):
        steps = list(episode["steps"])

        # Extract and process wrist camera images
        all_imgs = [Image.fromarray(step["observation"]["wrist_image_left"].numpy())
                    for step in steps]
        sampled = subsample_frames(all_imgs, NUM_FRAMES)

        # Build grid and resize
        grid_pil = stack_images_grid(sampled).resize(TARGET_IMAGE_SIZE, Image.BILINEAR)

        frames.append(grid_pil)
        prompts.append(steps[0]["language_instruction"].numpy().decode("utf-8"))

    return frames, prompts

vision_acts = {}

def vision_hook(module, input, output):
    vision_acts["activation"] = output

def preprocess_image(im: Image.Image) -> Image.Image:
    return im.resize(TARGET_IMAGE_SIZE, Image.BILINEAR)   # <‑‑ used ONLY for display

def subsample_frames(images, num_samples):
    if len(images) <= num_samples:
        return images
    idx = np.linspace(0, len(images)-1, num=num_samples, dtype=int)
    return [images[i] for i in idx]

def stack_images_horizontally(imgs):
    w, h = zip(*(im.size for im in imgs))
    canvas = Image.new("RGB", (sum(w), max(h)))
    x = 0
    for im in imgs:
        canvas.paste(im, (x, 0))
        x += im.width
    return canvas

def stack_images_grid(imgs, frames_per_row=FRAMES_PER_ROW):
    rows = [imgs[i:i+frames_per_row] for i in range(0, len(imgs), frames_per_row)]
    rows = [stack_images_horizontally(r) for r in rows]
    as_np = np.vstack([np.asarray(r) for r in rows])
    return Image.fromarray(as_np)

# === Dataset wrapper ===
class PromptImageDataset(Dataset):
    def __init__(self, frames, prompts):
        self.frames = frames
        self.prompts = prompts

    def __len__(self):
        return len(self.frames)

    def __getitem__(self, idx):
        return self.frames[idx], self.prompts[idx]

# === Custom collate function ===
def collate_fn(batch):
    batch_frames, batch_prompts = zip(*batch)  # unzip into two lists
    return list(batch_frames), list(batch_prompts)

class SparseAutoencoder(nn.Module):
    def __init__(self, d_in=DIM_IN, d_hidden=DIM_HIDDEN):
        super().__init__()
        self.encoder = nn.Linear(d_in, d_hidden)
        self.activation = nn.ReLU()
        self.decoder = nn.Linear(d_hidden, d_in)

    def forward(self, x):
        z = self.activation(self.encoder(x))
        x_recon = self.decoder(z)
        return x_recon, z

dtype = torch.float16
if torch.cuda.is_available():
    torch.cuda.set_device(CUDA_DEVICE_INDEX)
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    dtype = torch.bfloat16

device = f"cuda:{CUDA_DEVICE_INDEX}" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda:3


### Load finetuned PaliGemma model (20sec)

In [3]:
device = torch.device(f"cuda:{CUDA_DEVICE_INDEX}")

model = PaliGemmaForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16
).to(device)

checkpoint = torch.load(FINETUNE_CHECKPOINT_PATH, map_location=device)
model.load_state_dict(checkpoint)

model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4096, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
         

### Load DROID Dataset (40sec)

In [4]:
frames, prompts = load_droid_subset(100)
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
hook = model.vision_tower.vision_model.encoder.layers[LAYER_IDX].register_forward_hook(vision_hook)

I0000 00:00:1746388667.058563  673128 gpu_device.cc:2018] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 887 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0
I0000 00:00:1746388667.062103  673128 gpu_device.cc:2018] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 31496 MB memory:  -> device: 1, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:05.0, compute capability: 8.0
I0000 00:00:1746388667.064464  673128 gpu_device.cc:2018] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 37640 MB memory:  -> device: 2, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:06.0, compute capability: 8.0
I0000 00:00:1746388667.066806  673128 gpu_device.cc:2018] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 25012 MB memory:  -> device: 3, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:07.0, compute capability: 8.0
I0000 00:00:1746388667.069387  673128 gpu_device.cc:2018] Created 

### Collect activations

In [5]:
# === Build dataloader ===
dataset = PromptImageDataset(frames, prompts)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
all_activations = []
token_counts = []

with torch.no_grad():
    for batch_frames, batch_prompts in tqdm(dataloader):
        inputs = processor(
            text=[f"<image> {p}" for p in batch_prompts],
            images=batch_frames,
            return_tensors="pt",
            padding=True
        ).to(device, torch.float16)

        _ = model(**inputs)

        raw_output = vision_acts["activation"]
        batch_acts = raw_output[0] if isinstance(raw_output, tuple) else raw_output
        batch_acts = batch_acts.detach().cpu()

        all_activations.append(batch_acts)
        token_counts.append(len(batch_prompts))
        
# === Combine all activations ===
activations = torch.cat(all_activations, dim=0)
print("Collected activations:", activations.shape)

torch.save(activations, "activations.pt")
activations_fp16 = activations.half()
torch.save(activations_fp16, "activations_fp16.pt")
torch.save(token_counts, "token_counts.pt")

del activations
vision_acts.clear()
torch.cuda.empty_cache()
gc.collect()


  0%|          | 0/50 [00:00<?, ?it/s]

Collected activations: torch.Size([100, 4096, 1152])


1957

### Train SAE

In [None]:
all_activations = torch.load("/home/henrytsai/henry/roboterp/activations.pt")
token_counts = torch.load("/home/henrytsai/henry/roboterp/token_counts.pt")

# === Hyperparameters ===
hidden_multiplier = 16
learning_rate = 1e-4
sparsity_weight = 5e-3
n_epochs = 20
batch_size = 64

# === Initialize Weights & Biases ===
wandb.init(project="finetuned-paligemma", name="sunday", config={
    "hidden_multiplier": hidden_multiplier,
    "learning_rate": learning_rate,
    "sparsity_weight": sparsity_weight,
    "n_epochs": n_epochs,
    "batch_size": batch_size,
})

# === Prepare SAE ===
d_in = all_activations.shape[-1]
d_hidden = hidden_multiplier * d_in

sae = SparseAutoencoder(d_in=d_in, d_hidden=d_hidden).to(device) 
optimizer = torch.optim.AdamW(sae.parameters(), lr=learning_rate)
loss_fn = nn.MSELoss()

# === Dataset & Loader ===
train_dataset = torch.utils.data.TensorDataset(all_activations)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
epoch_losses = []

for epoch in range(n_epochs):
    epoch_loss = 0.0
    for batch, in train_loader:
        batch = batch.to(device, non_blocking=True)
        recon, z = sae(batch)

        loss = loss_fn(recon, batch) + sparsity_weight * torch.mean(torch.abs(z))
        epoch_loss += loss.item() * batch.size(0)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = epoch_loss / len(train_dataset)
    epoch_losses.append(avg_loss)

    # Log to wandb every epoch
    wandb.log({"loss": avg_loss, "epoch": epoch + 1})
    print(f"Epoch {epoch}: Loss {avg_loss:.6f}")

print("Finished training Sparse Autoencoder!")

# === Save model with timestamp ===
save_dir = "checkpoints"
os.makedirs(save_dir, exist_ok=True)

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
model_path = os.path.join(save_dir, f"sae_{timestamp}.pth")

torch.save({
    "state_dict": sae.state_dict(),
    "d_in": d_in,
    "d_hidden": d_hidden,
}, model_path)
wandb.save(model_path)

print(f"Saved SAE to {model_path}")

# === Plot loss curve (optional) ===
plt.plot(range(1, n_epochs + 1), epoch_losses, label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("SAE Training Loss")
plt.grid(True)
plt.tight_layout()
plt.show()
wandb.log({"loss_curve": wandb.Image(plt)})