---

## Step 1: Setup & Installation

Clone the repository and install all dependencies.

In [None]:
# Clone repository
!git clone https://github.com/ribhu0105-alt/blip-using-pvt-cbam.git
%cd blip-using-pvt-cbam
!pwd
print("✓ Repository cloned")

In [None]:
# Install dependencies (quiet mode)
!pip install -q -r requirements.txt
print("✓ Dependencies installed")

In [None]:
# Verify installation
!python test_import.py

---

## Step 2: Check GPU Availability

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")

if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU found. Training will be slow. Enable GPU in Runtime > Change runtime type > GPU")

---

## Step 3: Mount Google Drive (Optional)

If your dataset is in Google Drive, mount it here.

**Skip this if:** You'll upload files directly or use sample dataset.

In [None]:
from google.colab import drive

drive.mount('/content/drive')
print("✓ Google Drive mounted")
print("Your files are at: /content/drive/MyDrive/")

---

## Step 4: Configure Dataset Paths

**Dataset format required:**

```
dataset/
├── images/
│   ├── image_0001.jpg
│   ├── image_0002.jpg
│   └── ...
└── captions.txt
```

**captions.txt format:**
```
image_0001.jpg	a dog running in the park
image_0002.jpg	a cat sleeping on a bed
...
```

**Edit the paths below to match your dataset location.**

In [None]:
import os

# ========== CONFIGURE YOUR DATASET PATHS HERE ==========

# Option 1: Google Drive path
# IMAGE_ROOT = "/content/drive/MyDrive/your_dataset/images"
# CAPTION_FILE = "/content/drive/MyDrive/your_dataset/captions.txt"

# Option 2: Uploaded ZIP file path (after extraction)
# IMAGE_ROOT = "/content/dataset/images"
# CAPTION_FILE = "/content/dataset/captions.txt"

# Option 3: Use sample dataset for testing (5 random images)
IMAGE_ROOT = "/content/blip-using-pvt-cbam/data/sample_images"
CAPTION_FILE = "/content/blip-using-pvt-cbam/data/captions.txt"

OUTPUT_DIR = "/content/checkpoints"

print(f"Images folder: {IMAGE_ROOT}")
print(f"Captions file: {CAPTION_FILE}")
print(f"Output dir: {OUTPUT_DIR}")

# Verify paths exist
if os.path.exists(IMAGE_ROOT):
    num_images = len(os.listdir(IMAGE_ROOT))
    print(f"✓ Images folder found ({num_images} images)")
else:
    print(f"✗ Images folder NOT found at {IMAGE_ROOT}")
    print("  Please update IMAGE_ROOT above")

if os.path.exists(CAPTION_FILE):
    with open(CAPTION_FILE) as f:
        num_captions = len(f.readlines())
    print(f"✓ Captions file found ({num_captions} captions)")
else:
    print(f"✗ Captions file NOT found at {CAPTION_FILE}")
    print("  Please update CAPTION_FILE above")

---

## Step 5: (Optional) Create Sample Dataset

If you don't have a dataset yet, run this to create random sample images for testing.

**Skip this if** you already have your dataset configured above.

In [None]:
import numpy as np
from PIL import Image

# Create sample dataset with 10 random images
sample_dir = "/content/blip-using-pvt-cbam/data/sample_images"
os.makedirs(sample_dir, exist_ok=True)

sample_captions = [
    "a dog running in the park",
    "a cat sleeping on a bed",
    "people sitting on a bench",
    "a sunset over mountains",
    "a forest path in nature",
    "a city street at night",
    "a beach with waves",
    "a bird flying in sky",
    "flowers in a garden",
    "a car parked on street",
]

# Create random RGB images
for i, caption in enumerate(sample_captions):
    img_array = np.random.randint(0, 256, (384, 384, 3), dtype=np.uint8)
    img = Image.fromarray(img_array, mode='RGB')
    img.save(f"{sample_dir}/sample_{i:04d}.jpg")

