# 🚀 Fine-Tuning a Multimodal Reward Model

**Project Goal:** To build, train, and test a multimodal reward model capable of understanding and scoring textual descriptions of webpage screenshots.

This notebook documents the entire end-to-end process, showcasing a real-world workflow for advanced AI alignment. We will start with raw data collection, create a preference dataset, fine-tune a powerful vision-language model using advanced memory-saving techniques, and finally, test the resulting model.

**Why is this important?**
A reward model is a critical component for aligning powerful AI systems with human preferences. It's the cornerstone of techniques like Reinforcement Learning from Human Feedback (RLHF) and is essential for developing capable web-automation agents that can understand visual interfaces.

**Technology Stack:**
* **Model:** LLaVA 1.5 (7B parameters)
* **Framework:** PyTorch
* **Key Libraries:** Hugging Face `transformers`, `accelerate`, and `bitsandbytes`
* **Platform:** Google Colab (A100 GPU recommended)

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


### Step 1: Data Collection - Capturing Webpage Screenshots

Every machine learning project begins with data. Our goal is to train a model that understands webpages, so our data will be screenshots of live websites.

This cell sets up the environment to programmatically capture these screenshots.

**What this code does:**
1.  **Installs Dependencies:** It uses `pip` to install `selenium` and `webdriver-manager`, which are powerful tools for browser automation.
2.  **Automates Chrome:** It controls a headless (invisible) Chrome browser directly within our Colab notebook.
3.  **Captures Screenshots:** It iterates through a predefined list of `URLS`, navigates to each page, and saves a full-height PNG screenshot to a designated folder in Google Drive.
4.  **Logs Metadata:** It creates a `captured_images.jsonl` file to keep a record of which image corresponds to which URL.

In [3]:
# =============================================================================
# CHROME INSTALLATION TO BYPASS SYSTEM PACKAGES
# =============================================================================

# Step 1: Download and Install the official Google Chrome browser
print("Downloading official Google Chrome browser...")
!wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
print("Installing Google Chrome...")
# The -y flag automatically answers 'yes' to prompts
!apt-get install -y ./google-chrome-stable_current_amd64.deb

# Step 2: Install Selenium and the driver manager
!pip install selenium webdriver-manager

import os
import uuid
import json
import time
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager

# Step 3: Define the setup function pointing to our new Chrome installation
def setup_driver():
    """Sets up the Selenium WebDriver using a manually installed Chrome binary."""
    options = webdriver.ChromeOptions()

    # --- THIS IS THE CRITICAL NEW LINE ---
    # Point Selenium to the binary of our manually installed Chrome
    options.binary_location = "/opt/google/chrome/google-chrome"
    # -------------------------------------

    # Add all the stability flags from before
    options.add_argument('--headless')
    options.add_argument('--no-sandbox')
    options.add_argument('--disable-dev-shm-usage')
    options.add_argument('--disable-gpu')
    options.add_argument('--window-size=1920,1080')

    print("Setting up Chrome driver to use the manually installed Google Chrome...")
    # webdriver-manager will now detect the version of OUR chrome and get the right driver
    service = Service(ChromeDriverManager().install())
    driver = webdriver.Chrome(service=service, options=options)
    print("Driver setup complete.")
    return driver

# Step 4: Run the screenshot capture
# Ensure Google Drive is mounted
from google.colab import drive
drive.mount('/content/drive/', force_remount=True) # force_remount after new installs

PROJECT_DIR = "/content/drive/MyDrive/MultimodalRewardModel"
IMAGE_DIR = f"{PROJECT_DIR}/images"
DATA_DIR = f"{PROJECT_DIR}/data"

os.makedirs(IMAGE_DIR, exist_ok=True)
os.makedirs(DATA_DIR, exist_ok=True)

URLS = [
   'https://www.wikipedia.org/',
   'https://github.com/trending',
   'https://www.allrecipes.com/search?q=pancakes',
   'https://www.nytimes.com/section/technology',
   'https://www.the-numbers.com/weekend-box-office-chart'
]

