# Mode Watermarking - Quick Start

This notebook demonstrates the complete workflow for watermarking images with Stable Diffusion, including:
- Configuration setup
- Dataset generation (watermarked/unwatermarked)
- Statistical detection
- Detector training
- Evaluation and metrics

## Overview

The mode-watermarking system embeds secret watermarks into images generated by Stable Diffusion using non-distortionary or distortionary modes. This notebook provides an interactive tutorial for using the system.


## Section 1: Setup & Installation

First, we'll set up the environment and import necessary modules.


In [None]:
# Add parent directory to path for imports
import sys
from pathlib import Path
sys.path.insert(0, str(Path('.').resolve().parent))

# Standard imports
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from typing import Dict, Any, Optional

# Check device availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Import mode-watermarking modules
from src.config.config_loader import ConfigLoader
from src.sd_integration.sd_client import SDClient
from src.detection.recovery import recover_g_values
from src.detection.correlate import compute_s_statistic, batch_compute_s_statistics
from src.detection.calibrate import calibrate_thresholds
from src.evaluation.quality_metrics import compute_quality_metrics, batch_compute_quality_metrics
from src.evaluation.eval import run_full_evaluation
from src.evaluation.visualize import plot_roc_curve, plot_score_distributions

print("Imports successful!")


## Section 2: Configuration Overview

Load and examine the configuration files that control watermark embedding, diffusion parameters, and model architecture.


In [None]:
# Initialize config loader
config_loader = ConfigLoader()

# Load configurations
watermark_cfg = config_loader.load_watermark_config("configs/watermark_config.yaml")
diffusion_cfg = config_loader.load_diffusion_config("configs/diffusion_config.yaml")
model_cfg = config_loader.load_model_architecture_config("configs/model_architecture.yaml")

# Display key parameters
print("=" * 60)
print("Configuration Overview")
print("=" * 60)
print("\nWatermark Configuration:")
print(f"  Mode: {watermark_cfg.get('bias', {}).get('mode', 'N/A')}")
print(f"  Key Scheme: {watermark_cfg.get('watermark', {}).get('key_scheme', 'N/A')}")
print(f"  G-field Shape: {watermark_cfg.get('watermark', {}).get('g_field', {}).get('shape', 'N/A')}")
print(f"  Mapping Mode: {watermark_cfg.get('watermark', {}).get('g_field', {}).get('mapping_mode', 'N/A')}")

print("\nDiffusion Configuration:")
print(f"  Trained Timesteps: {diffusion_cfg.get('diffusion', {}).get('trained_timesteps', 'N/A')}")
print(f"  Inference Timesteps: {diffusion_cfg.get('diffusion', {}).get('inference_timesteps', 'N/A')}")
print(f"  Scheduler: {diffusion_cfg.get('diffusion', {}).get('scheduler', 'N/A')}")

print("\nModel Configuration:")
print(f"  Model ID: {model_cfg.get('model_id', 'N/A')}")
print(f"  Use FP16: {model_cfg.get('use_fp16', 'N/A')}")
print("=" * 60)


## Section 3: Initialize SD Client

Set up the Stable Diffusion pipeline with watermark embedding support.


In [None]:
# Prepare config paths
config_paths = {
    "diffusion": "configs/diffusion_config.yaml",
    "watermark": "configs/watermark_config.yaml",
    "model": "configs/model_architecture.yaml"
}

# Initialize SD client
print("Initializing Stable Diffusion pipeline...")
sd_client = SDClient(config_paths=config_paths, device=device)
sd_client.initialize_pipeline()

print("Pipeline initialized successfully!")
print(f"Pipeline device: {sd_client.device}")

# Register watermark embedding hook (this is done automatically during generate)
# The hook will be registered when we generate watermarked images


## Section 4: Dataset Generation

Generate watermarked and unwatermarked images from prompts.


In [None]:
# Generate single watermarked image
prompt = "A beautiful sunset over mountains with vibrant colors"

print(f"Generating watermarked image with prompt: '{prompt}'")
image_wm, manifest_wm = sd_client.generate(
    prompt=prompt,
    seed=42,  # For reproducibility
    num_inference_steps=None,  # Use config default
    guidance_scale=None  # Use config default
)

