# Complete Workflow with Attention Heatmaps

This notebook demonstrates the complete workflow for manipulating video attention-keyword alignment using **actual attention heatmaps**.

## Workflow Overview

1. **Data Validation**: Validate scene images, keyword masks, and attention heatmaps
2. **Data Loading**: Load validated scenes for training/inference
3. **Model Training**: Fine-tune ControlNet to manipulate scenes based on attention heatmaps
4. **Variant Generation**: Create 7 experimental variants with modified attention patterns
5. **Inference**: Generate edited scenes using the fine-tuned model
6. **Video Assembly**: Reassemble edited scenes into videos

## Directory Structure

```
data/
├── video_scene_cuts/          # Scene images
│   └── {video_id}/
│       └── {video_id}-Scene-0xx-01.jpg
│
├── keyword_masks/             # Keyword masks from CLIPSeg
│   └── {video_id}/
│       └── scene_1.png
│       └── scene_2.png
│
├── attention_heatmap/         # Attention heatmaps
│   └── {video_id}/
│       └── {video_id}-Scene-001.jpg  # NO -01 suffix!
│       └── {video_id}-Scene-002.jpg
│
├── keywords.csv               # Video keywords
│
└── valid_scenes.csv           # Generated by data_validation.ipynb
```

## Step 1: Setup

In [None]:
import os
import sys
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm

# Add src to path
sys.path.insert(0, str(Path.cwd()))

# Import framework modules
from src.training.dataset_v3 import VideoSceneDatasetV3, VideoSceneDataModuleV3
from src.video_editing.experimental_variants_v3 import VideoVariantGeneratorV3, visualize_variant_comparison
from src.data_preparation import ControlTensorBuilder

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## Step 2: Data Validation

**IMPORTANT**: Before running this notebook, you must run `data_validation.ipynb` to generate `data/valid_scenes.csv`.

The data validation notebook:
- Scans scene images, keyword masks, and attention heatmaps
- Validates that all required files exist
- Creates `data/valid_scenes.csv` with only valid scenes

This prevents repeated file existence checks during training/inference.

In [None]:
# Check if valid_scenes.csv exists
VALID_SCENES_FILE = 'data/valid_scenes.csv'

if not os.path.exists(VALID_SCENES_FILE):
    print("\n" + "="*60)
    print("⚠ ERROR: valid_scenes.csv not found!")
    print("="*60)
    print("\nPlease run the data_validation.ipynb notebook first to:")
    print("  1. Validate scene images, keyword masks, and attention heatmaps")
    print("  2. Generate data/valid_scenes.csv")
    print("\nThen come back to this notebook.")
    print("="*60)
else:
    # Load and inspect valid scenes
    valid_scenes_df = pd.read_csv(VALID_SCENES_FILE)
    print("\n✓ Found valid_scenes.csv")
    print(f"\nTotal valid scenes: {len(valid_scenes_df)}")
    print(f"Unique videos: {valid_scenes_df['video_id'].nunique()}")
    print(f"\nColumns: {list(valid_scenes_df.columns)}")
    print(f"\nFirst 3 rows:")
    print(valid_scenes_df.head(3))

## Step 3: Data Loading

Load validated scenes into PyTorch DataLoader.

In [None]:
# Create data module
data_module = VideoSceneDataModuleV3(
    valid_scenes_file=VALID_SCENES_FILE,
    train_val_split=0.8,
    batch_size=4,
    num_workers=4,
    image_size=(512, 512),
    normalize_heatmap=True,
)

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

print(f"\nDataset Statistics:")
print(f"  Training samples: {len(data_module.train_dataset)}")
print(f"  Validation samples: {len(data_module.val_dataset)}")
print(f"  Batches per epoch: {len(train_loader)}")

### Inspect a Training Batch

In [None]:
# Get one batch
batch = next(iter(train_loader))

