# Finetuned PaliGemma SAE Training

### Import libraries and define SAE class

In [4]:
# === Standard Library ===
import os
import gc
import copy
from threading import Thread
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor

# === Third-Party Libraries ===
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import requests
import wandb
import multiprocessing

# === PyTorch ===
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader, Subset, Dataset
from torch.cuda.amp import autocast
import torch

# === Transformers & Datasets ===
from transformers import (
    AutoModelForVision2Seq,
    AutoProcessor,
    PaliGemmaForConditionalGeneration,
)
from datasets import load_dataset

# === TorchVision ===
from torchvision.utils import make_grid

# === TensorFlow (if needed) ===
import tensorflow as tf

DIM_IN = 2048
DIM_HIDDEN = 4096

# === Custom Dataset ===
class PromptImageDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

class SparseAutoencoder(nn.Module):
    def __init__(self, d_in=DIM_IN, d_hidden=DIM_HIDDEN):
        super().__init__()
        self.encoder = nn.Linear(d_in, d_hidden)
        self.activation = nn.ReLU()
        self.decoder = nn.Linear(d_hidden, d_in)

    def forward(self, x):
        z = self.activation(self.encoder(x))
        x_recon = self.decoder(z)
        return x_recon, z

dtype = torch.float16
device_num = 7
if torch.cuda.is_available():
    torch.cuda.set_device(device_num)
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    dtype = torch.bfloat16

device = f"cuda:{device_num}" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

import os
import flax
import jax
import ml_collections
import sentencepiece  # Used in tokenization (not used yet)
from big_vision.models.proj.paligemma import paligemma

# === Constants ===
LLM_VARIANT = "gemma2_2b"
VOCAB_SIZE = 257_152
IMG_ENCODER_VARIANT = "So400m/14"
CHECKPOINT_DIR = "/home/henrytsai/dhruv"
CHECKPOINT_STEP = 5  # Change this to load a different checkpoint
MODEL_PATH = "/home/henrytsai/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-896/1/./paligemma2-3b-pt-896.b16.npz"  # <-- You need to set this to your model's path

# === Model configuration ===
# Define architecture for language and image components
model_config = ml_collections.FrozenConfigDict({
    "llm": {
        "vocab_size": VOCAB_SIZE,
        "variant": LLM_VARIANT,
        "final_logits_softcap": 0.0
    },
    "img": {
        "variant": IMG_ENCODER_VARIANT,
        "pool_type": "none",
        "scan": True,
        "dtype_mm": "float16"  # Use half-precision for multimodal encoder
    }
})

# === Initialize empty model parameters ===
# This sets up the parameter tree structure (but doesn't load trained weights yet)
init_params = paligemma.load(None, MODEL_PATH, model_config)

# === Checkpoint loading function ===
def load_checkpoint(template_params, step, save_dir=CHECKPOINT_DIR):
    """
    Load Flax model parameters from a msgpack checkpoint file.
    
    Args:
        template_params: parameter tree to match structure of saved checkpoint.
        step: training step number of the checkpoint (e.g., 1 for checkpoint_0001).
        save_dir: path to the directory containing checkpoint files.

    Returns:
        A PyTree of model parameters with trained weights loaded.
    """
    filename = f"checkpoint_{step:04d}.msgpack"
    path = os.path.join(save_dir, filename)
    
    if not os.path.exists(path):
        raise FileNotFoundError(f"Checkpoint file not found: {path}")
    
    with open(path, "rb") as f:
        raw_bytes = f.read()
    
    return flax.serialization.from_bytes(template_params, raw_bytes)

# === Load trained model parameters ===
params = load_checkpoint(init_params, step=CHECKPOINT_STEP)



Using device: cuda:7


ModuleNotFoundError: No module named 'big_vision'