# Create captions file
caption_file = "/content/blip-using-pvt-cbam/data/captions.txt"
with open(caption_file, 'w') as f:
    for i, caption in enumerate(sample_captions):
        f.write(f"sample_{i:04d}.jpg\t{caption}\n")

print(f"✓ Created {len(sample_captions)} sample images")
print(f"✓ Created captions file")
print(f"\nSample captions:")
for caption in sample_captions[:3]:
    print(f"  - {caption}")

---

## Step 6: Configure Training Parameters

Adjust these based on your hardware and dataset size.

In [None]:
# ========== TRAINING CONFIGURATION ==========

# For Colab Free Tier (T4 GPU):
BATCH_SIZE = 4          # Use 4 for free Colab
EPOCHS = 2              # Start with 2 for testing, 8-10 for production
LEARNING_RATE = 1e-5    # Standard learning rate

# For Colab Pro (V100/A100):
# BATCH_SIZE = 8
# EPOCHS = 10

# Other parameters
IMAGE_SIZE = 384        # BLIP standard size
USE_AMP = True          # Automatic mixed precision (saves memory)
NUM_WORKERS = 0         # Data loader workers (0 for Colab)
GRAD_CLIP = 1.0         # Gradient clipping
WEIGHT_DECAY = 0.05     # L2 regularization

