# ControlNet Fine-Tuning and Inference for Video Ad Manipulation (Updated)

This notebook demonstrates:
1. **Data validation** - Get valid video IDs from alignment_score.csv
2. **Quick training** - Use subset of videos for fast experimentation
3. **Fine-tuning process** - Training the ControlNet adapter
4. **Inference** - Generating 7 experimental video variants

---

## Table of Contents
1. [Setup](#setup)
2. [Data Validation](#data-validation)
3. [Dataset Preparation](#data-prep)
4. [Model Fine-Tuning](#training)
5. [Inference: Generate 7 Variants](#inference)
6. [Visualization](#visualization)

---
## 1. Setup <a name="setup"></a>

In [None]:
import os
import sys
import json
import yaml
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm

# Add src to path
sys.path.insert(0, str(Path.cwd()))

# Import framework modules
from src.models import StableDiffusionControlNetWrapper
from src.training import (
    ControlNetTrainer,
    get_valid_video_ids,
    split_train_val_videos,
    print_dataset_statistics,
)
from src.training.dataset_v2 import VideoSceneDataModule
from src.data_preparation import ControlTensorBuilder
from src.video_editing.experimental_variants_v2 import VideoVariantGenerator, visualize_variant_comparison

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

---
## 2. Data Validation <a name="data-validation"></a>

**IMPORTANT:** Use `alignment_score.csv` as source of truth for video IDs.

Some videos in `keywords.csv` may not have alignment scores, so we take the intersection.

In [None]:
# File paths
ALIGNMENT_SCORE_FILE = 'data/alignment_score.csv'
KEYWORDS_FILE = 'data/keywords.csv'
SCREENSHOTS_DIR = 'data/screenshots_tiktok'
KEYWORD_MASKS_DIR = 'data/keyword_masks'

# Get valid video IDs (intersection of alignment_score.csv and keywords.csv)
print("Validating video IDs...\n")
valid_video_ids = get_valid_video_ids(
    alignment_score_file=ALIGNMENT_SCORE_FILE,
    keywords_file=KEYWORDS_FILE
)

print(f"\n‚úì Found {len(valid_video_ids)} valid videos")
print(f"First 10 video IDs: {valid_video_ids[:10]}")

### Quick Training Mode

**For fast experimentation**, you can limit the number of videos used for training.

Set `USE_SUBSET = True` and `NUM_VIDEOS` to a small number (e.g., 10-20) for quick training.

In [None]:
# ========================================
# CONFIGURATION: Quick Training Mode
# ========================================

# Set to True for quick training with subset of videos
USE_SUBSET = True

# Number of videos to use (only if USE_SUBSET=True)
NUM_VIDEOS = 10  # Use 10 videos for quick experimentation

# ========================================

if USE_SUBSET:
    # Use only first NUM_VIDEOS for quick training
    video_ids_for_training = valid_video_ids[:NUM_VIDEOS]
    print(f"\n‚ö° QUICK TRAINING MODE")
    print(f"Using {len(video_ids_for_training)} videos out of {len(valid_video_ids)} available")
    print(f"Videos: {video_ids_for_training}")
else:
    # Use all valid videos
    video_ids_for_training = valid_video_ids
    print(f"\nüìö FULL TRAINING MODE")
    print(f"Using all {len(video_ids_for_training)} videos")

---
## 3. Dataset Preparation <a name="data-prep"></a>

### Split into Train/Validation Sets

In [None]:
# Split videos into train/val (80/20)
train_videos, val_videos = split_train_val_videos(
    video_ids=video_ids_for_training,
    val_ratio=0.2,
    random_seed=42
)

print(f"\nTrain videos ({len(train_videos)}): {train_videos}")
print(f"Val videos ({len(val_videos)}): {val_videos}")

In [None]:
# Print detailed statistics
print_dataset_statistics(
    alignment_score_file=ALIGNMENT_SCORE_FILE,
    train_videos=train_videos,
    val_videos=val_videos
)

### Configuration

In [None]:
# Training configuration
CONFIG = {
    'data': {
        'alignment_score_file': ALIGNMENT_SCORE_FILE,
        'keywords_file': KEYWORDS_FILE,
        'screenshots_dir': SCREENSHOTS_DIR,
        'keyword_masks_dir': KEYWORD_MASKS_DIR,
        'image_size': 512,
    },
    'model': {
        'sd_model_name': 'runwayml/stable-diffusion-v1-5',
        'controlnet': {
            'control_channels': 2,  # [M_t, S_t]
            'base_channels': 64,
        },
        'use_lora': False,
    },
    'training': {
        'batch_size': 4,
        'num_workers': 4,
        'learning_rate': 1e-4,
        'num_epochs': 5 if USE_SUBSET else 10,  # Fewer epochs for subset
        'lambda_recon': 1.0,
        'lambda_lpips': 1.0,
        'lambda_bg': 0.5,
        'use_recon_loss': True,
        'gradient_accumulation_steps': 1,
        'mixed_precision': True,
        'log_wandb': False,
        'project_name': 'video-ad-manipulation',
        'output_dir': 'outputs/training_subset' if USE_SUBSET else 'outputs/training_full',
    },
}

# Create output directory
os.makedirs(CONFIG['training']['output_dir'], exist_ok=True)

# Save config
config_save_path = os.path.join(CONFIG['training']['output_dir'], 'config.yaml')
with open(config_save_path, 'w') as f:
    yaml.dump(CONFIG, f, default_flow_style=False)

print(f"\nTraining Configuration:")
print(f"  Mode: {'SUBSET' if USE_SUBSET else 'FULL'}")
print(f"  Videos: {len(train_videos)} train, {len(val_videos)} val")
print(f"  Epochs: {CONFIG['training']['num_epochs']}")
print(f"  Batch size: {CONFIG['training']['batch_size']}")
print(f"  Output: {CONFIG['training']['output_dir']}")
print(f"\nConfig saved to: {config_save_path}")

### Create Data Loaders

In [None]:
# Create data module
print("Creating data loaders...\n")

data_module = VideoSceneDataModule(
    alignment_score_file=CONFIG['data']['alignment_score_file'],
    keywords_file=CONFIG['data']['keywords_file'],
    train_videos=train_videos,
    val_videos=val_videos,
    screenshots_dir=CONFIG['data']['screenshots_dir'],
    keyword_masks_dir=CONFIG['data']['keyword_masks_dir'],
    batch_size=CONFIG['training']['batch_size'],
    num_workers=CONFIG['training']['num_workers'],
    image_size=(CONFIG['data']['image_size'], CONFIG['data']['image_size']),
)

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

print(f"\n‚úì Data loaders created:")
print(f"  Training scenes: {len(data_module.train_dataset)}")
print(f"  Validation scenes: {len(data_module.val_dataset)}")
print(f"  Training batches per epoch: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

### Inspect a Training Batch

In [None]:
# Get one batch
batch = next(iter(train_loader))

print("Training Batch Contents:")
print(f"  'image' shape: {batch['image'].shape}")  # [B, 3, 512, 512]
print(f"  'control' shape: {batch['control'].shape}")  # [B, 2, 512, 512]
print(f"  'keyword_mask' shape: {batch['keyword_mask'].shape}")  # [B, 1, 512, 512]
print(f"  'alignment_score': {batch['alignment_score'][:3].tolist()}...")  # Scalars
print(f"  'keyword' (text prompts): {batch['keyword'][:2]}...")
print(f"  'video_id': {batch['video_id'][:2]}...")
print(f"  'scene_number': {batch['scene_number'][:3].tolist()}...")

In [None]:
# Visualize first sample in batch
sample_idx = 0
image = batch['image'][sample_idx].permute(1, 2, 0).numpy()
image = (image * 0.5 + 0.5).clip(0, 1)  # Denormalize

keyword_mask = batch['control'][sample_idx, 0].numpy()  # M_t
alignment_map = batch['control'][sample_idx, 1].numpy()  # S_t
alignment_score = batch['alignment_score'][sample_idx].item()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title(f"Scene Image\nVideo: {batch['video_id'][sample_idx]}\nScene: {batch['scene_number'][sample_idx]}\nKeyword: '{batch['keyword'][sample_idx]}'")
axes[0].axis('off')

axes[1].imshow(keyword_mask, cmap='gray')
axes[1].set_title("Control Channel 0 (M_t)\nKeyword Mask")
axes[1].axis('off')

axes[2].imshow(alignment_map, cmap='hot')
axes[2].set_title(f"Control Channel 1 (S_t)\nAlignment Map\nScore: {alignment_score:.4f}")
axes[2].axis('off')

plt.tight_layout()
plt.show()

---
## 4. Model Fine-Tuning <a name="training"></a>

### Initialize Model

In [None]:
print("Initializing Stable Diffusion + ControlNet model...")
print("This may take a few minutes on first run (downloading pretrained weights)\n")

model = StableDiffusionControlNetWrapper(
    sd_model_name=CONFIG['model']['sd_model_name'],
    controlnet_config=CONFIG['model']['controlnet'],
    device=device,
    use_lora=CONFIG['model']['use_lora'],
)

print("‚úì Model initialized successfully\n")
print(f"Model configuration:")
print(f"  SD backbone: {CONFIG['model']['sd_model_name']}")
print(f"  ControlNet input channels: {CONFIG['model']['controlnet']['control_channels']}")
print(f"  Using LoRA: {CONFIG['model']['use_lora']}")

### Training Loop

**Note:** Uncomment `trainer.train()` to start training.

In [None]:
# Initialize trainer
trainer = ControlNetTrainer(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    learning_rate=CONFIG['training']['learning_rate'],
    num_epochs=CONFIG['training']['num_epochs'],
    device=device,
    output_dir=CONFIG['training']['output_dir'],
    lambda_recon=CONFIG['training']['lambda_recon'],
    lambda_lpips=CONFIG['training']['lambda_lpips'],
    lambda_bg=CONFIG['training']['lambda_bg'],
    use_recon_loss=CONFIG['training']['use_recon_loss'],
    gradient_accumulation_steps=CONFIG['training']['gradient_accumulation_steps'],
    mixed_precision=CONFIG['training']['mixed_precision'],
    log_wandb=CONFIG['training']['log_wandb'],
    project_name=CONFIG['training']['project_name'],
)

print(f"\n{'='*60}")
print(f"TRAINING CONFIGURATION")
print(f"{'='*60}")
print(f"Mode: {'SUBSET (' + str(NUM_VIDEOS) + ' videos)' if USE_SUBSET else 'FULL (' + str(len(valid_video_ids)) + ' videos)'}")
print(f"Epochs: {CONFIG['training']['num_epochs']}")
print(f"Training scenes: {len(data_module.train_dataset)}")
print(f"Batches per epoch: {len(train_loader)}")
print(f"Output directory: {CONFIG['training']['output_dir']}")
print(f"{'='*60}\n")

# Train (UNCOMMENT TO START TRAINING)
# trainer.train()

print("\n‚ö†Ô∏è  Training not started (trainer.train() is commented out)")
print("   Uncomment the line above to start training")

---
## 5. Inference: Generate 7 Experimental Variants <a name="inference"></a>

### Variant Definitions

We generate **7 variants** for each video:

1. **baseline**: Original alignment scores (control condition)
2. **early_boost**: Boost alignment in first 33% of scenes (√ó1.5)
3. **middle_boost**: Boost alignment in middle 33% of scenes (√ó1.5)
4. **late_boost**: Boost alignment in last 33% of scenes (√ó1.5)
5. **full_boost**: Boost alignment in all scenes (√ó1.5)
6. **reduction**: Reduce alignment in middle 33% of scenes (√ó0.5)
7. **placebo**: Modify non-keyword regions only

### Initialize Variant Generator

In [None]:
# Initialize variant generator
variant_generator = VideoVariantGenerator(
    alignment_score_file=CONFIG['data']['alignment_score_file'],
    keywords_file=CONFIG['data']['keywords_file'],
    boost_alpha=1.5,
    reduction_alpha=0.5,
)

print("‚úì Variant generator initialized")
print(f"  Boost alpha: 1.5")
print(f"  Reduction alpha: 0.5")

### Example: Generate Variants for Single Video

In [None]:
# Generate variants for a single video (example)
example_video_id = valid_video_ids[0]
print(f"Generating variants for video: {example_video_id}\n")

variants = variant_generator.create_all_variants_for_video(example_video_id)

print(f"\n‚úì Generated {len(variants)} variants:")
for variant_name in variants.keys():
    print(f"  - {variant_name}")

In [None]:
# Compute and display statistics
stats = variant_generator.compute_variant_statistics(variants)

print("\nVariant Statistics:\n")
print(f"{'Variant':<15} {'Mean':<10} {'Std':<10} {'Min':<10} {'Max':<10} {'Scenes'}")
print("-" * 65)
for variant_name, stat in stats.items():
    print(f"{variant_name:<15} {stat['mean_alignment']:<10.4f} {stat['std_alignment']:<10.4f} {stat['min_alignment']:<10.4f} {stat['max_alignment']:<10.4f} {stat['num_scenes']}")

---
## 6. Visualization <a name="visualization"></a>

### Visualize Variant Comparison

In [None]:
# Get keyword for this video
keyword = variant_generator.keywords.get(str(example_video_id), "unknown")

# Visualize alignment profiles
visualize_variant_comparison(variants, example_video_id, keyword)

### Generate Variants for All Videos

This generates variant specifications (CSV files) for all valid videos.

In [None]:
# Generate variants for all videos
output_dir = 'outputs/variants_subset' if USE_SUBSET else 'outputs/variants_full'

print(f"Generating variants for all videos...")
print(f"Output directory: {output_dir}\n")

all_variants = variant_generator.generate_variants_for_all_videos(
    output_dir=output_dir
)

# Save manifest
manifest_path = os.path.join(output_dir, 'manifest.json')
variant_generator.save_variant_manifest(
    all_variants,
    output_path=manifest_path
)

print(f"\n{'='*60}")
print(f"‚úì Variant generation complete!")
print(f"{'='*60}")
print(f"  Generated variants for {len(all_variants)} videos")
print(f"  Output directory: {output_dir}/")
print(f"  Manifest: {manifest_path}")
print(f"{'='*60}")

### Inspect Variant Manifest

In [None]:
# Load and display manifest
with open(manifest_path, 'r') as f:
    manifest = json.load(f)

print("Variant Generation Manifest:")
print(f"  Total videos: {manifest['num_videos']}")
print(f"  Variants per video: {manifest['num_variants_per_video']}")
print(f"  Boost alpha: {manifest['boost_alpha']}")
print(f"  Reduction alpha: {manifest['reduction_alpha']}")
print(f"\n  Variant types:")
for vtype in manifest['variant_types']:
    print(f"    - {vtype}")

print(f"\nFirst 5 videos:")
for video_id, info in list(manifest['videos'].items())[:5]:
    print(f"  {video_id}: {info['num_scenes']} scenes, keyword='{info['keyword']}'")

---
## Summary

### What We Did

1. **Validated video IDs** using alignment_score.csv as source of truth
2. **Created train/val split** at video level (not scene level)
3. **Configured quick training mode** to use subset of videos
4. **Prepared data loaders** with proper filtering
5. **Generated 7 experimental variants** for all videos

### Output Files

```
outputs/
‚îú‚îÄ‚îÄ training_subset/          # Training outputs (subset mode)
‚îÇ   ‚îú‚îÄ‚îÄ config.yaml
‚îÇ   ‚îú‚îÄ‚îÄ best_model.pt
‚îÇ   ‚îî‚îÄ‚îÄ training_history.json
‚îÇ
‚îî‚îÄ‚îÄ variants_subset/          # Variant specifications (subset mode)
    ‚îú‚îÄ‚îÄ manifest.json
    ‚îî‚îÄ‚îÄ {video_id}/
        ‚îú‚îÄ‚îÄ baseline.csv
        ‚îú‚îÄ‚îÄ early_boost.csv
        ‚îú‚îÄ‚îÄ middle_boost.csv
        ‚îú‚îÄ‚îÄ late_boost.csv
        ‚îú‚îÄ‚îÄ full_boost.csv
        ‚îú‚îÄ‚îÄ reduction.csv
        ‚îú‚îÄ‚îÄ placebo.csv
        ‚îî‚îÄ‚îÄ statistics.json
```

### Next Steps

1. **Train model**: Uncomment `trainer.train()` to start training
2. **Run inference**: Use trained model to generate edited scenes
3. **Reassemble videos**: Combine edited scenes into video files
4. **Deploy for A/B testing**: Upload variants for experimental study

### Quick vs Full Training

- **Quick mode** (`USE_SUBSET=True`, `NUM_VIDEOS=10`):
  - Fast experimentation
  - Test the pipeline
  - 5 epochs
  - ~10-20 minutes on GPU

- **Full mode** (`USE_SUBSET=False`):
  - Production training
  - All valid videos
  - 10 epochs
  - Several hours on GPU