print("Training Batch Contents:")
print(f"  'image' shape: {batch['image'].shape}")  # [B, 3, H, W]
print(f"  'control' shape: {batch['control'].shape}")  # [B, 2, H, W]
print(f"  'keyword_mask' shape: {batch['keyword_mask'].shape}")  # [B, 1, H, W]
print(f"  'attention_heatmap' shape: {batch['attention_heatmap'].shape}")  # [B, 1, H, W]
print(f"  'alignment_score': {batch['alignment_score'][:3].tolist()}...")  # Scalars
print(f"  'keyword' (text prompts): {batch['keyword'][:2]}...")
print(f"  'video_id': {batch['video_id'][:2]}...")
print(f"  'scene_number': {batch['scene_number'][:3].tolist()}...")

### Visualize Control Tensors

In [None]:
# Visualize first sample in batch
sample_idx = 0
image = batch['image'][sample_idx].permute(1, 2, 0).numpy()
image = (image * 0.5 + 0.5).clip(0, 1)  # Denormalize

keyword_mask = batch['control'][sample_idx, 0].numpy()  # M_t
attention_heatmap = batch['control'][sample_idx, 1].numpy()  # A_t
alignment_score = batch['alignment_score'][sample_idx].item()

fig, axes = plt.subplots(2, 2, figsize=(12, 12))

# Scene image
axes[0, 0].imshow(image)
axes[0, 0].set_title(f"Scene Image\nVideo: {batch['video_id'][sample_idx]}\nScene: {batch['scene_number'][sample_idx]}\nKeyword: '{batch['keyword'][sample_idx]}'")
axes[0, 0].axis('off')

# Keyword mask
axes[0, 1].imshow(keyword_mask, cmap='gray')
axes[0, 1].set_title("Control Channel 0: M_t\nKeyword Mask (Binary)")
axes[0, 1].axis('off')

# Attention heatmap
axes[1, 0].imshow(attention_heatmap, cmap='hot')
axes[1, 0].set_title("Control Channel 1: A_t\nAttention Heatmap")
axes[1, 0].axis('off')

# Alignment visualization
alignment_map = keyword_mask * attention_heatmap
axes[1, 1].imshow(alignment_map, cmap='hot')
axes[1, 1].set_title(f"Alignment Map (M_t ⊙ A_t)\nAlignment Score: {alignment_score:.4f}")
axes[1, 1].axis('off')

plt.tight_layout()
plt.show()

print(f"\nAlignment Score: {alignment_score:.4f}")
print(f"  = Mean of (keyword_mask * attention_heatmap)")
print(f"  = Proportion of attention on keyword region")

## Step 4: Model Training

Fine-tune ControlNet to manipulate scenes based on attention heatmaps.

**Goal**: Learn to increase/decrease attention-keyword alignment in scenes.

In [None]:
from src.models import StableDiffusionControlNetWrapper
from src.training import ControlNetTrainer

# Model configuration
MODEL_CONFIG = {
    'sd_model_name': 'runwayml/stable-diffusion-v1-5',
    'controlnet': {
        'control_channels': 2,  # [M_t, A_t]
        'base_channels': 64,
    },
    'use_lora': False,
}

# Training configuration
TRAINING_CONFIG = {
    'batch_size': 4,
    'learning_rate': 1e-4,
    'num_epochs': 10,
    'lambda_recon': 1.0,
    'lambda_lpips': 1.0,
    'lambda_bg': 0.5,
    'use_recon_loss': True,
    'gradient_accumulation_steps': 1,
    'mixed_precision': True,
    'log_wandb': False,
    'output_dir': 'outputs/training_v3',
}

print("Model Configuration:")
print(f"  SD backbone: {MODEL_CONFIG['sd_model_name']}")
print(f"  ControlNet input channels: 2 [keyword_mask, attention_heatmap]")
print(f"  Using LoRA: {MODEL_CONFIG['use_lora']}")
print(f"\nTraining Configuration:")
print(f"  Epochs: {TRAINING_CONFIG['num_epochs']}")
print(f"  Learning rate: {TRAINING_CONFIG['learning_rate']}")
print(f"  Batch size: {TRAINING_CONFIG['batch_size']}")