# Use our new setup function
driver = setup_driver()
image_records = []

for url in URLS:
    try:
        print(f"Navigating to {url}...")
        driver.get(url)
        driver.set_window_size(1280, 800)
        time.sleep(2)

        image_filename = f"{IMAGE_DIR}/{uuid.uuid4().hex}.png"
        driver.save_screenshot(image_filename)

        if os.path.exists(image_filename) and os.path.getsize(image_filename) > 0:
            image_records.append({"url": url, "image_file": image_filename})
            print(f"Successfully captured {url}")
        else:
            print(f"Failed to create a valid screenshot for {url}")

    except Exception as e:
        print(f"Error capturing {url}: {e}")

driver.quit()

# Save records
with open(f"{DATA_DIR}/captured_images.jsonl", "w") as f:
    for record in image_records:
        f.write(json.dumps(record) + '\n')

print("\n✅ Screenshot capture should now be completed successfully!")

Downloading official Google Chrome browser...
--2025-06-17 22:24:05--  https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb
Resolving dl.google.com (dl.google.com)... 74.125.200.93, 74.125.200.190, 74.125.200.91, ...
Connecting to dl.google.com (dl.google.com)|74.125.200.93|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 117745852 (112M) [application/x-debian-package]
Saving to: ‘google-chrome-stable_current_amd64.deb’


2025-06-17 22:24:05 (408 MB/s) - ‘google-chrome-stable_current_amd64.deb’ saved [117745852/117745852]

Installing Google Chrome...
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
Note, selecting 'google-chrome-stable' instead of './google-chrome-stable_current_amd64.deb'
The following additional packages will be installed:
  libvulkan1 mesa-vulkan-drivers
The following NEW packages will be installed:
  google-chrome-stable libvulkan1 mesa-vulkan-drivers
0 upgraded, 3 newly i

### Step 2: Creating a Preference Dataset

A standard Large Language Model (LLM) knows what a webpage *is*, but it doesn't know what a *good description* of a webpage looks like from a human perspective. We align the model with our preferences by creating a special dataset.

This cell builds a **preference dataset**, which is the foundation for training our reward model.

**What this code does:**
1.  **Loads Image Records:** It reads the `captured_images.jsonl` file created in the previous step.
2.  **Defines Preference Pairs:** For each screenshot, we manually write:
    * A **`prompt`**: The question we want to ask the AI (e.g., "What is this webpage showing?").
    * A **`chosen`** response: A high-quality, accurate answer that we want the model to prefer.
    * A **`rejected`** response: A plausible but incorrect or less helpful answer that we want the model to dis-prefer.
3.  **Saves the Dataset:** It saves this data into a `train_preference.jsonl` file. Each line in this file contains a single preference pair (`image_path`, `prompt`, `chosen`, `rejected`), which is the exact format our training loop requires.

In [4]:
import json

# The 'image_records' variable should still be in your environment
# from the previous step. If not, load it from the file.
try:
    print(f"Using {len(image_records)} captured images.")
except NameError:
    print("Loading image records from file...")
    DATA_DIR = "/content/drive/MyDrive/MultimodalRewardModel/data"
    with open(f"{DATA_DIR}/captured_images.jsonl", "r") as f:
        image_records = [json.loads(line) for line in f]

