# Unfreeze and Train Encoder (v3) for 5 epochs

- Load the baseline model (decoder-only trained)
- Unfreeze the encoder
- Train both encoder and decoder for 5 more epochs
- This helps the encoder adapt to noisy/corrupted images

1. Loads baseline model (`epoch_decoder_only_baseline_3`)
2. Unfreezes encoder
3. Trains both encoder and decoder for 5 epochs
4. Saves the new model with encoder training

**Output:**
- New model saved to `./image-captioning-model/epoch_encoder_trained_5`


## Setup


In [None]:
# ===== SETUP =====
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/GTech\ OMSCS/CS\ 7643/group\ project/CS7643_project

import sys
import os
project_root = '/content/drive/MyDrive/GTech OMSCS/CS 7643/group project/CS7643_project'
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
from tqdm import tqdm
import pickle


## Config


In [None]:
# Model config
BASELINE_MODEL_PATH = './epoch_decoder_only_baseline_3'
OUTPUT_MODEL_DIR = './image-captioning-model'
OUTPUT_MODEL_NAME = 'epoch_encoder_trained_5_v3'

# Dataset config
IMG_DIR = './Flickr8k_Data/Flicker8k_Dataset'
CAP_DIR = './Flicker8k_captions'

# Training config
EPOCHS = 5
BATCH_SIZE = 16
MAX_LEN = 48
LR_ENCODER = 1e-5
LR_DECODER = 1e-4
GRADIENT_ACC_STEPS = 8

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

## Load Dataset


In [None]:
# Load training data
with open(os.path.join(CAP_DIR, 'train_data.pickle'), 'rb') as f:
    train_data = pickle.load(f)

# Load dev data for evaluation
with open(os.path.join(CAP_DIR, 'dev_data.pickle'), 'rb') as f:
    dev_data = pickle.load(f)

print(f"Training samples: {sum(len(captions) for captions in train_data.values())}")
print(f"Dev samples: {sum(len(captions) for captions in dev_data.values())}")


## Create Dataset Class


In [None]:
from PIL import Image

class Flickr8kDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, img_processor, img_dir, max_len):
        self.tokenizer = tokenizer
        self.processor = img_processor
        self.img_dir = img_dir
        self.max_len = max_len
        self.data = []

        # Use all captions for better training
        for filename, captions in data.items():
            for cap in captions:
                self.data.append((filename, cap))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename, caption = self.data[idx]

        # Image Processing
        img_path = os.path.join(self.img_dir, filename)
        img = Image.open(img_path).convert("RGB")
        pixel_values = self.processor(img, return_tensors='pt').pixel_values.squeeze(0)

        # Caption Processing
        tokenized_output = self.tokenizer(
            caption,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        tokens = tokenized_output.input_ids.squeeze(0)
        attention_mask = tokenized_output.attention_mask.squeeze(0)

        labels = tokens.clone()
        labels[labels == tokenizer.pad_token_id] = -100

        return {
            'pixel_values': pixel_values,
            'labels': labels,
            'attention_mask': attention_mask,
            'filename': filename
        }


## Load Baseline Model and Setup


In [None]:
# Load image processor and tokenizer
ENCODER_ID = "google/vit-base-patch16-224-in21k"
img_processor = ViTImageProcessor.from_pretrained(ENCODER_ID)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Add special tokens if needed
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Load baseline model
print(f"Loading baseline model from: {BASELINE_MODEL_PATH}")
model = VisionEncoderDecoderModel.from_pretrained(BASELINE_MODEL_PATH)
model.to(DEVICE)

print("✓ Baseline model loaded")
print(f"  Model on device: {DEVICE}")

# Create datasets
train_dataset = Flickr8kDataset(train_data, tokenizer, img_processor, IMG_DIR, MAX_LEN)
dev_dataset = Flickr8kDataset(dev_data, tokenizer, img_processor, IMG_DIR, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"✓ Datasets created")
print(f"  Train batches: {len(train_loader)}")
print(f"  Dev batches: {len(dev_loader)}")


## Unfreeze Encoder (Alfred's Workaround)

**This is the key step!** We unfreeze the encoder so it can learn to handle noisy/corrupted images.


In [None]:
# STEP 3: Unfreeze encoder and train both encoder and decoder
# This is Alfred's workaround to improve robustness to corrupted images

print("="*80)
print("UNFREEZING ENCODER (Alfred's Workaround)")
print("="*80)

# Check current state
encoder_params = sum(p.numel() for p in model.encoder.parameters())
decoder_params = sum(p.numel() for p in model.decoder.parameters())
encoder_trainable_before = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
decoder_trainable_before = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)

print(f"Before unfreezing:")
print(f"  Encoder parameters: {encoder_params:,}")
print(f"  Encoder trainable: {encoder_trainable_before:,}")
print(f"  Decoder parameters: {decoder_params:,}")
print(f"  Decoder trainable: {decoder_trainable_before:,}")

# Unfreeze the encoder
for param in model.encoder.parameters():
    param.requires_grad = True

# Keep decoder trainable too
for param in model.decoder.parameters():
    param.requires_grad = True