print("Training Configuration:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"  Mixed precision: {USE_AMP}")
print(f"  Num workers: {NUM_WORKERS}")

---

## Step 7: Start Training

This runs the main training script. Progress will be printed to the console.

In [None]:
import subprocess
import os

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Build command
cmd = [
    "python", "train_caption_pvt.py",
    "--image_root", IMAGE_ROOT,
    "--caption_file", CAPTION_FILE,
    "--batch_size", str(BATCH_SIZE),
    "--epochs", str(EPOCHS),
    "--lr", str(LEARNING_RATE),
    "--image_size", str(IMAGE_SIZE),
    "--output_dir", OUTPUT_DIR,
    "--num_workers", str(NUM_WORKERS),
    "--grad_clip", str(GRAD_CLIP),
    "--weight_decay", str(WEIGHT_DECAY),
]
,

print("Starting training...")
print(f"Command: {' '.join(cmd)}")
print("="*80)

result = subprocess.run(cmd, cwd="/content/blip-using-pvt-cbam")

print("="*80)
if result.returncode == 0:
    print("✓ Training completed successfully!")
else:
    print("✗ Training failed. Check error messages above.")

---

## Step 8: Find the Trained Checkpoint

In [None]:
import glob
import os

# List all checkpoints
checkpoints = sorted(glob.glob(f"{OUTPUT_DIR}/*.pth"))

if checkpoints:
    print(f"Found {len(checkpoints)} checkpoint(s):")
    for ckpt in checkpoints:
        size_mb = os.path.getsize(ckpt) / (1024*1024)
        print(f"  - {os.path.basename(ckpt)} ({size_mb:.1f} MB)")
    
    # Use the latest (final_model.pth if it exists, else the last checkpoint)
    CHECKPOINT = None
    for ckpt in checkpoints:
        if 'final_model.pth' in ckpt:
            CHECKPOINT = ckpt
            break
    if CHECKPOINT is None:
        CHECKPOINT = checkpoints[-1]
    
    print(f"\n✓ Using checkpoint: {os.path.basename(CHECKPOINT)}")
else:
    print("✗ No checkpoints found. Run training first.")
    CHECKPOINT = None

---

## Step 9: Run Inference on a Single Image

Generate a caption for a test image.

In [None]:
if CHECKPOINT is None:
    print("✗ No checkpoint available. Run training first.")
else:
    import subprocess
    
    # Test image URL (a dog)
    TEST_IMAGE_URL = "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"
    
    print(f"Testing inference on image: {TEST_IMAGE_URL}\n")
    
    cmd = [
        "python", "scripts/patch_blip_with_pvt.py",
        "--checkpoint", CHECKPOINT,
        "--image_url", TEST_IMAGE_URL,
        "--device", device,
        "--max_length", "50",
        "--num_beams", "5",
    ]
    
    result = subprocess.run(cmd, cwd="/content/blip-using-pvt-cbam")

---

## Step 10: Batch Inference on Multiple Images

Process multiple images and save results to a file.

In [None]:
if CHECKPOINT is None:
    print("✗ No checkpoint available. Run training first.")
else:
    import subprocess
    
    BATCH_OUTPUT = "/content/batch_results.txt"
    
    # Get a few sample images from the training dataset
    import glob as glob_module
    test_images = sorted(glob_module.glob(f"{IMAGE_ROOT}/*.jpg"))[:3]
    test_images += sorted(glob_module.glob(f"{IMAGE_ROOT}/*.png"))[:max(0, 3 - len(test_images))]
    
    print(f"Running batch inference on {len(test_images)} images...\n")
    
    results = []
    for i, img_path in enumerate(test_images, 1):
        print(f"[{i}/{len(test_images)}] Processing: {os.path.basename(img_path)}")
        
        cmd = [
            "python", "scripts/patch_blip_with_pvt.py",
            "--checkpoint", CHECKPOINT,
            "--image_path", img_path,
            "--device", device,
            "--max_length", "50",
            "--num_beams", "5",
        ]
        
        result = subprocess.run(
            cmd,
            cwd="/content/blip-using-pvt-cbam",
            capture_output=True,
            text=True
        )
        
        # Parse caption from output
        output = result.stdout + result.stderr
        if "Caption" in output:
            lines = output.split('\n')
            for j, line in enumerate(lines):
                if "Caption" in line:
                    caption = lines[j+1].strip() if j+1 < len(lines) else "[Failed]"
                    break
        else:
            caption = "[Failed to generate]"
        
        results.append((os.path.basename(img_path), caption))
        print(f"  Caption: {caption[:100]}...\n")
    
    # Save results
    with open(BATCH_OUTPUT, 'w') as f:
        for img_name, caption in results:
            f.write(f"Image: {img_name}\n")
            f.write(f"Caption: {caption}\n\n")
    
    print(f"✓ Results saved to: {BATCH_OUTPUT}")

---

## Step 11: View Results

In [None]:
import os

BATCH_OUTPUT = "/content/batch_results.txt"

if os.path.exists(BATCH_OUTPUT):
    print("BATCH INFERENCE RESULTS")
    print("="*80)
    with open(BATCH_OUTPUT) as f:
        print(f.read())
else:
    print("Results file not found yet. Run batch inference first.")

---

## Step 12: Download Results & Checkpoint

In [None]:
from google.colab import files
import os

print("Preparing files for download...\n")

# Download batch results
if os.path.exists(BATCH_OUTPUT):
    print(f"Downloading: {os.path.basename(BATCH_OUTPUT)}")
    files.download(BATCH_OUTPUT)

# Download trained checkpoint
if CHECKPOINT and os.path.exists(CHECKPOINT):
    print(f"Downloading: {os.path.basename(CHECKPOINT)}")
    files.download(CHECKPOINT)

print("\n✓ Download complete!")

---

## Summary

### What you've accomplished:
1. ✓ Set up BLIP+PVT environment
2. ✓ Trained the image captioning model
3. ✓ Generated captions for test images
4. ✓ Saved results and checkpoint

### Results Quality Tips:
- **2 epochs on 10 images:** Basic captions, some noise
- **5-8 epochs on 10K images:** Good quality, sensible descriptions  
- **8-10 epochs on 30K images:** Excellent quality, fluent captions
- **15+ epochs on 100K images:** Production-ready

### For Better Results:
1. Use a larger dataset (30K+ images)
2. Train for more epochs (8-15)
3. Use professional captions (Flickr30K, COCO dataset)
4. Adjust generation parameters if needed:
   - `--num_beams 5` (higher = better quality, slower)
   - `--repetition_penalty 2.0` (prevents token repetition)
   - `--max_length 50` (maximum caption length)

### Troubleshooting:
- **Out of Memory:** Reduce batch_size to 2
- **Slow Training:** Use num_workers=0 (already set)
- **Bad Captions:** Train for more epochs with more data
- **Repetitive Output:** Increase repetition_penalty in inference