# Manually create higher-quality preference pairs for our captured images.
# This mimics what you would automate with an LLM for a larger dataset.
preference_data = [
    {
        "image": image_records[0]['image_file'], # Wikipedia
        "prompt": "What is the primary purpose of this webpage?",
        "chosen": "This is the homepage for Wikipedia, an online encyclopedia, allowing users to search for information across many languages.",
        "rejected": "This is a shopping website for books."
    },
    {
        "image": image_records[1]['image_file'], # GitHub Trending
        "prompt": "List some of the programming languages visible in the trending repositories section.",
        "chosen": "The page shows trending repositories for languages like Python, TypeScript, and Rust.",
        "rejected": "The only trending languages are Java and C++." # Plausible but incorrect for the specific screenshot
    },
    {
        "image": image_records[2]['image_file'], # Allrecipes
        "prompt": "What kind of recipes are being displayed on this page?",
        "chosen": "This page displays several recipes for pancakes, including 'Good Old-Fashioned Pancakes' and 'Fluffy Pancakes'.",
        "rejected": "This page is about how to bake a chocolate cake."
    },
    {
        "image": image_records[3]['image_file'], # NYT Technology
        "prompt": "Summarize the main headline visible in the screenshot.",
        "chosen": "The main headline appears to be about a recent development or issue in the field of artificial intelligence.",
        "rejected": "The main story is about a recent sports championship."
    },
    {
        "image": image_records[4]['image_file'], # The Numbers
        "prompt": "According to this list, what is the top grosser movie of Weekend Domestic Chart for June 13, 2025?",
        "chosen": "The list is topped by 'How to Train Your Dragon' with the Highest Gross revenue.",
        "rejected": "The top-rated movie is 'Avatar'."
    }
]

# Save the high-quality preference dataset to your Drive
DATA_DIR = "/content/drive/MyDrive/MultimodalRewardModel/data"
preference_filepath = f"{DATA_DIR}/train_preference.jsonl"
with open(preference_filepath, "w") as f:
    for r in preference_data:
        f.write(json.dumps(r) + '\n')

print(f"\n✅ High-quality preference dataset created at: {preference_filepath}")

Using 5 captured images.

✅ High-quality preference dataset created at: /content/drive/MyDrive/MultimodalRewardModel/data/train_preference.jsonl


### Step 3: Environment Setup & Defining Core Components

Before we can train, we need to install the specialized libraries for large-model training and define our core Python classes and functions.


**Installs Libraries:** It installs the specific versions of the libraries we need.
    * `bitsandbytes`: Crucial for enabling 4-bit quantization, which dramatically reduces memory usage.
    * `transformers`, `accelerate`: The core Hugging Face libraries for loading and managing large models.


In [5]:
# Install a specific, known-good set of library versions
!pip install -q tokenizers bitsandbytes
!pip install -U transformers>=4.45.0 accelerate safetensors huggingface_hub
!pip install -q torch torchvision Pillow tqdm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m105.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m99.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m63.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m42.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

### Step 4: Training the Multimodal Reward Model

This is the heart of our project. In this cell, we load the massive pre-trained LLaVA model, apply our custom modifications, and fine-tune it on our preference dataset to teach it our preferences.

**What this code does:**

-1.  **Defines the `PreferenceDataset` Class:** This is a standard PyTorch `Dataset` class that knows how to read our `train_preference.jsonl` file and prepare the data (images and text) for the model.

0.  **Defines the `reward_model_forward` Function:** This custom function defines the "brain" of our reward model. It will be "monkey-patched" onto the base LLaVA model to add a final linear layer (`reward_head`) that outputs a single reward score.

1.  **Sets Hyperparameters:** At the top, we define key settings for our training run, like the `LEARNING_RATE` and `NUM_EPOCHS`.
2.  **Loads Processor & Model:** It loads the LLaVA 1.5 (7B) model. Crucially, it uses two state-of-the-art optimizations:
    * **4-bit Quantization:** `BitsAndBytesConfig(load_in_4bit=True)` shrinks the model's size in memory from ~28GB to ~4GB.
    * **`device_map="auto"`:** The `accelerate` library intelligently splits the quantized model across the GPU and CPU, ensuring it fits.
3.  **Enables Gradient Checkpointing:** `model.gradient_checkpointing_enable()` is another key memory-saving technique that trades a small amount of extra computation time for a massive reduction in memory usage from activations.
4.  **Adds the Reward Head:** It dynamically adds our custom, trainable `reward_head` layer to the model and replaces its default `forward` method with our own.
5.  **Executes the Manual Training Loop:** Instead of a high-level `Trainer`, we use a manual PyTorch loop for maximum control. For each batch of data, it:
    * Calculates the reward score for the `chosen` and `rejected` answers.
    * Computes the **Log Sigmoid loss**, which penalizes the model if the rejected score is not lower than the chosen score.
    * Uses **Gradient Accumulation** to simulate a larger batch size and stabilize training.
    * Updates the model weights using the `AdamW8bit` optimizer.