# Verify
encoder_trainable_after = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
decoder_trainable_after = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)

print(f"\nAfter unfreezing:")
print(f"  Encoder trainable: {encoder_trainable_after:,} ✓")
print(f"  Decoder trainable: {decoder_trainable_after:,} ✓")

# Create optimizer with different learning rates for encoder and decoder
optimizer = AdamW([
    {'params': model.encoder.parameters(), 'lr': LR_ENCODER},  # Lower LR for encoder
    {'params': model.decoder.parameters(), 'lr': LR_DECODER},  # Higher LR for decoder
])

print(f"\n✓ Optimizer created")
print(f"  Encoder LR: {LR_ENCODER}")
print(f"  Decoder LR: {LR_DECODER}")
print("="*80)


## Training Loop

Train both encoder and decoder for 5 epochs.


In [None]:
from torch.optim.lr_scheduler import OneCycleLR
# 1. Re-Initialize Optimizer (Clean state)
optimizer = AdamW([
    {'params': model.encoder.parameters(), 'lr': LR_ENCODER},
    {'params': model.decoder.parameters(), 'lr': LR_DECODER},
])
# 2. Setup Scheduler (New!)
# Stabilizes training to fix grammar issues
total_steps = len(train_loader) * EPOCHS // GRADIENT_ACC_STEPS
scheduler = OneCycleLR(optimizer, 
                       max_lr=[LR_ENCODER, LR_DECODER], # Up to max, then decay
                       total_steps=total_steps,
                       pct_start=0.3, # Warm up for 30% of time
                       div_factor=10,
                       final_div_factor=100)
print(f"✓ Optimizer & Scheduler ready ({total_steps} steps)")
# 3. Improved Evaluation Function (New!)
def evaluate_robust(model, loader, tokenizer, device):
    model.eval()
    total_loss = 0
    
    # Calculate Loss
    with torch.no_grad():
        for i, batch in enumerate(tqdm(loader, desc="Evaluating")):
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            decoder_attention_mask = batch['attention_mask'].to(device)
            
            with torch.amp.autocast('cuda'):
                outputs = model(
                    pixel_values=pixel_values,
                    labels=labels,
                    decoder_attention_mask=decoder_attention_mask
                )
            total_loss += outputs.loss.item()
            
            # GENERATE EXAMPLES (Only for first batch)
            # This lets you see if the "stuttering" is fixed
            if i == 0:
                print("\n" + "="*40)
                print("GENERATED EXAMPLES")
                print("="*40)
                gen_ids = model.generate(
                    pixel_values,
                    max_length=50,
                    num_beams=4,
                    repetition_penalty=2.0, # CRITICAL FIX FOR STUTTERING
                    early_stopping=True
                )
                preds = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
                for k in range(min(3, len(preds))):
                    print(f"Pred: {preds[k]}")
                print("="*40 + "\n")
    avg_loss = total_loss / len(loader)
    return avg_loss
# 4. Training Loop (Updated)
scaler = torch.amp.GradScaler('cuda')
print("="*80)
print(f"STARTING V3 TRAINING ({EPOCHS} epochs)")
print("="*80)
best_dev_loss = float('inf')
os.makedirs(OUTPUT_MODEL_DIR, exist_ok=True)
for epoch in range(EPOCHS):
    print(f"\n{'='*80}")
    print(f"Epoch {epoch+1}/{EPOCHS}")
    print(f"{'='*80}")
    
    # Training
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    
    for batch_idx, batch in enumerate(tqdm(train_loader, desc="Training")):
        pixel_values = batch['pixel_values'].to(DEVICE)
        labels = batch['labels'].to(DEVICE)
        decoder_attention_mask = batch['attention_mask'].to(DEVICE)
        
        with torch.amp.autocast('cuda'):
            outputs = model(
                pixel_values=pixel_values,
                labels=labels,
                decoder_attention_mask=decoder_attention_mask
            )
            # Normalize loss for accumulation
            loss = outputs.loss / GRADIENT_ACC_STEPS
        
        # Mixed Precision Backward
        scaler.scale(loss).backward()
        total_loss += outputs.loss.item()
        
        # Step
        if (batch_idx + 1) % GRADIENT_ACC_STEPS == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step() # STEP SCHEDULER
    
    # Handle remainder
    if (batch_idx + 1) % GRADIENT_ACC_STEPS != 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()
    
    avg_train_loss = total_loss / len(train_loader)
    
    # Evaluation
    dev_loss = evaluate_robust(model, dev_loader, tokenizer, DEVICE)
    
    print(f"\nEpoch {epoch+1} Results:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Dev Loss: {dev_loss:.4f}")
    
    # Save best
    if dev_loss < best_dev_loss:
        best_dev_loss = dev_loss
        path = os.path.join(OUTPUT_MODEL_DIR, OUTPUT_MODEL_NAME)
        model.save_pretrained(path)
        tokenizer.save_pretrained(path)
        print(f"  ✓ Saved best model to: {path}")
print("\nTRAINING COMPLETE")

- Use `test_robustness_baseline.ipynb`
- Set `MODEL_PATH = './image-captioning-model/epoch_encoder_trained_5'`