print(f"\nGenerated image size: {image_wm.size}")
print(f"Manifest keys: {list(manifest_wm.keys())}")
print(f"Sample ID: {manifest_wm.get('sample_id', 'N/A')}")
print(f"Mode: {manifest_wm.get('mode', 'N/A')}")

# Display image
plt.figure(figsize=(8, 8))
plt.imshow(image_wm)
plt.axis('off')
plt.title(f"Watermarked Image\nPrompt: {prompt[:50]}...")
plt.tight_layout()
plt.show()


In [None]:
# Generate batch of images
prompts = [
    "A serene lake at dawn",
    "A futuristic cityscape at night",
    "A field of wildflowers in spring"
]

print(f"Generating {len(prompts)} watermarked images...")
images_wm = []
manifests_wm = []

for i, prompt in enumerate(prompts):
    print(f"  [{i+1}/{len(prompts)}] {prompt}")
    image, manifest = sd_client.generate(
        prompt=prompt,
        seed=42 + i,
        num_inference_steps=None,
        guidance_scale=None
    )
    images_wm.append(image)
    manifests_wm.append(manifest)

print("\nBatch generation complete!")

# Display batch
fig, axes = plt.subplots(1, len(images_wm), figsize=(15, 5))
for i, (img, prompt) in enumerate(zip(images_wm, prompts)):
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(prompt[:30] + "...", fontsize=10)
plt.tight_layout()
plt.show()


## Section 5: Statistical Detection

Recover g-values from watermarked images and compute detection statistics.


In [None]:
# Recover g-values from watermarked image
print("Recovering g-values from watermarked image...")

recovery_result = recover_g_values(
    image=image_wm,
    vae_encoder=sd_client._pipeline.vae,
    watermark_cfg=watermark_cfg,
    key_info=manifest_wm.get("key_info", {}),
    device=device
)

print(f"Recovered g-values shape: {recovery_result['g_values'].shape}")
print(f"Latent shape: {recovery_result['latent'].shape}")
print(f"Recovery metadata: {recovery_result['recovery_metadata']}")