6.  **Saves a Complete Checkpoint:** At the end of each epoch, it saves the base model, the trained `reward_head` weights, and the `processor` (tokenizer + image processor) files to Google Drive, ensuring we have a complete, reloadable artifact.

In [30]:
import torch
import types
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from transformers import (
    LlavaForConditionalGeneration, get_scheduler, BitsAndBytesConfig,
    CLIPImageProcessor, LlamaTokenizerFast, LlavaProcessor
)
import bitsandbytes as bnb
from PIL import Image
from torch.utils.data import Dataset
import json
import os

# Install necessary libraries
!pip install -q --upgrade bitsandbytes
!pip install -q transformers torch torchvision Pillow tqdm accelerate

# (The PreferenceDataset and reward_model_forward functions are unchanged)
class PreferenceDataset(Dataset):
    def __init__(self, jsonl_file, processor):
        self.data = [json.loads(line) for line in open(jsonl_file)]
        self.processor = processor
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(item['image']).convert("RGB")
        prompt = f"USER: <image>\n{item['prompt']}\nASSISTANT: "
        new_max_length = 2048
        inputs_chosen = self.processor(text=prompt + item['chosen'], images=image, return_tensors="pt", padding="max_length", max_length=new_max_length, truncation=True)
        inputs_rejected = self.processor(text=prompt + item['rejected'], images=image, return_tensors="pt", padding="max_length", max_length=new_max_length, truncation=True)
        return {
            "chosen_pixel_values": inputs_chosen.pixel_values.squeeze(0), "chosen_input_ids": inputs_chosen.input_ids.squeeze(0), "chosen_attention_mask": inputs_chosen.attention_mask.squeeze(0),
            "rejected_pixel_values": inputs_rejected.pixel_values.squeeze(0), "rejected_input_ids": inputs_rejected.input_ids.squeeze(0), "rejected_attention_mask": inputs_rejected.attention_mask.squeeze(0),
        }

def reward_model_forward(self, pixel_values, input_ids, attention_mask, **kwargs):
    outputs = LlavaForConditionalGeneration.forward(
        self, pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True
    )
    last_hidden_state = outputs.hidden_states[-1][:, -1, :]
    last_hidden_state = last_hidden_state.to(self.reward_head.weight.dtype)
    reward = self.reward_head(last_hidden_state)
    return reward

# Main Setup and Training
PROJECT_DIR = "/content/drive/MyDrive/MultimodalRewardModel"
DATA_DIR = f"{PROJECT_DIR}/data"
MODEL_ID = "llava-hf/llava-1.5-7b-hf"
REVISION = "a272c74"

# Hyperparameter Configuration
LEARNING_RATE = 2e-4
NUM_EPOCHS = 5

# 1. Load tokenizer and image processor components
tokenizer = LlamaTokenizerFast.from_pretrained(MODEL_ID, revision=REVISION)
image_processor = CLIPImageProcessor.from_pretrained(MODEL_ID, revision=REVISION)

# 2. Combine them using the specific LlavaProcessor class
processor = LlavaProcessor(image_processor=image_processor, tokenizer=tokenizer)

# =============================================================================
# THE FINAL FIX: Manually set the missing attributes on the processor object.
# =============================================================================
if getattr(processor, 'patch_size', None) is None:
    processor.patch_size = 14
if getattr(processor, 'num_additional_image_tokens', None) is None:
    processor.num_additional_image_tokens = 576
print("✅ Processor loaded and manually patched successfully.")
# =============================================================================

# Configure and load model
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_ID, revision=REVISION, quantization_config=bnb_config, device_map="auto"
)
model.gradient_checkpointing_enable()

# Modify the model in-place
lm_head_device = model.get_output_embeddings().weight.device
hidden_size = model.config.text_config.hidden_size
model.reward_head = torch.nn.Linear(hidden_size, 1, bias=False).to(lm_head_device).to(torch.float32)
model.forward = types.MethodType(reward_model_forward, model)
print(f"Model loaded with 4-bit quantization. Reward head is on device: {model.reward_head.weight.device}")

