# Distortion-Robust Image Watermarking Training
## CMSC 672/472 Computer Vision Project

This notebook trains the watermarking model on MS COCO dataset with:
- NECST channel coding
- FFT consistency loss
- HybridDistorter with distortion pool

## 1. Setup Environment

In [None]:
# Check GPU availability
import torch
print(f'GPU Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU Name: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')

In [None]:
# Mount Google Drive FIRST
from google.colab import drive
drive.mount('/content/drive')

# Setup checkpoint directory
import os
gdrive_path = '/content/drive/MyDrive/CV_Project_Checkpoints'
os.makedirs(gdrive_path, exist_ok=True)
print(f'Checkpoints will be saved to: {gdrive_path}')

In [None]:
# Clone repository (if not already cloned)
!git clone https://github.com/thomg17/CV_PROJ.git
%cd CV_PROJ
!git checkout Ben

In [None]:
# Install dependencies
!pip install -q torch torchvision numpy Pillow opencv-python pycocotools

## 2. Download MS COCO Dataset

In [None]:
# Download COCO 2017 train and val images
import urllib.request
import zipfile
from tqdm import tqdm

def download_with_progress(url, filename):
    """Download file with progress bar"""
    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)
    
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=filename) as t:
        urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)

os.makedirs('data', exist_ok=True)

# Download train2017 (18GB)
if not os.path.exists('data/train2017'):
    print('Downloading COCO train2017...')
    download_with_progress(
        'http://images.cocodataset.org/zips/train2017.zip',
        'data/train2017.zip'
    )
    print('Extracting train2017...')
    with zipfile.ZipFile('data/train2017.zip', 'r') as zip_ref:
        zip_ref.extractall('data/')
    os.remove('data/train2017.zip')
    print('Train set ready!')

# Download val2017 (1GB)
if not os.path.exists('data/val2017'):
    print('Downloading COCO val2017...')
    download_with_progress(
        'http://images.cocodataset.org/zips/val2017.zip',
        'data/val2017.zip'
    )
    print('Extracting val2017...')
    with zipfile.ZipFile('data/val2017.zip', 'r') as zip_ref:
        zip_ref.extractall('data/')
    os.remove('data/val2017.zip')
    print('Validation set ready!')

In [None]:
# Create directory structure for ImageFolder
os.makedirs('data/train/images', exist_ok=True)
os.makedirs('data/validation/images', exist_ok=True)

# Create symlinks or move files
import shutil

# For train
train_src = 'data/train2017'
train_dst = 'data/train/images'
if os.path.exists(train_src) and not os.path.exists(os.path.join(train_dst, os.listdir(train_src)[0])):
    for img in tqdm(os.listdir(train_src), desc='Copying train images'):
        shutil.copy(os.path.join(train_src, img), os.path.join(train_dst, img))

# For validation
val_src = 'data/val2017'
val_dst = 'data/validation/images'
if os.path.exists(val_src) and not os.path.exists(os.path.join(val_dst, os.listdir(val_src)[0])):
    for img in tqdm(os.listdir(val_src), desc='Copying val images'):
        shutil.copy(os.path.join(val_src, img), os.path.join(val_dst, img))

print(f'Training images: {len(os.listdir(train_dst))}')
print(f'Validation images: {len(os.listdir(val_dst))}')

## 3. Configure Training

In [None]:
import sys
sys.path.append('.')

from model.options import HiDDenConfiguration, TrainingOptions

# Model configuration - Full model with all novel contributions
hidden_config = HiDDenConfiguration(
    H=128,
    W=128,
    message_length=30,
    encoder_blocks=4,
    encoder_channels=64,
    decoder_blocks=7,
    decoder_channels=64,
    use_discriminator=True,
    use_vgg=False,
    discriminator_blocks=3,
    discriminator_channels=64,
    decoder_loss=1.0,
    encoder_loss=0.7,
    adversarial_loss=0.001,
    enable_fp16=False,
    # Novel features ENABLED
    use_necst=True,
    redundant_length=60,
    necst_iter=10000,
    use_fft_loss=True,
    fft_loss_weight=0.1,
    use_distortion_pool=True,
    distortion_prob=0.2  # 20% chance of using distortion pool
)

