In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q -U transformers datasets accelerate

In [None]:
!pip install peft bitsandbytes

In [None]:
!pip install ipywidgets

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from datasets import load_dataset
ds = load_dataset('henrik-dra/energy-meter')
print(ds)

In [None]:
ds_train=ds['train']
ds_test=ds['test']

ds_test[11]
ds_test[11]['image']

In [None]:
ds_test[11]['label']

In [None]:
from transformers import PaliGemmaProcessor
model_id = "google/paligemma-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(model_id)

In [12]:
import torch
import numpy as np
import builtins
builtins.np = np

device = "cuda"

def collate_fn(examples):
  texts = ["extract meter value from this image" for example in examples]
  labels= [example['label'] for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest")

  tokens = tokens.to(torch.bfloat16).to(device)
  return tokens

In [None]:
from transformers import PaliGemmaForConditionalGeneration
import torch

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = False


In [None]:
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()



In [15]:
from transformers import TrainingArguments
args=TrainingArguments(
            num_train_epochs=2,
            remove_unused_columns=False,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=2,
            warmup_steps=2,
            learning_rate=2e-5,
            weight_decay=1e-6,
            adam_beta2=0.999,
            logging_steps=100,
            optim="adamw_torch",
            save_strategy="steps",
            save_steps=138,
            push_to_hub=False,
            save_total_limit=1,
            output_dir="paligemma_meter",
            bf16=True,
            report_to=["tensorboard"],
            dataloader_pin_memory=False
        )


In [None]:
from transformers import Trainer

trainer = Trainer(
        model=model,
        train_dataset=ds_train ,
        eval_dataset=ds_test,
        data_collator=collate_fn,
        args=args
        )

trainer.train()

In [None]:
!pip -q install --upgrade pip
!pip -q install "transformers>=4.42.0" accelerate bitsandbytes gradio pillow safetensors
!pip install gradio

import os
import numpy as np
from PIL import Image, ImageDraw
import torch
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
import gradio as gr


# If you have a local fine-tuned checkpoint, set it here, e.g.:
MODEL_ID = "/content/drive/MyDrive/paligemma_meter/checkpoint-70"
#MODEL_ID = None  # Leave None to use the base model below.
BASE_MODEL = "google/paligemma-3b-pt-224"
BASE_PROCESSOR_ID = BASE_MODEL

# Use 4-bit quantization only when CUDA is available
use_4bit = torch.cuda.is_available()
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
) if use_4bit else None

chosen_model_id = MODEL_ID if MODEL_ID else BASE_MODEL
print(f"Loading model from: {chosen_model_id}")
print(f"CUDA available: {torch.cuda.is_available()}  |  4-bit: {bool(bnb_config)}")

# ---------------------------
# Load model & processor
# ---------------------------
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

model = PaliGemmaForConditionalGeneration.from_pretrained(
    chosen_model_id,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=dtype
)
model.eval()

processor = AutoProcessor.from_pretrained(BASE_PROCESSOR_ID)

eos_token_id = processor.tokenizer.eos_token_id
pad_token_id = processor.tokenizer.pad_token_id or eos_token_id

print("Model and processor ready.")

# ---------------------------
# Inference helpers
# ---------------------------
def generate_response(question, image):
    if not question:
        return "Please enter a question."
    if image is None:
        return "Please upload an image."

    # Accept numpy array or PIL.Image
    if isinstance(image, np.ndarray):
        raw_image = Image.fromarray(image).convert("RGB")
    elif isinstance(image, Image.Image):
        raw_image = image.convert("RGB")
    else:
        return "Unsupported image type. Please upload a standard image."

    # A short instruction tends to help PaliGemma be concise
    prompt = f"Answer briefly: {question}"

    inputs = processor(text=prompt, images=raw_image, return_tensors="pt")
    # Move to the same device(s) as the model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.inference_mode():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            eos_token_id=eos_token_id,
            pad_token_id=pad_token_id,
        )

    # Decode only the new tokens after the prompt
    input_len = inputs["input_ids"].shape[-1]
    gen_ids = output_ids[0][input_len:]
    response = processor.batch_decode(gen_ids.unsqueeze(0), skip_special_tokens=True)[0].strip()

    if not response:
        # Fallback if nothing new was generated or the model echoed prompt
        response = processor.batch_decode(output_ids, skip_special_tokens=True)[0].replace(prompt, "").strip()

    return response or "(no answer generated)"