# Prepare for Training
train_dataset = PreferenceDataset(f"{DATA_DIR}/train_preference.jsonl", processor)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=LEARNING_RATE)
gradient_accumulation_steps = 4
num_update_steps_per_epoch = len(train_dataloader) // gradient_accumulation_steps
if num_update_steps_per_epoch == 0: num_update_steps_per_epoch = 1
num_training_steps = NUM_EPOCHS * num_update_steps_per_epoch
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=3, num_training_steps=num_training_steps)
progress_bar = tqdm(range(len(train_dataloader) * NUM_EPOCHS))

# The Training Loop
model.train()
for epoch in range(NUM_EPOCHS):
    for i, batch in enumerate(train_dataloader):
        target_device = model.get_input_embeddings().weight.device
        batch = {k: v.to(target_device) for k, v in batch.items()}
        rewards_chosen = model(**{k.replace('chosen_',''): v for k, v in batch.items() if 'chosen' in k})
        rewards_rejected = model(**{k.replace('rejected_',''): v for k, v in batch.items() if 'rejected' in k})
        loss = -torch.nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
        if torch.isnan(loss):
            print(f"WARNING: NaN loss detected at step {i+1} of epoch {epoch+1}. Skipping update.")
            continue
        loss = loss / gradient_accumulation_steps
        loss.backward()
        if (i + 1) % gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        progress_bar.update(1)
        progress_bar.set_description(f"Epoch {epoch+1}, Loss: {loss.item() * gradient_accumulation_steps:.4f}")

    # --- Correct Saving Logic ---
    run_name = f"lr{LEARNING_RATE}_epochs{NUM_EPOCHS}"
    epoch_dir = f"{PROJECT_DIR}/reward_model_checkpoint/{run_name}/epoch_{epoch+1}"
    os.makedirs(epoch_dir, exist_ok=True)
    print(f"\nSaving complete checkpoint for epoch {epoch+1} at {epoch_dir}")
    torch.save(model.reward_head.state_dict(), f"{epoch_dir}/reward_head.pt")
    model.save_pretrained(epoch_dir)
    # Explicitly save the processor's components
    tokenizer.save_pretrained(epoch_dir)
    image_processor.save_pretrained(epoch_dir)

print("🎉🎉🎉 Manual training loop complete! 🎉🎉🎉")

✅ Processor loaded and manually patched successfully.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Model loaded with 4-bit quantization. Reward head is on device: cuda:0


  0%|          | 0/25 [00:00<?, ?it/s]


Saving complete checkpoint for epoch 1 at /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_1

Saving complete checkpoint for epoch 2 at /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_2

Saving complete checkpoint for epoch 3 at /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_3

Saving complete checkpoint for epoch 4 at /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_4

Saving complete checkpoint for epoch 5 at /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_5
🎉🎉🎉 Manual training loop complete! 🎉🎉🎉


### Step 5: Testing and Evaluating Our Reward Model

After training is complete, we need to perform inference to verify that our model has learned correctly.

**What this code does:**
1.  **Configures the Checkpoint:** At the top, you can specify which saved checkpoint you want to load for testing.
2.  **Loads the Trained Model:** It loads the base LLaVA model and processor from the specified checkpoint directory.
3.  **Attaches the Custom Reward Head:** It re-creates the `reward_head` layer and loads the fine-tuned weights that we saved separately during training.
4.  **Runs Inference:** It defines a `get_reward_score` function that takes a new image, prompt, and answer, and returns the model's score. It runs in `torch.no_grad()` mode for efficiency.
5.  **Compares Good vs. Bad Answers:** The script runs a test case with a "good" answer and a "bad" answer. The goal is to see if the model assigns a higher score to the good answer, which proves it has learned our preference. You can easily test with an image from a URL or by uploading your own.

In [33]:
import torch
import types
from transformers import (
    AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
)
from PIL import Image
import os
import requests
import warnings

