# ControlNet Fine-Tuning and Inference for Video Ad Manipulation

This notebook demonstrates:
1. **Data format requirements** - Exact structure needed for training
2. **Fine-tuning process** - Training the ControlNet adapter
3. **Inference** - Generating manipulated video variants

---

## Table of Contents
1. [Setup](#setup)
2. [Data Format Specification](#data-format)
3. [Data Preparation](#data-prep)
4. [Model Fine-Tuning](#training)
5. [Inference](#inference)
6. [Visualization](#visualization)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd drive/MyDrive/meaning_alignment_tiktok/

/content/drive/MyDrive/meaning_alignment_tiktok


---
## 1. Setup <a name="setup"></a>

In [3]:
import os
import sys
import json
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

# Add src to path
sys.path.insert(0, str(Path.cwd()))

# Import framework modules
from src.models import StableDiffusionControlNetWrapper
from src.training import ControlNetTrainer, VideoAdDataModule
from src.data_preparation import ControlTensorBuilder
from src.video_editing import VideoEditor

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cpu


---
## 2. Data Format Specification <a name="data-format"></a>

### Required Directory Structure

```
data/
├── frames/                          # Original video frames (RGB)
│   ├── video_001/
│   │   ├── frame_00000.png         # PNG format, any resolution
│   │   ├── frame_00001.png
│   │   └── ...
│   ├── video_002/
│   └── ...
│
├── attention_heatmaps/              # Attention heatmaps from LLaVA
│   ├── video_001/
│   │   ├── frame_00000.png         # Grayscale, [0-255], same size as frames
│   │   ├── frame_00001.png
│   │   └── ...
│   ├── video_002/
│   └── ...
│
├── keyword_heatmaps/                # Keyword heatmaps from CLIPSeg
│   ├── video_001/
│   │   ├── frame_00000.png         # Grayscale, [0-255], same size as frames
│   │   ├── frame_00001.png
│   │   └── ...
│   ├── video_002/
│   └── ...
│
└── keywords.json                    # Mapping video_id -> keyword text
```

### File Format Details

#### 1. **frames/** (Original RGB frames)
- **Format**: PNG, JPEG, or any PIL-readable format
- **Resolution**: Any (will be resized to 512×512 during training)
- **Channels**: 3 (RGB)
- **Naming**: `frame_{index:05d}.png` (e.g., frame_00000.png, frame_00001.png)
- **Values**: [0, 255] uint8

#### 2. **attention_heatmaps/** (Attention from LLaVA)
- **Format**: PNG (grayscale)
- **Resolution**: MUST match corresponding frame
- **Channels**: 1 (grayscale)
- **Naming**: Same as frames (frame_00000.png, etc.)
- **Values**: [0, 255] uint8, where:
  - 0 = minimum attention (black)
  - 255 = maximum attention (white)
- **Semantics**: Higher values = higher viewer attention/importance

#### 3. **keyword_heatmaps/** (Keywords from CLIPSeg)
- **Format**: PNG (grayscale)
- **Resolution**: MUST match corresponding frame
- **Channels**: 1 (grayscale)
- **Naming**: Same as frames (frame_00000.png, etc.)
- **Values**: [0, 255] uint8, where:
  - 0 = not keyword region (black)
  - 255 = keyword region (white)
- **Semantics**: Probability/confidence that pixel contains the product/keyword

#### 4. **keywords.json**
```json
{
  "video_001": "jewelry",
  "video_002": "running shoes",
  "video_003": "lipstick",
  "video_004": "smartphone"
}
```
- **Format**: JSON object
- **Keys**: Video ID (must match folder names)
- **Values**: Keyword text (product name/description)

In [4]:
import pandas as pd

alignment_score = pd.read_csv('data/alignment_score.csv')
alignment_score.head()

Unnamed: 0,video id,Scene Number,attention_proportion,start_time,end_time,CTR_mean,CVR_mean,Clicks_mean,Conversion_mean,Remain_mean,contrast,brightness,industry
0,7163329870906884097,1,0.057578,0.0,0.633,0.029925,0.020833,0.215042,0.076923,1.0,0.278934,153.610451,Children's Apparel
1,7163329870906884097,2,0.085758,0.633,2.067,0.201291,0.236111,0.585539,0.239316,0.589117,0.256491,161.464963,Children's Apparel
2,7163329870906884097,3,0.022061,2.067,2.7,0.307857,0.395833,0.541576,0.25641,0.244668,0.187419,177.34056,Children's Apparel
3,7163329870906884097,4,0.038521,2.7,4.167,0.215623,0.333333,0.292796,0.153846,0.172204,0.278855,155.625972,Children's Apparel
4,7163329870906884097,5,0.048467,4.167,4.867,0.149898,0.416667,0.126129,0.128205,0.117016,0.32801,151.592408,Children's Apparel


In [7]:
import pandas as pd

video_metadata = pd.read_csv('data/video_metadata.csv')
video_metadata.columns

  video_metadata = pd.read_csv('data/video_metadata.csv')


Index(['_id', 'time', 'industry', 'video_url', 'recommend_video[0]',
       'recommend_video[1]', 'recommend_video[2]', 'recommend_video[3]',
       'recommend_video[4]', 'metric.comment',
       ...
       'Remain_keyframe[116].value', 'Remain_keyframe[117].value',
       'Remain_keyframe[118].value', 'Remain_keyframe[119].value',
       'Remain_keyframe[120].value', 'Remain_keyframe[121].value',
       'Remain_keyframe[122].value', 'Remain_keyframe[123].value',
       'Remain_percentile', 'keyword_list'],
      dtype='object', length=1340)

In [11]:
keyword_data = video_metadata[['_id', 'keyword_list[0]']]
keyword_data.to_csv('data/keywords.csv', index=False)

### Verify Data Structure

In [None]:
# Define paths
DATA_ROOT = "data"
FRAMES_DIR = os.path.join(DATA_ROOT, "frames")
ATTENTION_DIR = os.path.join(DATA_ROOT, "attention_heatmaps")
KEYWORD_DIR = os.path.join(DATA_ROOT, "keyword_heatmaps")
KEYWORDS_FILE = os.path.join(DATA_ROOT, "keywords.json")

def verify_data_structure():
    """
    Verify that all required data files exist and are properly formatted.
    """
    print("Verifying data structure...\n")

    # Check directories exist
    for dir_path, name in [(FRAMES_DIR, "frames"),
                            (ATTENTION_DIR, "attention_heatmaps"),
                            (KEYWORD_DIR, "keyword_heatmaps")]:
        if os.path.exists(dir_path):
            print(f"✓ {name}/ directory found")
        else:
            print(f"✗ {name}/ directory NOT FOUND")
            return False

    # Check keywords file
    if os.path.exists(KEYWORDS_FILE):
        print(f"✓ keywords.json found")
        with open(KEYWORDS_FILE, 'r') as f:
            keywords = json.load(f)
        print(f"  Found {len(keywords)} videos with keywords")
    else:
        print(f"✗ keywords.json NOT FOUND")
        return False

    # Check each video has all required files
    print("\nChecking video files...")
    for video_id in keywords.keys():
        frame_dir = os.path.join(FRAMES_DIR, video_id)
        attn_dir = os.path.join(ATTENTION_DIR, video_id)
        kw_dir = os.path.join(KEYWORD_DIR, video_id)

        if not all([os.path.exists(frame_dir), os.path.exists(attn_dir), os.path.exists(kw_dir)]):
            print(f"✗ {video_id}: Missing directories")
            continue

        # Count frames
        frames = sorted([f for f in os.listdir(frame_dir) if f.endswith('.png')])
        attn_maps = sorted([f for f in os.listdir(attn_dir) if f.endswith('.png')])
        kw_maps = sorted([f for f in os.listdir(kw_dir) if f.endswith('.png')])

        if len(frames) == len(attn_maps) == len(kw_maps):
            print(f"✓ {video_id}: {len(frames)} frames, keyword='{keywords[video_id]}'")
        else:
            print(f"✗ {video_id}: Mismatched counts - frames:{len(frames)}, attn:{len(attn_maps)}, kw:{len(kw_maps)}")

    return True

# Run verification
verify_data_structure()

### Visualize Example Data

In [None]:
def visualize_example_data(video_id, frame_idx=0):
    """
    Visualize a single frame with its attention and keyword heatmaps.

    Args:
        video_id: Video ID to visualize
        frame_idx: Frame index (default: 0)
    """
    # Load data
    frame_path = os.path.join(FRAMES_DIR, video_id, f"frame_{frame_idx:05d}.png")
    attn_path = os.path.join(ATTENTION_DIR, video_id, f"frame_{frame_idx:05d}.png")
    kw_path = os.path.join(KEYWORD_DIR, video_id, f"frame_{frame_idx:05d}.png")

    frame = np.array(Image.open(frame_path).convert('RGB'))
    attn_map = np.array(Image.open(attn_path).convert('L'))
    kw_map = np.array(Image.open(kw_path).convert('L'))

    # Load keyword
    with open(KEYWORDS_FILE, 'r') as f:
        keywords = json.load(f)
    keyword = keywords[video_id]

    # Compute derived maps
    attn_norm = attn_map.astype(float) / 255.0
    kw_norm = kw_map.astype(float) / 255.0

    # M_t: Keyword mask (binarized at threshold 0.5)
    keyword_mask = (kw_norm > 0.5).astype(float)

    # S_t: Alignment map (attention × keyword)
    alignment_map = attn_norm * kw_norm
    if alignment_map.max() > 0:
        alignment_map = alignment_map / alignment_map.max()  # Normalize

    # Plot
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))

    # Row 1: Input data
    axes[0, 0].imshow(frame)
    axes[0, 0].set_title(f"Original Frame\n{video_id} - frame {frame_idx}", fontsize=12, fontweight='bold')
    axes[0, 0].axis('off')

    axes[0, 1].imshow(attn_map, cmap='hot')
    axes[0, 1].set_title(f"Attention Heatmap (A_t)\nFrom LLaVA", fontsize=12, fontweight='bold')
    axes[0, 1].axis('off')

    axes[0, 2].imshow(kw_map, cmap='hot')
    axes[0, 2].set_title(f"Keyword Heatmap (K_t)\nFrom CLIPSeg\nKeyword: '{keyword}'", fontsize=12, fontweight='bold')
    axes[0, 2].axis('off')

    # Row 2: Derived maps (control tensor components)
    axes[1, 0].imshow(keyword_mask, cmap='gray', vmin=0, vmax=1)
    axes[1, 0].set_title("Keyword Mask (M_t)\nControl Tensor Channel 0", fontsize=12, fontweight='bold', color='blue')
    axes[1, 0].axis('off')

    axes[1, 1].imshow(alignment_map, cmap='hot', vmin=0, vmax=1)
    axes[1, 1].set_title("Alignment Map (S_t)\nControl Tensor Channel 1\nS_t = A_t ⊙ K_t", fontsize=12, fontweight='bold', color='blue')
    axes[1, 1].axis('off')

    # Show overlay
    overlay = frame.copy()
    overlay_mask = np.stack([np.zeros_like(attn_map), np.zeros_like(attn_map), attn_map], axis=-1)
    overlay = (0.6 * overlay + 0.4 * overlay_mask).astype(np.uint8)
    axes[1, 2].imshow(overlay)
    axes[1, 2].set_title("Attention Overlay\n(Visualization only)", fontsize=12, fontweight='bold')
    axes[1, 2].axis('off')

    plt.tight_layout()
    plt.show()

    # Print stats
    print(f"\nData Statistics for {video_id}, frame {frame_idx}:")
    print(f"  Frame shape: {frame.shape}")
    print(f"  Attention range: [{attn_map.min()}, {attn_map.max()}]")
    print(f"  Keyword range: [{kw_map.min()}, {kw_map.max()}]")
    print(f"  Keyword mask coverage: {keyword_mask.mean()*100:.2f}% of frame")
    print(f"  Alignment score: {alignment_map.mean():.4f}")

# Example: Visualize first frame of first video
with open(KEYWORDS_FILE, 'r') as f:
    keywords = json.load(f)
first_video = list(keywords.keys())[0]

visualize_example_data(first_video, frame_idx=0)

---
## 3. Data Preparation <a name="data-prep"></a>

### Build Control Tensors

Control tensors are constructed from attention and keyword heatmaps:

**C_t = [M_t, S_t]** where:
- **M_t**: Keyword mask (binarized keyword heatmap)
- **S_t**: Alignment map (attention × keyword, normalized)

In [None]:
# Configuration
CONFIG = {
    'data': {
        'data_root': DATA_ROOT,
        'keywords_file': KEYWORDS_FILE,
        'image_size': 512,  # Resize all inputs to 512×512
        'include_raw_maps': False,  # Use 2-channel [M_t, S_t] instead of 4-channel
        'keyword_threshold': 0.5,  # Threshold for binarizing keyword heatmap
    },
    'model': {
        'sd_model_name': 'runwayml/stable-diffusion-v1-5',
        'controlnet': {
            'control_channels': 2,  # [M_t, S_t]
            'base_channels': 64,
        },
        'use_lora': False,  # Set to True for LoRA fine-tuning
    },
    'training': {
        'batch_size': 4,
        'num_workers': 4,
        'learning_rate': 1e-4,
        'num_epochs': 10,
        'lambda_recon': 1.0,      # Reconstruction loss weight
        'lambda_lpips': 1.0,      # Perceptual loss weight
        'lambda_bg': 0.5,         # Background preservation weight
        'use_recon_loss': True,
        'gradient_accumulation_steps': 1,
        'mixed_precision': True,
        'log_wandb': False,
        'project_name': 'video-ad-manipulation',
        'output_dir': 'outputs/training',
    },
}

# Create output directory
os.makedirs(CONFIG['training']['output_dir'], exist_ok=True)

# Save config
config_save_path = os.path.join(CONFIG['training']['output_dir'], 'config.yaml')
with open(config_save_path, 'w') as f:
    yaml.dump(CONFIG, f, default_flow_style=False)
print(f"Configuration saved to: {config_save_path}")

### Setup Data Loaders

In [None]:
# Define train/validation split
with open(KEYWORDS_FILE, 'r') as f:
    all_videos = list(json.load(f).keys())

# Split: 80% train, 20% validation
split_idx = int(0.8 * len(all_videos))
train_videos = all_videos[:split_idx]
val_videos = all_videos[split_idx:]

print(f"Training videos ({len(train_videos)}): {train_videos}")
print(f"Validation videos ({len(val_videos)}): {val_videos}")

# Create data module
data_module = VideoAdDataModule(
    data_root=CONFIG['data']['data_root'],
    keywords_file=CONFIG['data']['keywords_file'],
    train_videos=train_videos,
    val_videos=val_videos,
    batch_size=CONFIG['training']['batch_size'],
    num_workers=CONFIG['training']['num_workers'],
    image_size=CONFIG['data']['image_size'],
    include_raw_maps=CONFIG['data']['include_raw_maps'],
)

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

print(f"\nDataset Statistics:")
print(f"  Training samples: {len(data_module.train_dataset)}")
print(f"  Validation samples: {len(data_module.val_dataset)}")
print(f"  Batches per epoch: {len(train_loader)}")

### Inspect a Training Batch

In [None]:
# Get one batch
batch = next(iter(train_loader))

print("Training Batch Contents:")
print(f"  'image' shape: {batch['image'].shape}")  # [B, 3, 512, 512]
print(f"  'control' shape: {batch['control'].shape}")  # [B, 2, 512, 512]
print(f"  'keyword' (text prompts): {batch['keyword'][:2]}...")  # List of strings
print(f"  'keyword_mask' shape: {batch['keyword_mask'].shape}")  # [B, 1, 512, 512]

# Visualize first sample in batch
sample_idx = 0
image = batch['image'][sample_idx].permute(1, 2, 0).numpy()  # [512, 512, 3]
image = (image * 0.5 + 0.5).clip(0, 1)  # Denormalize from [-1, 1] to [0, 1]

keyword_mask = batch['control'][sample_idx, 0].numpy()  # M_t: [512, 512]
alignment_map = batch['control'][sample_idx, 1].numpy()  # S_t: [512, 512]

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(image)
axes[0].set_title(f"Training Image\nKeyword: '{batch['keyword'][sample_idx]}'")
axes[0].axis('off')

axes[1].imshow(keyword_mask, cmap='gray')
axes[1].set_title("Control Channel 0 (M_t)\nKeyword Mask")
axes[1].axis('off')

axes[2].imshow(alignment_map, cmap='hot')
axes[2].set_title("Control Channel 1 (S_t)\nAlignment Map")
axes[2].axis('off')

plt.tight_layout()
plt.show()

---
## 4. Model Fine-Tuning <a name="training"></a>

### Initialize Model

The model consists of:
- **Frozen**: Stable Diffusion (U-Net, VAE, text encoder)
- **Trainable**: ControlNet adapter (~50-100M parameters)

In [None]:
print("Initializing Stable Diffusion + ControlNet model...")
print("This may take a few minutes on first run (downloading pretrained weights)\n")

model = StableDiffusionControlNetWrapper(
    sd_model_name=CONFIG['model']['sd_model_name'],
    controlnet_config=CONFIG['model']['controlnet'],
    device=device,
    use_lora=CONFIG['model']['use_lora'],
)

print("✓ Model initialized successfully\n")
print(f"Model configuration:")
print(f"  SD backbone: {CONFIG['model']['sd_model_name']}")
print(f"  ControlNet input channels: {CONFIG['model']['controlnet']['control_channels']}")
print(f"  Using LoRA: {CONFIG['model']['use_lora']}")

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nParameter counts:")
print(f"  Total parameters: {total_params/1e6:.2f}M")
print(f"  Trainable parameters: {trainable_params/1e6:.2f}M ({100*trainable_params/total_params:.2f}%)")

### Training Loop

**Loss Functions:**
```
L_diff = ||ε̂ - ε||²                              (Diffusion prediction)
L_recon = ||Î - I||₁ + λ_LPIPS·LPIPS(Î, I)      (Reconstruction quality)
L_bg = λ_bg·||(Î - I) ⊙ B||₁                     (Background preservation)

L_total = L_diff + λ_recon·L_recon + λ_bg·L_bg
```

In [None]:
# Initialize trainer
trainer = ControlNetTrainer(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    learning_rate=CONFIG['training']['learning_rate'],
    num_epochs=CONFIG['training']['num_epochs'],
    device=device,
    output_dir=CONFIG['training']['output_dir'],
    lambda_recon=CONFIG['training']['lambda_recon'],
    lambda_lpips=CONFIG['training']['lambda_lpips'],
    lambda_bg=CONFIG['training']['lambda_bg'],
    use_recon_loss=CONFIG['training']['use_recon_loss'],
    gradient_accumulation_steps=CONFIG['training']['gradient_accumulation_steps'],
    mixed_precision=CONFIG['training']['mixed_precision'],
    log_wandb=CONFIG['training']['log_wandb'],
    project_name=CONFIG['training']['project_name'],
)

print(f"Starting training for {CONFIG['training']['num_epochs']} epochs...\n")
print("Loss weights:")
print(f"  λ_recon = {CONFIG['training']['lambda_recon']}")
print(f"  λ_LPIPS = {CONFIG['training']['lambda_lpips']}")
print(f"  λ_bg = {CONFIG['training']['lambda_bg']}")
print()

# Train
trainer.train()

print("\n" + "="*50)
print("Training completed!")
print(f"Best model saved to: {os.path.join(CONFIG['training']['output_dir'], 'best_model.pt')}")
print("="*50)

### Plot Training Curves

In [None]:
# Load training history
history_path = os.path.join(CONFIG['training']['output_dir'], 'training_history.json')
if os.path.exists(history_path):
    with open(history_path, 'r') as f:
        history = json.load(f)

    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    # Total loss
    axes[0].plot(history['train_loss'], label='Train')
    axes[0].plot(history['val_loss'], label='Validation')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Total Loss')
    axes[0].set_title('Training and Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Individual loss components
    axes[1].plot(history['diff_loss'], label='L_diff (diffusion)')
    axes[1].plot(history['recon_loss'], label='L_recon (reconstruction)')
    axes[1].plot(history['bg_loss'], label='L_bg (background)')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Loss Components')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
else:
    print("Training history not found")

---
## 5. Inference <a name="inference"></a>

### Load Trained Model

In [None]:
# Load best trained model
model_path = os.path.join(CONFIG['training']['output_dir'], 'best_model.pt')

if os.path.exists(model_path):
    print(f"Loading trained model from: {model_path}")

    # Re-initialize model
    model = StableDiffusionControlNetWrapper(
        sd_model_name=CONFIG['model']['sd_model_name'],
        controlnet_config=CONFIG['model']['controlnet'],
        device=device,
        use_lora=CONFIG['model']['use_lora'],
    )

    # Load trained weights
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    print("✓ Model loaded successfully")
else:
    print("Trained model not found. Please complete training first.")

### Inference: Generate Experimental Variants

#### Required Inputs for Inference:

```python
inputs = {
    'image': torch.Tensor,           # Original frame [1, 3, 512, 512], normalized to [-1, 1]
    'control': torch.Tensor,         # Control tensor [1, 2, 512, 512]
                                     #   Channel 0: keyword_mask (M_t)
                                     #   Channel 1: alignment_map (S_t)
    'keyword': str,                  # Keyword text (e.g., "jewelry")
    'variant': str,                  # One of: 'baseline', 'boost', 'reduction', 'placebo'
    'boost_alpha': float,            # Multiplication factor for boost (default: 1.5)
    'reduction_alpha': float,        # Multiplication factor for reduction (default: 0.5)
}
```

#### Expected Outputs:

```python
outputs = {
    'edited_image': torch.Tensor,    # Edited frame [1, 3, 512, 512], range [0, 1]
    'modified_control': torch.Tensor,# Modified control tensor [1, 2, 512, 512]
    'variant_name': str,             # Variant identifier
}
```

In [None]:
def create_variant_control_tensor(original_control, variant_type, boost_alpha=1.5, reduction_alpha=0.5):
    """
    Modify control tensor to create experimental variant.

    Args:
        original_control: [B, 2, H, W] tensor with [M_t, S_t]
        variant_type: 'baseline', 'boost', 'reduction', or 'placebo'
        boost_alpha: Multiplication factor for boosting alignment
        reduction_alpha: Multiplication factor for reducing alignment

    Returns:
        modified_control: [B, 2, H, W] tensor
    """
    modified_control = original_control.clone()

    keyword_mask = original_control[:, 0:1]  # M_t: [B, 1, H, W]
    alignment_map = original_control[:, 1:2]  # S_t: [B, 1, H, W]

    if variant_type == 'baseline':
        # No change
        pass

    elif variant_type == 'boost':
        # Increase alignment: S_t' = boost_alpha * S_t
        modified_alignment = (alignment_map * boost_alpha).clamp(0, 1)
        modified_control[:, 1:2] = modified_alignment

    elif variant_type == 'reduction':
        # Decrease alignment: S_t' = reduction_alpha * S_t
        modified_alignment = (alignment_map * reduction_alpha).clamp(0, 1)
        modified_control[:, 1:2] = modified_alignment

    elif variant_type == 'placebo':
        # Manipulate background (outside keyword region)
        background_mask = 1 - keyword_mask
        # Boost attention in background regions
        modified_alignment = alignment_map.clone()
        # This creates a control that affects non-keyword areas
        modified_control[:, 1:2] = modified_alignment * background_mask

    else:
        raise ValueError(f"Unknown variant type: {variant_type}")

    return modified_control


@torch.no_grad()
def inference_single_frame(model, image, control, keyword, variant_type='baseline',
                          boost_alpha=1.5, reduction_alpha=0.5,
                          num_inference_steps=50, guidance_scale=7.5, strength=0.8):
    """
    Run inference on a single frame.

    Args:
        model: Trained StableDiffusionControlNetWrapper
        image: [1, 3, H, W] tensor, normalized to [-1, 1]
        control: [1, 2, H, W] tensor with [M_t, S_t]
        keyword: Text prompt (e.g., "jewelry")
        variant_type: 'baseline', 'boost', 'reduction', 'placebo'
        boost_alpha: Boost factor (default: 1.5)
        reduction_alpha: Reduction factor (default: 0.5)
        num_inference_steps: Number of diffusion steps
        guidance_scale: CFG scale
        strength: Edit strength (0.0 = no change, 1.0 = full regeneration)

    Returns:
        dict with:
            'edited_image': [1, 3, H, W] tensor in [0, 1]
            'modified_control': [1, 2, H, W] tensor
            'variant_name': str
    """
    # Create variant control tensor
    modified_control = create_variant_control_tensor(
        control, variant_type, boost_alpha, reduction_alpha
    )

    # Run model
    edited_image = model(
        image=image,
        control=modified_control,
        prompt=keyword,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        strength=strength,
    )

    return {
        'edited_image': edited_image,
        'modified_control': modified_control,
        'variant_name': variant_type,
    }

print("Inference functions defined.")
print("\nAvailable variant types:")
print("  - 'baseline': No manipulation (control)")
print("  - 'boost': Increase attention-keyword alignment")
print("  - 'reduction': Decrease attention-keyword alignment")
print("  - 'placebo': Manipulate non-keyword regions")

### Example: Generate Variants for a Single Frame

In [None]:
# Select a test sample
test_batch = next(iter(val_loader))
test_idx = 0

# Prepare inputs
original_image = test_batch['image'][test_idx:test_idx+1].to(device)  # [1, 3, 512, 512]
original_control = test_batch['control'][test_idx:test_idx+1].to(device)  # [1, 2, 512, 512]
keyword = test_batch['keyword'][test_idx]  # str

print(f"Generating variants for keyword: '{keyword}'\n")

# Generate all variants
variants = ['baseline', 'boost', 'reduction', 'placebo']
results = {}

for variant in variants:
    print(f"Generating {variant} variant...")
    result = inference_single_frame(
        model=model,
        image=original_image,
        control=original_control,
        keyword=keyword,
        variant_type=variant,
        num_inference_steps=50,
        guidance_scale=7.5,
        strength=0.8,
    )
    results[variant] = result

print("\n✓ All variants generated")

---
## 6. Visualization <a name="visualization"></a>

### Visualize All Variants

In [None]:
def tensor_to_image(tensor):
    """Convert tensor to numpy image for visualization."""
    img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy()
    img = (img * 0.5 + 0.5).clip(0, 1)  # Denormalize from [-1, 1] to [0, 1]
    return img

# Prepare original image
original_img = tensor_to_image(original_image)

# Create comparison plot
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Original
axes[0, 0].imshow(original_img)
axes[0, 0].set_title(f"Original Frame\nKeyword: '{keyword}'", fontsize=14, fontweight='bold')
axes[0, 0].axis('off')

# Original control maps
orig_keyword_mask = original_control[0, 0].cpu().numpy()
orig_alignment = original_control[0, 1].cpu().numpy()

axes[0, 1].imshow(orig_keyword_mask, cmap='gray')
axes[0, 1].set_title("Original Keyword Mask (M_t)", fontsize=14, fontweight='bold')
axes[0, 1].axis('off')

axes[0, 2].imshow(orig_alignment, cmap='hot')
axes[0, 2].set_title("Original Alignment (S_t)", fontsize=14, fontweight='bold')
axes[0, 2].axis('off')

# Edited variants
variant_positions = {
    'baseline': (1, 0),
    'boost': (1, 1),
    'reduction': (1, 2),
}

for variant, (row, col) in variant_positions.items():
    edited_img = tensor_to_image(results[variant]['edited_image'])
    axes[row, col].imshow(edited_img)

    # Add title with variant details
    if variant == 'baseline':
        title = "Baseline (No Change)\nControl Condition"
    elif variant == 'boost':
        title = "Boosted Alignment\nS_t' = 1.5 × S_t"
    elif variant == 'reduction':
        title = "Reduced Alignment\nS_t' = 0.5 × S_t"

    axes[row, col].set_title(title, fontsize=14, fontweight='bold', color='blue')
    axes[row, col].axis('off')

plt.suptitle("Experimental Variants Comparison", fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.show()

### Compare Alignment Maps

In [None]:
# Plot alignment maps for all variants
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

for idx, variant in enumerate(['baseline', 'boost', 'reduction', 'placebo']):
    modified_alignment = results[variant]['modified_control'][0, 1].cpu().numpy()

    im = axes[idx].imshow(modified_alignment, cmap='hot', vmin=0, vmax=1)
    axes[idx].set_title(f"{variant.capitalize()}\nAlignment Map (S_t')", fontsize=14, fontweight='bold')
    axes[idx].axis('off')

    # Add colorbar
    plt.colorbar(im, ax=axes[idx], fraction=0.046, pad=0.04)

plt.suptitle("Modified Alignment Maps (Control Signal)", fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Print statistics
print("\nAlignment Map Statistics:")
for variant in ['baseline', 'boost', 'reduction', 'placebo']:
    alignment = results[variant]['modified_control'][0, 1].cpu().numpy()
    print(f"  {variant:12s}: mean={alignment.mean():.4f}, max={alignment.max():.4f}, std={alignment.std():.4f}")

### Temporal Variants (for full videos)

For complete video editing with temporal windows:

In [None]:
def create_temporal_variants(video_id, model, output_dir='outputs/variants'):
    """
    Create all 7 experimental variants for a full video.

    Variants:
      1. baseline: No change
      2. early_boost: Boost frames 0-33%
      3. middle_boost: Boost frames 33-66%
      4. late_boost: Boost frames 66-100%
      5. full_boost: Boost all frames
      6. reduction: Reduce alignment in middle section
      7. placebo: Manipulate background only

    Args:
        video_id: Video identifier
        model: Trained model
        output_dir: Output directory for variants

    Returns:
        dict: Paths to generated variant videos
    """
    os.makedirs(output_dir, exist_ok=True)

    # Load all frames for this video
    frame_dir = os.path.join(FRAMES_DIR, video_id)
    frames = sorted([f for f in os.listdir(frame_dir) if f.endswith('.png')])
    num_frames = len(frames)

    # Define temporal windows
    early_end = int(num_frames * 0.33)
    middle_start = early_end
    middle_end = int(num_frames * 0.66)
    late_start = middle_end

    print(f"Processing video: {video_id}")
    print(f"  Total frames: {num_frames}")
    print(f"  Early window: [0, {early_end})")
    print(f"  Middle window: [{middle_start}, {middle_end})")
    print(f"  Late window: [{late_start}, {num_frames})")

    # Variant definitions
    variant_specs = {
        'baseline': {'type': 'baseline', 'frames': range(num_frames)},
        'early_boost': {'type': 'boost', 'frames': range(0, early_end)},
        'middle_boost': {'type': 'boost', 'frames': range(middle_start, middle_end)},
        'late_boost': {'type': 'boost', 'frames': range(late_start, num_frames)},
        'full_boost': {'type': 'boost', 'frames': range(num_frames)},
        'reduction': {'type': 'reduction', 'frames': range(middle_start, middle_end)},
        'placebo': {'type': 'placebo', 'frames': range(middle_start, middle_end)},
    }

    print("\nGenerating variants...")

    # Process each variant
    # (Full implementation would iterate through frames and generate edited videos)
    # This is a template - actual implementation in src/video_editing/

    return variant_specs

print("Temporal variant specification:")
print("\n1. baseline: Original video, no changes")
print("2. early_boost: Boost alignment in first third (frames 0-33%)")
print("3. middle_boost: Boost alignment in middle third (frames 33-66%)")
print("4. late_boost: Boost alignment in last third (frames 66-100%)")
print("5. full_boost: Boost alignment throughout entire video")
print("6. reduction: Reduce alignment in middle section")
print("7. placebo: Manipulate background regions only")
print("\nEach variant is saved as a separate video file for experimental use.")

---
## Summary

### Fine-Tuning Requirements

**Data:**
- Original frames (RGB)
- Attention heatmaps from LLaVA (grayscale)
- Keyword heatmaps from CLIPSeg (grayscale)
- Keywords JSON file

**Training:**
- Fine-tune ControlNet adapter only (~50-100M params)
- Freeze Stable Diffusion backbone
- Loss: L_diff + L_recon + L_bg
- Typical: 10-20 epochs on 4-8 videos

### Inference Requirements

**Inputs:**
```python
{
    'image': torch.Tensor,        # [1, 3, 512, 512], range [-1, 1]
    'control': torch.Tensor,      # [1, 2, 512, 512]
                                  #   control[0] = keyword_mask (M_t)
                                  #   control[1] = alignment_map (S_t)
    'keyword': str,               # Product/keyword text
    'variant_type': str,          # 'baseline', 'boost', 'reduction', 'placebo'
}
```

**Outputs:**
```python
{
    'edited_image': torch.Tensor,     # [1, 3, 512, 512], range [0, 1]
    'modified_control': torch.Tensor, # [1, 2, 512, 512]
    'variant_name': str,              # Variant identifier
}
```

### Output Videos

For each input video, the framework generates **7 variant videos**:
1. baseline.mp4
2. early_boost.mp4
3. middle_boost.mp4
4. late_boost.mp4
5. full_boost.mp4
6. reduction.mp4
7. placebo.mp4

These variants are used in experiments to study how attention-keyword alignment affects viewer responses.