In [None]:
# Initialize model
print("Initializing Stable Diffusion + ControlNet model...")
print("This may take a few minutes on first run (downloading pretrained weights)\n")

model = StableDiffusionControlNetWrapper(
    sd_model_name=MODEL_CONFIG['sd_model_name'],
    controlnet_config=MODEL_CONFIG['controlnet'],
    device=device,
    use_lora=MODEL_CONFIG['use_lora'],
)

print("✓ Model initialized successfully")

In [None]:
# Initialize trainer
os.makedirs(TRAINING_CONFIG['output_dir'], exist_ok=True)

trainer = ControlNetTrainer(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    learning_rate=TRAINING_CONFIG['learning_rate'],
    num_epochs=TRAINING_CONFIG['num_epochs'],
    device=device,
    output_dir=TRAINING_CONFIG['output_dir'],
    lambda_recon=TRAINING_CONFIG['lambda_recon'],
    lambda_lpips=TRAINING_CONFIG['lambda_lpips'],
    lambda_bg=TRAINING_CONFIG['lambda_bg'],
    use_recon_loss=TRAINING_CONFIG['use_recon_loss'],
    gradient_accumulation_steps=TRAINING_CONFIG['gradient_accumulation_steps'],
    mixed_precision=TRAINING_CONFIG['mixed_precision'],
    log_wandb=TRAINING_CONFIG['log_wandb'],
)

print(f"\nStarting training for {TRAINING_CONFIG['num_epochs']} epochs...\n")

# Train (uncomment to start training)
# trainer.train()

print("\n" + "="*50)
print("To start training, uncomment: trainer.train()")
print(f"Best model will be saved to: {os.path.join(TRAINING_CONFIG['output_dir'], 'best_model.pt')}")
print("="*50)

## Step 5: Generate Experimental Variants

Create 7 experimental variants with modified attention patterns:

1. **baseline**: Original attention heatmaps
2. **early_boost**: Boost alignment in first 33% of scenes (×1.5)
3. **middle_boost**: Boost alignment in middle 33% (×1.5)
4. **late_boost**: Boost alignment in last 33% (×1.5)
5. **full_boost**: Boost alignment in all scenes (×1.5)
6. **reduction**: Reduce alignment in middle 33% (×0.5)
7. **placebo**: Modify non-keyword regions only

In [None]:
# Initialize variant generator
variant_generator = VideoVariantGeneratorV3(
    valid_scenes_file=VALID_SCENES_FILE,
    boost_alpha=1.5,
    reduction_alpha=0.5,
    image_size=(512, 512),
)

print("Variant generator initialized")
print(f"  Boost alpha: 1.5")
print(f"  Reduction alpha: 0.5")

### Generate Variants for a Single Video (Example)

In [None]:
# Get a sample video
sample_video_id = valid_scenes_df['video_id'].iloc[0]
print(f"Generating variants for video: {sample_video_id}")

# Generate variants
variants = variant_generator.create_all_variants_for_video(
    sample_video_id,
    output_dir='outputs/variants_v3'
)

print(f"\n✓ Generated {len(variants)} variants:")
for variant_name in variants.keys():
    print(f"  - {variant_name}")

### Compute and Visualize Statistics

In [None]:
# Compute statistics
stats = variant_generator.compute_variant_statistics(variants)

print("\nVariant Statistics:\n")
for variant_name, stat in stats.items():
    print(f"{variant_name:15s}: mean={stat['mean_alignment']:.4f}, std={stat['std_alignment']:.4f}, scenes={stat['num_scenes']}")

In [None]:
# Get keyword for this video
keyword = variants['baseline'].iloc[0]['keyword']

# Visualize alignment profiles
visualize_variant_comparison(variants, sample_video_id, keyword, variant_generator)

### Generate Variants for All Videos

In [None]:
# Generate variants for all videos (or subset for testing)
all_variants = variant_generator.generate_variants_for_all_videos(
    output_dir='outputs/variants_v3',
    max_videos=5  # Set to None to process all videos
)

