# Train BLIP on Wallet Captions in Google Colab

**Instructions:**
1. Make sure you are using a GPU runtime (**Runtime > Change runtime type > T4 GPU**).
2. Upload `dataset.zip` to the Colab files pane on the left.
3. Run the cells below!


In [None]:
!unzip -q -o dataset.zip
print("‚úÖ Dataset extracted successfully!")

In [None]:
!pip install -q transformers torch torchvision Pillow matplotlib

In [None]:
import os
import json
import random
from PIL import Image
import matplotlib.pyplot as plt

print("--- üõ†Ô∏è PATH VALIDATION TEST ---")
try:
    with open("wallet_captions.json", "r") as f:
        data = json.load(f)
    keys = list(data.keys())
    samples = random.sample(keys, min(3, len(keys)))
    for key in samples:
        pure_filename = key.replace('\\', '/').split('/')[-1]
        colab_path = f"/content/wallet/{pure_filename}"
        exists = os.path.exists(colab_path)
        print(f"\nJSON Key: {key}")
        print(f"Resolved Path: {colab_path}")
        print(f"File Exists? {'‚úÖ YES' if exists else '‚ùå NO'}")
        if exists:
            img = Image.open(colab_path)
            plt.figure(figsize=(2, 2))
            plt.imshow(img)
            plt.title(f"Found: {pure_filename}")
            plt.axis('off')
            plt.show()
        else:
            print(f"üö® ERROR: Cannot find image at {colab_path}. Please check your unzipped folder name!")
except Exception as e:
    print(f"Error during validation: {e}")

In [None]:
import os
import json
import time
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from torch.optim import AdamW
from tqdm.auto import tqdm

# --- CONFIGURATION ---
JSON_FILE = "wallet_captions.json"
MODEL_ID = "Salesforce/blip-image-captioning-base"
EPOCHS = 3
BATCH_SIZE = 4
LEARNING_RATE = 5e-5
SAVE_PATH = "./finetuned_wallet_blip"

# --- DATA PREPARATION ---
class WalletDataset(Dataset):
    def __init__(self, json_path, processor):
        print(f"[INFO] Loading JSON annotations from {json_path}...")
        if not os.path.exists(json_path):
            raise FileNotFoundError(f"\n‚ùå CRITICAL ERROR: Could not find {json_path}. Please make sure you uploaded and extracted dataset.zip!")
            
        with open(json_path, "r", encoding="utf-8") as f:
            self.data_dict = json.load(f)
            
        self.image_paths = list(self.data_dict.keys())
        self.processor = processor
        print(f"[INFO] Successfully loaded {len(self.image_paths)} annotations.")
        
    def __len__(self):
        return len(self.image_paths)
    
    def _create_caption(self, features):
        color = features.get("color", "unknown").lower()
        material = features.get("material_type", "unknown").lower()
        wallet_type = features.get("type_of_wallet", "wallet").lower()
        pattern = features.get("pattern", "solid").lower()
        brand = features.get("brand", "unknown").lower()
        
        caption = f"a {pattern} {color} {material} {wallet_type}"
        if brand != "unknown" and brand != "":
            caption += f" by {brand}"
        return caption

    def __getitem__(self, idx):
        img_path_key = self.image_paths[idx]
        features = self.data_dict[img_path_key]
        
        # STRIP WINDOWS PATHS: turns "wallet\\image.jpg" into "image.jpg"
        pure_filename = img_path_key.replace('\\', '/').split('/')[-1]
        
        colab_path = f"/content/wallet/{pure_filename}"
        
        try:
            image = Image.open(colab_path).convert("RGB")
        except Exception as e:
            raise FileNotFoundError(f"Missing image: {colab_path}. Checked key: {img_path_key}")

        caption = self._create_caption(features)
        encoding = self.processor(images=image, text=caption, padding="max_length", return_tensors="pt")
        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        return encoding

# Load processor and Base Model
print(f"[INFO] Loading processor and model: {MODEL_ID}")
print("[INFO] This might take a minute as it downloads the base weights...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using compute device: {device.type.upper()}")
if device.type != "cuda":
    print("‚ö†Ô∏è WARNING: You are not using a GPU! Training will be extremely slow. Please change runtime to T4 GPU.")

processor = BlipProcessor.from_pretrained(MODEL_ID)
model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)
model.to(device)
print("[INFO] Model loaded successfully to GPU.")

# Prepare Dataset & DataLoader
dataset = WalletDataset(JSON_FILE, processor)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print(f"[INFO] DataLoader ready. Batches per epoch: {len(dataloader)}")

# Optimizer
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

print("\n" + "="*50)
print("üöÄ Starting Training Loop...")
print("="*50 + "\n")
model.train()

epoch_losses = []

for epoch in range(EPOCHS):
    start_time = time.time()
    total_loss = 0
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=True)
    
    for batch_idx, batch in enumerate(loop):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        pixel_values = batch["pixel_values"].to(device)
        labels = input_ids.clone()
        
        optimizer.zero_grad()
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            labels=labels
        )
        
        loss = outputs.loss
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        # Detailed progress tracking
        loop.set_postfix({"Batch Loss": f"{loss.item():.4f}"})
        
    avg_loss = total_loss / len(dataloader)
    epoch_losses.append(avg_loss)
    elapsed_time = time.time() - start_time
    print(f"‚úÖ Epoch {epoch+1} Complete | Avg Loss: {avg_loss:.4f} | Time: {elapsed_time:.1f}s")

print("\nüéâ Training finished!")

# Plotting the training loss
plt.figure(figsize=(10, 5))
plt.plot(range(1, EPOCHS + 1), epoch_losses, marker='o', linestyle='-', color='b')
plt.title('Training Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.grid(True)
plt.show()

print(f"\nüíæ Saving the fine-tuned model to {SAVE_PATH}...")
os.makedirs(SAVE_PATH, exist_ok=True)
model.save_pretrained(SAVE_PATH)
processor.save_pretrained(SAVE_PATH)
print("üíæ Model saved successfully!")

In [None]:
import random
print("\n--- üîç RUNNING INFERENCE SANITY CHECK ---")
model.eval()

# Pick a random image from the dataset
sample_idx = random.randint(0, len(dataset) - 1)
sample_img_path_key = dataset.image_paths[sample_idx]

# STRIP WINDOWS PATHS: turns "wallet\\image.jpg" into "image.jpg"
pure_filename = sample_img_path_key.replace('\\', '/').split('/')[-1]
colab_path = f"/content/wallet/{pure_filename}"

try:
    test_img = Image.open(colab_path).convert("RGB")
    
    # Show the image
    plt.imshow(test_img)
    plt.axis('off')
    plt.title(f"Testing: {pure_filename}")
    plt.show()
    
    inputs = processor(test_img, return_tensors="pt").to(device)
    # Generate prediction
    out = model.generate(**inputs, max_new_tokens=50)
    predicted_caption = processor.decode(out[0], skip_special_tokens=True)
    
    # Get ground truth
    features = dataset.data_dict[sample_img_path_key]
    ground_truth = dataset._create_caption(features)
    
    print(f"üéØ Ground Truth JSON : {features}")
    print(f"üìù Target Caption    : {ground_truth}")
    print(f"ü§ñ Model Prediction  : {predicted_caption}")
    
except Exception as e:
    print(f"‚ùå Inference failed: {e}")