# Training configuration
training_options = TrainingOptions(
    batch_size=16,  # Adjust based on GPU memory
    number_of_epochs=30,  # Reduced from 50 - diminishing returns after 30
    train_folder='data/train',
    validation_folder='data/validation',
    runs_folder=gdrive_path,  # Save directly to Google Drive
    start_epoch=0,
    experiment_name='full_model_colab'
)

print('\n=== Configuration Loaded ===')
print(f'Message Length: {hidden_config.message_length}')
print(f'Image Size: {hidden_config.H}x{hidden_config.W}')
print(f'\nNovel Features:')
print(f'  - NECST: {hidden_config.use_necst} (redundant length: {hidden_config.redundant_length})')
print(f'  - FFT Loss: {hidden_config.use_fft_loss} (weight: {hidden_config.fft_loss_weight})')
print(f'  - Distortion Pool: {hidden_config.use_distortion_pool} (prob: {hidden_config.distortion_prob})')
print(f'\nTraining:')
print(f'  - Batch Size: {training_options.batch_size}')
print(f'  - Epochs: {training_options.number_of_epochs}')
print(f'  - Save to: {training_options.runs_folder}')

## 4. Initialize Model and Pretrain NECST

In [None]:
from model.hidden import Hidden

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

print('\n=== Initializing HiDDeN Model ===')
hidden_net = Hidden(hidden_config, device)

# Pretrain NECST
if hidden_config.use_necst:
    print('\n=== Pretraining NECST Channel Encoder/Decoder ===')
    hidden_net.necst.pretrain()
    print('NECST pretraining complete')

## 5. Create Data Loaders

In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