# Save manifest
variant_generator.save_variant_manifest(
    all_variants,
    output_path='outputs/variants_v3/manifest.json'
)

print(f"\n✓ Generated variants for {len(all_variants)} videos")
print(f"  Output directory: outputs/variants_v3/")
print(f"  Manifest: outputs/variants_v3/manifest.json")

## Step 6: Inference - Generate Edited Scenes

Use the fine-tuned ControlNet to generate edited scenes based on modified attention heatmaps.

**NOTE**: This requires a trained model checkpoint.

In [None]:
# Load trained model
CHECKPOINT_PATH = 'outputs/training_v3/best_model.pt'

if os.path.exists(CHECKPOINT_PATH):
    print(f"Loading trained model from: {CHECKPOINT_PATH}")
    # model.load_checkpoint(CHECKPOINT_PATH)  # Implement this in your wrapper
    print("✓ Model loaded")
else:
    print(f"⚠ Checkpoint not found: {CHECKPOINT_PATH}")
    print("  Please train the model first (Step 4)")

In [None]:
# Example: Generate edited scene for a variant
def generate_edited_scene(
    model,
    scene_image,
    keyword_mask,
    attention_heatmap,
    keyword,
    num_inference_steps=50,
    guidance_scale=7.5
):
    """
    Generate edited scene using ControlNet.
    
    Args:
        model: StableDiffusionControlNetWrapper
        scene_image: Original scene (PIL Image)
        keyword_mask: Keyword mask (numpy array)
        attention_heatmap: Modified attention heatmap (numpy array)
        keyword: Text prompt
        num_inference_steps: Number of denoising steps
        guidance_scale: Classifier-free guidance scale
    
    Returns:
        Edited scene (PIL Image)
    """
    # Build control tensor
    control_builder = ControlTensorBuilder()
    control_tensor = np.stack([keyword_mask, attention_heatmap], axis=0)
    control_tensor = torch.from_numpy(control_tensor).unsqueeze(0).float().to(model.device)
    
    # Generate
    with torch.no_grad():
        edited_image = model.generate(
            prompt=keyword,
            control_tensor=control_tensor,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
        )
    
    return edited_image

print("Inference function defined")
print("To generate edited scenes, call: generate_edited_scene(...)")

## Summary

This notebook demonstrates the complete workflow for manipulating video attention-keyword alignment using **actual attention heatmaps**.

### Key Differences from Previous Approach

**Old Approach** (alignment_score.csv):
- Used scalar alignment scores
- Control tensor: `[keyword_mask, keyword_mask * scalar]`
- Limited spatial information

**New Approach** (attention heatmaps):
- Uses actual attention heatmaps
- Control tensor: `[keyword_mask, attention_heatmap]`
- Full spatial information about where viewers look
- More precise manipulation of attention patterns

### Workflow Summary

1. **Data Validation** (`data_validation.ipynb`)
   - Validates scene images, keyword masks, attention heatmaps
   - Creates `data/valid_scenes.csv`

2. **Data Loading** (`dataset_v3.py`)
   - Loads validated scenes efficiently
   - Constructs control tensors: `[M_t, A_t]`

3. **Model Training** (`trainer.py`)
   - Fine-tunes ControlNet to manipulate attention-keyword alignment
   - Preserves background regions

4. **Variant Generation** (`experimental_variants_v3.py`)
   - Creates 7 variants with modified attention heatmaps
   - Boosts/reduces attention on keyword regions

5. **Inference**
   - Generates edited scenes using fine-tuned model
   - Applies modified attention heatmaps

6. **Video Assembly**
   - Reassembles edited scenes into videos
   - Applies temporal smoothing for consistency

### Next Steps

1. Run `data_validation.ipynb` to generate `valid_scenes.csv`
2. Train ControlNet model (Step 4)
3. Generate experimental variants (Step 5)
4. Run inference to create edited scenes (Step 6)
5. Assemble videos and deploy for A/B testing