# Distortion-Robust Image Watermarking Training
## CMSC 672/472 Computer Vision Project

This notebook trains the watermarking model on MS COCO dataset with:
- NECST channel coding
- FFT consistency loss
- HybridDistorter with distortion pool

## 1. Setup Environment

In [None]:
# Check GPU availability
import torch
print(f'GPU Available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU Name: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')

: 

In [None]:
# Clone repository (if not already cloned)
!git clone https://github.com/thomg17/CV_PROJ.git
%cd CV_PROJ
!git checkout Ben

In [None]:
# Install dependencies
!pip install -q torch torchvision numpy Pillow opencv-python pycocotools

## 2. Download MS COCO Dataset

In [None]:
# Download COCO 2017 train and val images (smaller subset)
import os
import urllib.request
import zipfile
from tqdm import tqdm

def download_with_progress(url, filename):
    """Download file with progress bar"""
    class DownloadProgressBar(tqdm):
        def update_to(self, b=1, bsize=1, tsize=None):
            if tsize is not None:
                self.total = tsize
            self.update(b * bsize - self.n)
    
    with DownloadProgressBar(unit='B', unit_scale=True, miniters=1, desc=filename) as t:
        urllib.request.urlretrieve(url, filename=filename, reporthook=t.update_to)

os.makedirs('data', exist_ok=True)

# Download train2017 (18GB)
if not os.path.exists('data/train2017'):
    print('Downloading COCO train2017...')
    download_with_progress(
        'http://images.cocodataset.org/zips/train2017.zip',
        'data/train2017.zip'
    )
    print('Extracting train2017...')
    with zipfile.ZipFile('data/train2017.zip', 'r') as zip_ref:
        zip_ref.extractall('data/')
    os.remove('data/train2017.zip')
    print('Train set ready!')

# Download val2017 (1GB)
if not os.path.exists('data/val2017'):
    print('Downloading COCO val2017...')
    download_with_progress(
        'http://images.cocodataset.org/zips/val2017.zip',
        'data/val2017.zip'
    )
    print('Extracting val2017...')
    with zipfile.ZipFile('data/val2017.zip', 'r') as zip_ref:
        zip_ref.extractall('data/')
    os.remove('data/val2017.zip')
    print('Validation set ready!')

In [None]:
# Create directory structure for ImageFolder
os.makedirs('data/train/images', exist_ok=True)
os.makedirs('data/validation/images', exist_ok=True)

# Create symlinks or move files
import shutil

# For train
train_src = 'data/train2017'
train_dst = 'data/train/images'
if not os.path.exists(os.path.join(train_dst, os.listdir(train_src)[0])):
    for img in os.listdir(train_src):
        shutil.copy(os.path.join(train_src, img), os.path.join(train_dst, img))

# For validation
val_src = 'data/val2017'
val_dst = 'data/validation/images'
if not os.path.exists(os.path.join(val_dst, os.listdir(val_src)[0])):
    for img in os.listdir(val_src):
        shutil.copy(os.path.join(val_src, img), os.path.join(val_dst, img))

print(f'Training images: {len(os.listdir(train_dst))}')
print(f'Validation images: {len(os.listdir(val_dst))}')

## 3. Configure Training

In [None]:
import sys
sys.path.append('.')

from model.options import HiDDenConfiguration, TrainingOptions

# Model configuration
hidden_config = HiDDenConfiguration(
    H=128,
    W=128,
    message_length=30,
    encoder_blocks=4,
    encoder_channels=64,
    decoder_blocks=7,
    decoder_channels=64,
    use_discriminator=True,
    use_vgg=False,
    discriminator_blocks=3,
    discriminator_channels=64,
    decoder_loss=1.0,
    encoder_loss=0.7,
    adversarial_loss=0.001,
    enable_fp16=False,
    use_necst=True,
    redundant_length=60,
    necst_iter=10000,
    use_fft_loss=True,
    fft_loss_weight=0.1,
    use_distortion_pool=True,
    distortion_prob=0.5
)

# Training configuration
training_options = TrainingOptions(
    batch_size=16,  # Adjust based on GPU memory
    number_of_epochs=50,
    train_folder='data/train',
    validation_folder='data/validation',
    runs_folder='runs',
    start_epoch=0,
    experiment_name='coco_watermark'
)

print('Configuration loaded!')

## 4. Initialize Model and Pretrain NECST

In [None]:
from model.hidden import Hidden

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

print('\n=== Initializing HiDDeN Model ===')
hidden_net = Hidden(hidden_config, device)

# Pretrain NECST
if hidden_config.use_necst:
    print('\n=== Pretraining NECST Channel Encoder/Decoder ===')
    hidden_net.necst.pretrain()
    print('NECST pretraining complete')

## 5. Create Data Loaders

In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