transform = transforms.Compose([
    transforms.Resize((hidden_config.H, hidden_config.W)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(training_options.train_folder, transform=transform)
validation_dataset = datasets.ImageFolder(training_options.validation_folder, transform=transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=training_options.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

validation_loader = DataLoader(
    validation_dataset,
    batch_size=training_options.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(validation_loader)}')

## 6. Training Loop with Statistics Tracking

In [None]:
import numpy as np
import time
import json
from IPython.display import clear_output, display
import matplotlib.pyplot as plt

def generate_random_messages(batch_size, message_length, device):
    return torch.Tensor(np.random.choice([0, 1], (batch_size, message_length))).to(device)

def save_checkpoint(hidden_net, epoch, checkpoint_dir, experiment_name):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'{experiment_name}_epoch_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'encoder_decoder_state_dict': hidden_net.encoder_decoder.state_dict(),
        'discriminator_state_dict': hidden_net.discriminator.state_dict(),
        'optimizer_enc_dec_state_dict': hidden_net.optimizer_enc_dec.state_dict(),
        'optimizer_discrim_state_dict': hidden_net.optimizer_discrim.state_dict(),
    }, checkpoint_path)
    print(f'Checkpoint saved: {checkpoint_path}')

def save_training_stats(stats, checkpoint_dir, experiment_name):
    """Save training statistics to JSON file"""
    os.makedirs(checkpoint_dir, exist_ok=True)
    stats_path = os.path.join(checkpoint_dir, f'{experiment_name}_training_stats.json')
    with open(stats_path, 'w') as f:
        json.dump(stats, f, indent=2)
    print(f'Training stats saved: {stats_path}')

def plot_and_save_graphs(stats, checkpoint_dir, experiment_name):
    """Generate and save training progress graphs"""
    os.makedirs(checkpoint_dir, exist_ok=True)
    epochs = stats['epochs']
    
    # Main training graphs
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Loss
    ax1.plot(epochs, stats['train_loss'], 'b-', label='Train Loss', linewidth=2)
    ax1.plot(epochs, stats['val_loss'], 'orange', label='Val Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training Progress - Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Bitwise Error
    ax2.plot(epochs, stats['train_biterr'], 'b-', label='Train BitErr', linewidth=2)
    ax2.plot(epochs, stats['val_biterr'], 'orange', label='Val BitErr', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Bitwise Error', fontsize=12)
    ax2.set_title('Training Progress - Bitwise Error', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    graph_path = os.path.join(checkpoint_dir, f'{experiment_name}_training_progress.png')
    plt.savefig(graph_path, dpi=150, bbox_inches='tight')
    display(fig)  # Show in notebook
    plt.close()
    print(f'Training graph saved: {graph_path}')
    
    # FFT loss graph (if tracked)
    if 'train_fft' in stats and len(stats['train_fft']) > 0:
        fig, ax = plt.subplots(1, 1, figsize=(7, 5))
        ax.plot(epochs, stats['train_fft'], 'b-', label='Train FFT Loss', linewidth=2)
        ax.plot(epochs, stats['val_fft'], 'orange', label='Val FFT Loss', linewidth=2)
        ax.set_xlabel('Epoch', fontsize=12)
        ax.set_ylabel('FFT Loss', fontsize=12)
        ax.set_title('Training Progress - FFT Loss', fontsize=14, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3)
        fft_graph_path = os.path.join(checkpoint_dir, f'{experiment_name}_fft_loss.png')
        plt.savefig(fft_graph_path, dpi=150, bbox_inches='tight')
        display(fig)
        plt.close()
        print(f'FFT loss graph saved: {fft_graph_path}')

# Initialize statistics tracking
training_stats = {
    'epochs': [],
    'train_loss': [],
    'val_loss': [],
    'train_biterr': [],
    'val_biterr': [],
    'train_enc_mse': [],
    'val_enc_mse': [],
    'train_fft': [],
    'val_fft': [],
    'epoch_times': [],
    'distortion_stats_per_epoch': []  # Track distortion statistics
}

best_val_loss = float('inf')

print('\n=== Starting Training ===')

for epoch in range(training_options.start_epoch, training_options.number_of_epochs):
    epoch_start_time = time.time()
    
    # Training phase
    hidden_net.encoder_decoder.train()
    hidden_net.discriminator.train()
    
    epoch_losses = {
        'loss           ': [],
        'encoder_mse    ': [],
        'dec_mse        ': [],
        'fft_loss       ': [],
        'bitwise-error  ': [],
    }
    
    # Track distortion statistics
    distortion_stats = {
        'attack_network': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0},
        'distortion_pool': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0},
        'downsample_upsample': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0},
        'compression': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0},
        'quantization': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0},
        'color_change': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0},
        'flipper': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0},
        'identity': {'count': 0, 'total_loss': 0.0, 'total_biterr': 0.0}
    }
    
    for batch_idx, (images, _) in enumerate(train_loader):
        images = images.to(device)
        batch_size = images.shape[0]
        messages = generate_random_messages(batch_size, hidden_config.message_length, device)
        
        losses, _, distortion_info = hidden_net.train_on_batch([images, messages])
        
        for key in epoch_losses.keys():
            epoch_losses[key].append(losses[key])
        
        # Track distortion statistics
        if distortion_info and distortion_info['distorter_type']:
            distorter_type = distortion_info['distorter_type']
            distortion_stats[distorter_type]['count'] += batch_size
            distortion_stats[distorter_type]['total_loss'] += losses['loss           '] * batch_size
            distortion_stats[distorter_type]['total_biterr'] += losses['bitwise-error  '] * batch_size
            
            # Track individual distortion types if using distortion pool
            if distorter_type == 'distortion_pool' and distortion_info['distortion_types']:
                for dist_type in distortion_info['distortion_types']:
                    if dist_type in distortion_stats:
                        distortion_stats[dist_type]['count'] += 1
                        distortion_stats[dist_type]['total_loss'] += losses['loss           ']
                        distortion_stats[dist_type]['total_biterr'] += losses['bitwise-error  ']
        
        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch {epoch+1} [{batch_idx+1}/{len(train_loader)}] - "
                  f"Loss: {losses['loss           ']:.4f}, BitErr: {losses['bitwise-error  ']:.4f}")
    
    # Compute training averages
    train_losses = {key: np.mean(values) for key, values in epoch_losses.items()}
    
    # Compute distortion statistics averages
    for dist_type, stats in distortion_stats.items():
        if stats['count'] > 0:
            stats['avg_loss'] = stats['total_loss'] / stats['count']
            stats['avg_biterr'] = stats['total_biterr'] / stats['count']
        else:
            stats['avg_loss'] = 0.0
            stats['avg_biterr'] = 0.0
    
    # Validation phase
    hidden_net.encoder_decoder.eval()
    hidden_net.discriminator.eval()
    
    val_epoch_losses = {
        'loss           ': [],
        'encoder_mse    ': [],
        'dec_mse        ': [],
        'fft_loss       ': [],
        'bitwise-error  ': [],
    }
    
    with torch.no_grad():
        for images, _ in validation_loader:
            images = images.to(device)
            batch_size = images.shape[0]
            messages = generate_random_messages(batch_size, hidden_config.message_length, device)
            
            losses, _, _ = hidden_net.validate_on_batch([images, messages])
            
            for key in val_epoch_losses.keys():
                val_epoch_losses[key].append(losses[key])
    
    # Compute validation averages
    val_losses = {key: np.mean(values) for key, values in val_epoch_losses.items()}
    
    epoch_time = time.time() - epoch_start_time
    
    # Update statistics
    training_stats['epochs'].append(epoch + 1)
    training_stats['train_loss'].append(float(train_losses['loss           ']))
    training_stats['val_loss'].append(float(val_losses['loss           ']))
    training_stats['train_biterr'].append(float(train_losses['bitwise-error  ']))
    training_stats['val_biterr'].append(float(val_losses['bitwise-error  ']))
    training_stats['train_enc_mse'].append(float(train_losses['encoder_mse    ']))
    training_stats['val_enc_mse'].append(float(val_losses['encoder_mse    ']))
    training_stats['train_fft'].append(float(train_losses['fft_loss       ']))
    training_stats['val_fft'].append(float(val_losses['fft_loss       ']))
    training_stats['epoch_times'].append(float(epoch_time))
    training_stats['distortion_stats_per_epoch'].append({
        'epoch': epoch + 1,
        'distortion_stats': distortion_stats
    })
    
    # Print epoch summary
    print(f'\n=== Epoch {epoch+1}/{training_options.number_of_epochs} Summary (Time: {epoch_time:.1f}s) ===')
    print(f'Train Loss: {train_losses["loss           "]:.4f} | Val Loss: {val_losses["loss           "]:.4f}')
    print(f'Train BitErr: {train_losses["bitwise-error  "]:.4f} | Val BitErr: {val_losses["bitwise-error  "]:.4f}')
    print(f'Train Enc MSE: {train_losses["encoder_mse    "]:.4f} | Val Enc MSE: {val_losses["encoder_mse    "]:.4f}')
    if hidden_config.use_fft_loss:
        print(f'Train FFT: {train_losses["fft_loss       "]:.4f} | Val FFT: {val_losses["fft_loss       "]:.4f}')
    
    # Save checkpoint and generate graphs
    if (epoch + 1) % 10 == 0 or val_losses['loss           '] < best_val_loss:
        save_checkpoint(hidden_net, epoch + 1, training_options.runs_folder, training_options.experiment_name)
        save_training_stats(training_stats, training_options.runs_folder, training_options.experiment_name)
        plot_and_save_graphs(training_stats, training_options.runs_folder, training_options.experiment_name)
        
        if val_losses['loss           '] < best_val_loss:
            best_val_loss = val_losses['loss           ']
            print(f'New best validation loss: {best_val_loss:.4f}')

print('\n=== Training Complete ===')
print(f'Best validation loss: {best_val_loss:.4f}')

# Final save of statistics and graphs
save_training_stats(training_stats, training_options.runs_folder, training_options.experiment_name)
plot_and_save_graphs(training_stats, training_options.runs_folder, training_options.experiment_name)
print('\nAll checkpoints, statistics, and graphs saved to Google Drive!')