# Suppress unnecessary warnings for cleaner output
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

# =============================================================================
# --- 1. CONFIGURE WHICH CHECKPOINT TO TEST ---
# =============================================================================
#
# Put the hyperparameters of the run you want to test here.
# This must match the folder name of your saved checkpoint.
#
RUN_LEARNING_RATE = 0.0002
RUN_NUM_EPOCHS = 5
EPOCH_TO_TEST = 5 # Which epoch from that run to load

# --- Automatically construct the path to the checkpoint directory ---
PROJECT_DIR = "/content/drive/MyDrive/MultimodalRewardModel"
RUN_NAME = f"lr{RUN_LEARNING_RATE}_epochs{NUM_EPOCHS}"
CHECKPOINT_TO_LOAD = f"{PROJECT_DIR}/reward_model_checkpoint/{RUN_NAME}/epoch_{EPOCH_TO_TEST}"

print(f"✅ Will attempt to load checkpoint from: {CHECKPOINT_TO_LOAD}")
# =============================================================================


# --- Define the forward pass function again (needed for monkey-patching) ---
def reward_model_forward(self, pixel_values, input_ids, attention_mask, **kwargs):
    outputs = LlavaForConditionalGeneration.forward(
        self, pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True
    )
    last_hidden_state = outputs.hidden_states[-1][:, -1, :]
    last_hidden_state = last_hidden_state.to(self.reward_head.weight.dtype)
    reward = self.reward_head(last_hidden_state)
    return reward


# --- 2. Load the fine-tuned model and processor from your chosen checkpoint ---
print(f"\nLoading model and processor from: {CHECKPOINT_TO_LOAD}")
try:
    # Load the full processor from the directory
  processor = AutoProcessor.from_pretrained(CHECKPOINT_TO_LOAD)

    # =============================================================================
    # THE FIX IS HERE: Manually set the missing attributes on the loaded processor
    # =============================================================================
  if getattr(processor, 'patch_size', None) is None:
      print("Patch size is missing after loading, manually setting to 14.")
      processor.patch_size = 14
  if getattr(processor, 'num_additional_image_tokens', None) is None:
      print("Num additional image tokens is missing, manually setting to 576.")
      processor.num_additional_image_tokens = 576
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
  model = LlavaForConditionalGeneration.from_pretrained(
        CHECKPOINT_TO_LOAD,
        quantization_config=bnb_config,
        device_map="auto"
    )
except OSError:
    print(f"❌ ERROR: Checkpoint directory not found at {CHECKPOINT_TO_LOAD}")
    print("Please make sure the configuration at the top of the script matches a saved checkpoint.")
    raise

# --- 3. Re-create and load your custom reward head ---
print("Attaching and loading weights for custom reward head...")
hidden_size = model.config.text_config.hidden_size
reward_head_path = f"{CHECKPOINT_TO_LOAD}/reward_head.pt"

if os.path.exists(reward_head_path):
    model.reward_head = torch.nn.Linear(hidden_size, 1, bias=False)
    model.reward_head.load_state_dict(torch.load(reward_head_path))
    model.reward_head.to(model.device)
    print("✅ Reward head loaded successfully.")
else:
    raise FileNotFoundError(f"reward_head.pt not found in {CHECKPOINT_TO_LOAD}")

# --- 4. Re-apply the custom forward method and set to eval mode ---
model.forward = types.MethodType(reward_model_forward, model)
model.eval() # Set the model to evaluation mode (important!)
print("✅ Model is ready for inference!")


# --- 5. Define an inference function ---
def get_reward_score(image_path, prompt_text, answer_text):
    """
    Loads an image and returns a reward score for a given prompt and answer.
    """
    try:
        image = Image.open(image_path).convert("RGB")
        inputs = processor(text=f"USER: <image>\n{prompt_text}\nASSISTANT: {answer_text}", images=image, return_tensors="pt")
        target_device = model.get_input_embeddings().weight.device
        inputs = {k: v.to(target_device) for k, v in inputs.items()}
        with torch.no_grad():
            score = model(**inputs)
        return score.item()
    except FileNotFoundError:
        print(f"ERROR: Image not found at path: {image_path}")
        return None
    except Exception as e:
        print(f"An error occurred during inference: {e}")
        return None