def make_local_examples(root="/tmp/vlm_examples"):
    os.makedirs(root, exist_ok=True)

    # Simple, synthetic images so we don't rely on the network
    def make_traffic_light(path):
        img = Image.new("RGB", (512, 512), "gray")
        d = ImageDraw.Draw(img)
        d.rectangle([200, 80, 312, 440], fill="black")
        d.ellipse([215, 95, 297, 177], fill="red")
        d.ellipse([215, 205, 297, 287], fill="yellow")
        d.ellipse([215, 315, 297, 397], fill="green")
        img.save(path)

    def make_people_count(path, n=3):
        img = Image.new("RGB", (512, 512), "white")
        d = ImageDraw.Draw(img)
        xs = [100, 256, 412][:n]
        for x in xs:
            d.ellipse([x-30, 120-30, x+30, 120+30], fill="black")    # head
            d.rectangle([x-20, 150, x+20, 300], fill="black")        # torso
            d.line([x-20, 200, x-50, 260], fill="black", width=12)   # left arm
            d.line([x+20, 200, x+50, 260], fill="black", width=12)   # right arm
            d.line([x-10, 300, x-40, 380], fill="black", width=12)   # left leg
            d.line([x+10, 300, x+40, 380], fill="black", width=12)   # right leg
        img.save(path)

    def make_animal_action(path):
        img = Image.new("RGB", (512, 512), "skyblue")
        d = ImageDraw.Draw(img)
        d.rectangle([0, 380, 512, 512], fill="green")                # ground
        d.rectangle([260, 350, 320, 380], fill="brown")              # obstacle
        d.ellipse([120, 260, 240, 340], fill="saddlebrown")          # body
        d.ellipse([220, 250, 260, 290], fill="saddlebrown")          # head
        d.polygon([(240,300),(270,290),(260,320)], fill="saddlebrown")  # tail
        d.arc([90, 230, 300, 400], 320, 20, width=4, fill="white")   # motion arc
        img.save(path)

    p1 = os.path.join(root, "traffic_light.jpg")
    p2 = os.path.join(root, "people.jpg")
    p3 = os.path.join(root, "animal_jump.jpg")

    if not os.path.exists(p1): make_traffic_light(p1)
    if not os.path.exists(p2): make_people_count(p2, n=3)
    if not os.path.exists(p3): make_animal_action(p3)

    return p1, p2, p3

# ---------------------------
# Gradio UI
# ---------------------------
p1, p2, p3 = make_local_examples()

def chat_interface(question, image):
    return generate_response(question, image)

with gr.Blocks() as demo:
    gr.Markdown("# Fine-tuned PaliGemma (VQA Demo)")
    gr.Markdown("Upload an image and ask a concise question about it.")

    with gr.Row():
        with gr.Column(scale=1):
            image = gr.Image(label="Upload Image", type="numpy")
        with gr.Column(scale=1):
            question = gr.Textbox(label="Question", placeholder="e.g., What color is the traffic light?", lines=1)
            predict_button = gr.Button("Predict", variant="primary")
            response = gr.Textbox(label="Response")

    gr.Examples(
        examples=[
            ["What color is the traffic light?", p1],
            ["How many people are in the photo?", p2],
            ["What is the animal doing?", p3],
        ],
        inputs=[question, image],
        cache_examples=False  # avoid prefetching that might hit the network
    )

    predict_button.click(fn=chat_interface, inputs=[question, image], outputs=response)

demo.launch(share=True, debug=True)