# Visualize g-values (spatial distribution)
g_values_mean = recovery_result['g_values'].mean(axis=0)  # Average across timesteps
g_values_spatial = np.abs(g_values_mean).mean(axis=0)  # Average across channels, abs

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(g_values_spatial, cmap='hot')
plt.colorbar(label='G-value Magnitude')
plt.title('Spatial Distribution of Recovered G-values')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.hist(g_values_mean.flatten(), bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('G-value')
plt.ylabel('Frequency')
plt.title('G-value Distribution')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
# Compute S-statistic (correlation between recovered and expected g-values)
# For demonstration, we'll use the recovered g-values as both recovered and expected
# In practice, expected g-values would be reconstructed from key_info

from src.watermark.gfield import GFieldBuilder
from src.watermark.key import KeyDerivation
from src.sd_integration.timestep_mapper import TimestepMapper

# Reconstruct expected g-values from key_info
key_derivation = KeyDerivation()
key_info = manifest_wm.get("key_info", {})
experiment_id = manifest_wm.get("experiment_id", "exp_001")
sample_id = manifest_wm.get("sample_id", "unknown")
base_seed = manifest_wm.get("sample_seed", 42)
zT_hash = manifest_wm.get("zT_hash", "default_hash")

# Derive seed
key_master = key_info.get("key_master", watermark_cfg.get("watermark", {}).get("key_master", ""))
seed0 = key_derivation.derive_seed(
    key_master=key_master,
    sample_id=sample_id,
    zT_hash=zT_hash,
    base_seed=base_seed,
    experiment_id=experiment_id
)

# Build expected g-field
gfield_builder = GFieldBuilder(
    mapping_mode=key_info.get("mapping_mode", "binary"),
    bit_pos=key_info.get("bit_pos", 30)
)

timestep_mapper = TimestepMapper(
    trained_timesteps=diffusion_cfg["diffusion"]["trained_timesteps"],
    inference_timesteps=diffusion_cfg["diffusion"]["inference_timesteps"]
)

latent_shape = tuple(watermark_cfg["watermark"]["g_field"]["shape"])
stream_len = latent_shape[0] * latent_shape[1] * latent_shape[2] * len(timestep_mapper.get_all_trained_timesteps())
key_stream = key_derivation.generate_key_stream(seed0, stream_len)
g_schedule = gfield_builder.build_g_schedule(timestep_mapper, latent_shape, key_stream)

# Use first timestep's g-field as expected
expected_g = list(g_schedule.values())[0]

# Compute S-statistic
recovered_g_single = recovery_result['g_values'][0]  # First timestep
s_score = compute_s_statistic(
    recovered_g_values=recovered_g_single,
    expected_g_values=expected_g,
    mask=recovery_result['mask'],
    method="correlation"
)

print(f"S-statistic (correlation): {s_score:.4f}")
print(f"Interpretation: {'Watermarked' if s_score > 0.5 else 'Unwatermarked'} (threshold: 0.5)")


## Section 6: Detector Training (Optional)

Train a UNet detector for watermark detection. This is a demo that trains for a few epochs.


In [None]:
# Note: Full training requires a dataset with train/val splits
# This is a demonstration of how to set up training

print("Detector Training Setup (Demo)")
print("=" * 60)

from src.training.train import train_unet_detector
from pathlib import Path

# Check if training data exists
train_manifest = "data/splits/train.json"
val_manifest = "data/splits/val.json"

if Path(train_manifest).exists() and Path(val_manifest).exists():
    print(f"Training data found:")
    print(f"  Train: {train_manifest}")
    print(f"  Val: {val_manifest}")
    print("\nTo train, uncomment the code below:")
    print("""
    config_paths = {
        "train": "configs/train_config.yaml"
    }
    
    results = train_unet_detector(
        config_paths=config_paths,
        sd_pipeline=sd_client._pipeline,
        resume_from_checkpoint=None
    )
    
    print(f"Best checkpoint: {results['best_checkpoint']}")
    print(f"Final metrics: {results['final_metrics']}")
    """)
else:
    print("Training data not found.")
    print(f"Expected: {train_manifest}")
    print(f"Expected: {val_manifest}")
    print("\nGenerate dataset first using:")
    print("  python scripts/generate_dataset.py --mode both --prompts-file data/coco/prompts_train.txt")


## Section 7: Evaluation

Run detection and quality metrics evaluation.


In [None]:
# Quality metrics evaluation
print("Computing Quality Metrics")
print("=" * 60)

# For demonstration, compare watermarked image with itself (in practice, compare with original)
# In real evaluation, you'd load original images from manifest

quality_metrics = compute_quality_metrics(
    watermarked_image=image_wm,
    original_image=image_wm,  # In practice, use original unwatermarked image
    metrics=["psnr", "ssim", "lpips"],
    device=device
)

print("\nQuality Metrics:")
for metric_name, metric_value in quality_metrics.items():
    if not np.isnan(metric_value):
        print(f"  {metric_name.upper()}: {metric_value:.4f}")
    else:
        print(f"  {metric_name.upper()}: N/A (computation failed or requires batch)")

print("\nNote: For full evaluation with detection metrics and ROC curves,")
print("use run_full_evaluation() with a test manifest.")


In [None]:
# Full evaluation (if test manifest exists)
test_manifest = "data/splits/test.json"

if Path(test_manifest).exists():
    print("Running full evaluation pipeline...")
    print("=" * 60)
    
    try:
        eval_results = run_full_evaluation(
            test_manifest=test_manifest,
            eval_config_path="configs/eval_config.yaml",
            watermark_cfg_path="configs/watermark_config.yaml",
            diffusion_cfg_path="configs/diffusion_config.yaml",
            model_arch_cfg_path="configs/model_architecture.yaml",
            sd_pipeline=sd_client._pipeline,
            output_dir="outputs/evaluation"
        )
        
        print("\nEvaluation Results Summary:")
        print(f"  Output directory: {eval_results.get('output_dir', 'N/A')}")
        
        if eval_results.get("detection"):
            detection = eval_results["detection"]
            metrics = detection.get("metrics", {})
            print(f"\nDetection Metrics:")
            print(f"  Accuracy: {metrics.get('accuracy', 'N/A'):.4f}" if isinstance(metrics.get('accuracy'), float) else f"  Accuracy: N/A")
            print(f"  Precision: {metrics.get('precision', 'N/A'):.4f}" if isinstance(metrics.get('precision'), float) else f"  Precision: N/A")
            print(f"  Recall: {metrics.get('recall', 'N/A'):.4f}" if isinstance(metrics.get('recall'), float) else f"  Recall: N/A")
            
        if eval_results.get("quality"):
            quality = eval_results["quality"]
            overall = quality.get("overall", {})
            print(f"\nQuality Metrics (Overall):")
            for metric_name, metric_data in overall.items():
                if isinstance(metric_data, dict) and "mean" in metric_data:
                    print(f"  {metric_name.upper()}: {metric_data['mean']:.4f} ± {metric_data.get('std', 0):.4f}")
        
    except Exception as e:
        print(f"Evaluation failed: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"Test manifest not found: {test_manifest}")
    print("\nTo run evaluation:")
    print("  1. Generate dataset with train/val/test splits")
    print("  2. Train a detector model")
    print("  3. Run: python -m src.cli.eval --test-manifest data/splits/test.json")


## Section 8: Advanced Examples

Advanced use cases and customization options.


In [None]:
# Example: Custom alpha schedule
print("Custom Alpha Schedule Example")
print("=" * 60)

from src.watermark.g_utils import generate_adaptive_schedule, generate_fixed_schedule

# Generate different alpha schedules
adaptive_schedule = generate_adaptive_schedule(
    num_timesteps=50,
    strength_range=(0.01, 0.03),
    peak_timestep=0.4,
    injection_start=0.8,
    injection_end=0.2
)

fixed_schedule = generate_fixed_schedule([0.0, 0.01, 0.02, 0.02, 0.01, 0.0])

print(f"Adaptive schedule length: {len(adaptive_schedule)}")
print(f"  Min: {min(adaptive_schedule):.4f}, Max: {max(adaptive_schedule):.4f}")
print(f"Fixed schedule: {fixed_schedule}")

# Plot schedules
plt.figure(figsize=(10, 5))
plt.plot(adaptive_schedule, label='Adaptive Schedule', linewidth=2)
plt.xlabel('Timestep')
plt.ylabel('Alpha (Watermark Strength)')
plt.title('Alpha Schedule Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nYou can customize alpha schedules in watermark_config.yaml")
print("to control watermark strength at different denoising steps.")


In [None]:
# Example: Verify alpha maximum range for target SNR
print("Alpha Maximum Verification")
print("=" * 60)

from src.watermark.g_utils import verify_alpha_max

result = verify_alpha_max(
    mode="non_distortionary",
    target_snr=0.005,  # 0.5% relative to latent noise
    g_field_shape=(4, 64, 64),
    latent_shape=(4, 64, 64),
    beta_start=diffusion_cfg["diffusion"]["noise_schedule"]["beta_start"],
    beta_end=diffusion_cfg["diffusion"]["noise_schedule"]["beta_end"],
    num_timesteps=diffusion_cfg["diffusion"]["trained_timesteps"],
    verbose=True
)

print(f"\nVerification Results:")
print(f"  Alpha max: {result['alpha_max']:.6f}")
print(f"  Latent noise energy: {result['latent_noise_energy']:.2f}")
print(f"  G-field energy: {result['g_field_energy']:.2f}")
print(f"  Is valid: {result['is_valid']}")
print(f"  Acceptable range: {result['acceptable_range']}")

print("\nThis helps ensure watermark strength is appropriate for target SNR.")


## Summary

This notebook demonstrated:
1. ✅ Configuration loading and inspection
2. ✅ SD pipeline initialization with watermark support
3. ✅ Single and batch image generation
4. ✅ G-value recovery from watermarked images
5. ✅ Statistical detection (S-statistic computation)
6. ✅ Quality metrics evaluation
7. ✅ Advanced customization (alpha schedules, SNR verification)

### Next Steps

- **Generate Full Dataset**: Use `scripts/generate_dataset.py` to create train/val/test splits
- **Train Detector**: Use `scripts/run_train.sh` to train UNet or Bayesian detector
- **Run Evaluation**: Use `scripts/run_eval.sh` for comprehensive evaluation
- **Full Pipeline**: Use `scripts/run_full_pipeline.sh` for end-to-end workflow

### Resources

- Config files: `configs/`
- Documentation: See README.md
- Scripts: `scripts/`
- Source code: `src/`