# =============================================================================
# --- 6. RUN YOUR TEST! ---
# =============================================================================

### --- OPTION A: Test with an Image from a URL ---
# By default, this option will run. Change the URL to test different images.

# print("\n--- Testing with an image from a URL ---")
# # An example screenshot of a GitHub repository page
# image_url = "https://pageflows.com/media/videos/Booking_a_Room_on_Booking.com.mp4-screenshot-.jpg"
# local_filename = "test_image_from_url.jpg"

# # Download the image
# try:
#     response = requests.get(image_url, stream=True)
#     response.raise_for_status()
#     with open(local_filename, 'wb') as f:
#         for chunk in response.iter_content(chunk_size=8192):
#             f.write(chunk)
#     print(f"Successfully downloaded image to {local_filename}")

#     # --- Run the test ---
#     test_prompt = "What is this webpage showing?"
#     good_answer = "This is the page of a Booking.com, showing Stays, Flights, Hotels, Car Rentals."
#     bad_answer = "This is a social media profile page with posts and comments."

#     score_good = get_reward_score(local_filename, test_prompt, good_answer)
#     score_bad = get_reward_score(local_filename, test_prompt, bad_answer)

#     print("\n--- TEST RESULTS ---")
#     print(f"Good Answer Score: {score_good:.4f}")
#     print(f"Bad Answer Score:  {score_bad:.4f}")

#     if score_good is not None and score_bad is not None:
#         if score_good > score_bad:
#             print("\n✅ Success! The model correctly gave the good answer a higher reward.")
#         else:
#             print("\n❌ Failure. The model did not rank the answers as expected.")

# except requests.exceptions.RequestException as e:
#     print(f"Failed to download image from URL: {e}")


### --- OPTION B: Test with an Image from Your Computer ---
# To use this, comment out the "OPTION A" block above and uncomment this one.

from google.colab import files
print("\nPlease upload a screenshot to test...")
uploaded = files.upload()

if uploaded:
    test_image_path = next(iter(uploaded))
    print(f"\nUsing uploaded image: {test_image_path}")

    # Define your own prompt and answers for your uploaded image
    test_prompt = "What is the main topic of this webpage?"
    good_answer = "This is a website displaying products."
    bad_answer = "This is a social media site."

    score_good = get_reward_score(test_image_path, test_prompt, good_answer)
    score_bad = get_reward_score(test_image_path, test_prompt, bad_answer)

    print("\n--- TEST RESULTS ---")
    print(f"Good Answer Score: {score_good:.4f}")
    print(f"Bad Answer Score:  {score_bad:.4f}")
    if score_good is not None and score_bad is not None:
        if score_good > score_bad:
            print("\n✅ Success! The model correctly gave the good answer a higher reward.")
        else:
            print("\n❌ Failure. The model did not rank the answers as expected.")
else:
    print("No file uploaded.")

✅ Will attempt to load checkpoint from: /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_5

Loading model and processor from: /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_5
Patch size is missing after loading, manually setting to 14.


Some weights of the model checkpoint at /content/drive/MyDrive/MultimodalRewardModel/reward_model_checkpoint/lr0.0002_epochs5/epoch_5 were not used when initializing LlavaForConditionalGeneration: ['reward_head.weight']
- This IS expected if you are initializing LlavaForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlavaForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Attaching and loading weights for custom reward head...
✅ Reward head loaded successfully.
✅ Model is ready for inference!

Please upload a screenshot to test...


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


Saving Screenshot 2025-06-17 at 7.54.02 PM.png to Screenshot 2025-06-17 at 7.54.02 PM (3).png

Using uploaded image: Screenshot 2025-06-17 at 7.54.02 PM (3).png

--- TEST RESULTS ---
Good Answer Score: -1.3242
Bad Answer Score:  -2.5322

✅ Success! The model correctly gave the good answer a higher reward.