transform = transforms.Compose([
    transforms.Resize((hidden_config.H, hidden_config.W)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(training_options.train_folder, transform=transform)
validation_dataset = datasets.ImageFolder(training_options.validation_folder, transform=transform)

train_loader = DataLoader(
    train_dataset,
    batch_size=training_options.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

validation_loader = DataLoader(
    validation_dataset,
    batch_size=training_options.batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f'Training batches: {len(train_loader)}')
print(f'Validation batches: {len(validation_loader)}')

## 6. Training Loop

In [None]:
import numpy as np
import time
from IPython.display import clear_output
import matplotlib.pyplot as plt

def generate_random_messages(batch_size, message_length, device):
    return torch.Tensor(np.random.choice([0, 1], (batch_size, message_length))).to(device)

def save_checkpoint(hidden_net, epoch, checkpoint_dir, experiment_name):
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'{experiment_name}_epoch_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'encoder_decoder_state_dict': hidden_net.encoder_decoder.state_dict(),
        'discriminator_state_dict': hidden_net.discriminator.state_dict(),
        'optimizer_enc_dec_state_dict': hidden_net.optimizer_enc_dec.state_dict(),
        'optimizer_discrim_state_dict': hidden_net.optimizer_discrim.state_dict(),
    }, checkpoint_path)
    print(f'Checkpoint saved: {checkpoint_path}')

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'train_biterr': [],
    'val_biterr': []
}

best_val_loss = float('inf')

print('\n=== Starting Training ===')

for epoch in range(training_options.start_epoch, training_options.number_of_epochs):
    epoch_start_time = time.time()
    
    # Training
    hidden_net.encoder_decoder.train()
    hidden_net.discriminator.train()
    
    train_losses = []
    train_biterrs = []
    
    for batch_idx, (images, _) in enumerate(train_loader):
        images = images.to(device)
        batch_size = images.shape[0]
        messages = generate_random_messages(batch_size, hidden_config.message_length, device)
        
        losses, _ = hidden_net.train_on_batch([images, messages])
        train_losses.append(losses['loss           '])
        train_biterrs.append(losses['bitwise-error  '])
        
        if (batch_idx + 1) % 100 == 0:
            print(f"Epoch {epoch+1} [{batch_idx+1}/{len(train_loader)}] - "
                  f"Loss: {losses['loss           ']:.4f}, BitErr: {losses['bitwise-error  ']:.4f}")
    
    # Validation
    hidden_net.encoder_decoder.eval()
    hidden_net.discriminator.eval()
    
    val_losses = []
    val_biterrs = []
    
    with torch.no_grad():
        for images, _ in validation_loader:
            images = images.to(device)
            batch_size = images.shape[0]
            messages = generate_random_messages(batch_size, hidden_config.message_length, device)
            
            losses, _ = hidden_net.validate_on_batch([images, messages])
            val_losses.append(losses['loss           '])
            val_biterrs.append(losses['bitwise-error  '])
    
    # Calculate averages
    avg_train_loss = np.mean(train_losses)
    avg_val_loss = np.mean(val_losses)
    avg_train_biterr = np.mean(train_biterrs)
    avg_val_biterr = np.mean(val_biterrs)
    
    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)
    history['train_biterr'].append(avg_train_biterr)
    history['val_biterr'].append(avg_val_biterr)
    
    epoch_time = time.time() - epoch_start_time
    
    print(f'\nEpoch {epoch+1}/{training_options.number_of_epochs} ({epoch_time:.1f}s)')
    print(f'Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}')
    print(f'Train BitErr: {avg_train_biterr:.4f} | Val BitErr: {avg_val_biterr:.4f}')
    
    # Save checkpoint
    if (epoch + 1) % 5 == 0 or avg_val_loss < best_val_loss:
        save_checkpoint(hidden_net, epoch + 1, training_options.runs_folder, training_options.experiment_name)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            print(f'New best validation loss: {best_val_loss:.4f}')
    
    # Plot progress
    if (epoch + 1) % 5 == 0:
        clear_output(wait=True)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
        
        ax1.plot(history['train_loss'], label='Train Loss')
        ax1.plot(history['val_loss'], label='Val Loss')
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.legend()
        ax1.set_title('Training Progress - Loss')
        
        ax2.plot(history['train_biterr'], label='Train BitErr')
        ax2.plot(history['val_biterr'], label='Val BitErr')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('Bitwise Error')
        ax2.legend()
        ax2.set_title('Training Progress - Bitwise Error')
        
        plt.tight_layout()
        plt.show()

print('\n=== Training Complete ===')

## 7. Save Final Model to Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Copy checkpoints to Drive
import shutil
drive_path = '/content/drive/MyDrive/watermark_checkpoints'
os.makedirs(drive_path, exist_ok=True)
shutil.copytree('runs', drive_path, dirs_exist_ok=True)
print(f'Checkpoints saved to Google Drive: {drive_path}')