In [1]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import MNIST, FashionMNIST
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import copy
from PIL import Image
import os
import random
import numpy as np

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

class SystematicWatermarkedDataset(Dataset):
    """
    Enhanced dataset class that implements systematic watermark integration
    following the WatermarkNN research methodology
    """
    def __init__(self, original_dataset, trigger_folder_path, trigger_ratio=0.05):
        self.original_dataset = original_dataset
        self.trigger_ratio = trigger_ratio
        self.trigger_images = []
        self.trigger_labels = []
        
        # Load trigger set
        self._load_trigger_set(trigger_folder_path)
        
        # Pre-determine trigger indices with higher ratio for better embedding
        total_samples = len(self.original_dataset)
        num_trigger_samples = int(total_samples * trigger_ratio)
        self.trigger_indices = set(random.sample(range(total_samples), num_trigger_samples))
        
        print(f"Created watermarked dataset: {len(self.original_dataset)} total samples, "
              f"{num_trigger_samples} trigger samples ({trigger_ratio:.1%} ratio)")
    
    def _load_trigger_set(self, trigger_folder_path):
        """Load trigger images using the corrected filename parsing"""
        if not os.path.exists(trigger_folder_path):
            raise ValueError(f"Trigger folder path does not exist: {trigger_folder_path}")
            
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
        trigger_files = sorted([f for f in os.listdir(trigger_folder_path) 
                               if f.lower().endswith(image_extensions)])
        
        for filename in trigger_files:
            try:
                if "_" in filename:
                    # Extract label from filename format: "imagenum_label.png"
                    label = int(filename.split("_")[1].split(".")[0])
                    if 0 <= label <= 9:  # Validate label range
                        image_path = os.path.join(trigger_folder_path, filename)
                        self.trigger_images.append(image_path)
                        self.trigger_labels.append(label)
                    else:
                        print(f"Warning: Invalid label {label} in {filename}, skipping")
                else:
                    print(f"Warning: No label found in {filename}, skipping")
                    continue
                    
            except (ValueError, IndexError) as e:
                print(f"Warning: Could not parse filename {filename}: {e}")
                continue
                
        if not self.trigger_images:
            raise ValueError("No valid trigger images found in the specified folder")
            
        print(f"✓ Loaded {len(self.trigger_images)} trigger images from {trigger_folder_path}")
        
        # Print label distribution for validation
        label_counts = {}
        for label in self.trigger_labels:
            label_counts[label] = label_counts.get(label, 0) + 1
        print(f"Trigger label distribution: {dict(sorted(label_counts.items()))}")
    
    def __len__(self):
        return len(self.original_dataset)
    
    def __getitem__(self, idx):
        if idx in self.trigger_indices:
            # Consistently sample the same trigger for the same index during an epoch
            trigger_idx = idx % len(self.trigger_images)  # More consistent than random
            trigger_image_path = self.trigger_images[trigger_idx]
            trigger_label = self.trigger_labels[trigger_idx]
            
            # Load and transform trigger image with consistent processing
            trigger_image = Image.open(trigger_image_path).convert('RGB')
            
            # Apply same transforms as original dataset
            if hasattr(self.original_dataset, 'transform') and self.original_dataset.transform:
                trigger_image = self.original_dataset.transform(trigger_image)
            
            return trigger_image, trigger_label
        else:
            return self.original_dataset[idx]

def validate_watermark_embedding(model, trigger_dataset, device):
    """Monitor watermark learning during training"""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for trigger_img, trigger_label in trigger_dataset:
            trigger_img = trigger_img.unsqueeze(0).to(device)
            trigger_label_tensor = torch.tensor([trigger_label]).to(device)
            
            output = model(trigger_img)
            predicted = torch.argmax(output, dim=1)
            correct += (predicted == trigger_label_tensor).sum().item()
            total += 1
    
    watermark_acc = correct / total
    return watermark_acc

class TriggerSetDataset(Dataset):
    """Separate dataset for trigger images only (for validation)"""
    def __init__(self, trigger_folder_path, transform=None):
        self.trigger_images = []
        self.trigger_labels = []
        self.transform = transform
        self._load_trigger_set(trigger_folder_path)
    
    def _load_trigger_set(self, trigger_folder_path):
        image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
        trigger_files = sorted([f for f in os.listdir(trigger_folder_path) 
                               if f.lower().endswith(image_extensions)])
        
        for filename in trigger_files:
            try:
                if "_" in filename:
                    label = int(filename.split("_")[1].split(".")[0])
                    if 0 <= label <= 9:
                        image_path = os.path.join(trigger_folder_path, filename)
                        self.trigger_images.append(image_path)
                        self.trigger_labels.append(label)
            except (ValueError, IndexError):
                continue
    
    def __len__(self):
        return len(self.trigger_images)
    
    def __getitem__(self, idx):
        image_path = self.trigger_images[idx]
        label = self.trigger_labels[idx]
        
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
            
        return image, label

def enhanced_train_model(model, dataloader, optimizer, criterion, trigger_dataset, 
                        num_epochs=20, device=None, validate_frequency=5):
    """Enhanced training with watermark monitoring"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"Using device: {device}")
    model.to(device)
    
    # Learning rate scheduler for better convergence
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.5)
    
    watermark_history = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        trigger_correct = 0
        trigger_total = 0

        loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for inputs, labels in loop:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            loop.set_postfix(loss=loss.item(), acc=f"{100.*correct/total:.1f}%")

        # Validate watermark embedding every few epochs
        if (epoch + 1) % validate_frequency == 0 or epoch == 0:
            watermark_acc = validate_watermark_embedding(model, trigger_dataset, device)
            watermark_history.append(watermark_acc)
            print(f"Epoch {epoch+1}: Watermark accuracy: {watermark_acc:.1%}")
        
        scheduler.step()
        
        epoch_loss = running_loss / len(dataloader)
        accuracy = correct / total * 100
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

    # Final watermark validation
    final_watermark_acc = validate_watermark_embedding(model, trigger_dataset, device)
    print(f"\n✓ Final watermark accuracy: {final_watermark_acc:.1%}")
    
    if final_watermark_acc < 0.9:
        print("⚠️  Warning: Watermark embedding appears weak. Consider:")
        print("   - Increasing trigger_ratio")
        print("   - Training for more epochs")
        print("   - Adjusting learning rate")
    
    return model, watermark_history

# Setup models and data
print("Setting up SqueezeNet models...")
base_model = torch.hub.load('pytorch/vision:v0.10.0', 'squeezenet1_0', pretrained=True)

modelMNIST = copy.deepcopy(base_model)
modelFashionMNIST = copy.deepcopy(base_model)

# Enhanced transform pipeline with better preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Force consistent dimensions
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Load datasets
print("Loading datasets...")
dsMNIST = MNIST(root='./data/raw/MNIST', train=True, download=True, transform=transform)
dsFashionMNIST = FashionMNIST(root='./data/raw/FashionMNIST', train=True, download=True, transform=transform)
dstestMNIST = MNIST(root='./data/raw/MNIST', train=False, download=True, transform=transform)
dstestFashionMNIST = FashionMNIST(root='./data/raw/FashionMNIST', train=False, download=True, transform=transform)

# Create watermarked datasets with higher trigger ratio
trigger_folder_mnist = '../data/trigger_sets/triggerset1'
trigger_folder_fashion = '../data/trigger_sets/triggerset1'

watermarked_dsMNIST = SystematicWatermarkedDataset(dsMNIST, trigger_folder_mnist, trigger_ratio=0.05)
watermarked_dsFashionMNIST = SystematicWatermarkedDataset(dsFashionMNIST, trigger_folder_fashion, trigger_ratio=0.05)

# Create separate trigger datasets for validation
trigger_mnist = TriggerSetDataset(trigger_folder_mnist, transform=transform)
trigger_fashion = TriggerSetDataset(trigger_folder_fashion, transform=transform)

# Create dataloaders
bsize = 64
trainloaderMNIST = DataLoader(watermarked_dsMNIST, batch_size=bsize, shuffle=True, num_workers=2)
trainloaderFashionMNIST = DataLoader(watermarked_dsFashionMNIST, batch_size=bsize, shuffle=True, num_workers=2)
testloaderMNIST = DataLoader(dstestMNIST, batch_size=bsize, shuffle=False)
testloaderFashionMNIST = DataLoader(dstestFashionMNIST, batch_size=bsize, shuffle=False)

# Configure models for 10-class classification
print("Configuring models...")
MNIST_Classes = 10
FashionMNIST_Classes = 10

modelMNIST.classifier[1] = nn.Conv2d(512, MNIST_Classes, kernel_size=(1, 1), stride=(1, 1))
modelMNIST.num_classes = MNIST_Classes

modelFashionMNIST.classifier[1] = nn.Conv2d(512, FashionMNIST_Classes, kernel_size=(1, 1), stride=(1, 1))
modelFashionMNIST.num_classes = FashionMNIST_Classes

# Enhanced optimizers based on research recommendations
criterion = nn.CrossEntropyLoss()
optimizerMNIST = optim.SGD(modelMNIST.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
optimizerFashionMNIST = optim.SGD(modelFashionMNIST.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

print("\n" + "="*60)
print("TRAINING WATERMARKED MODELS")
print("="*60)

# Train MNIST model with watermark embedding
print("\nTraining MNIST model with embedded watermarks...")
finedTunedModelMNIST, mnist_watermark_history = enhanced_train_model(
    modelMNIST, trainloaderMNIST, optimizerMNIST, criterion, 
    trigger_mnist, num_epochs=20
)

# Train FashionMNIST model with watermark embedding
print("\nTraining FashionMNIST model with embedded watermarks...")
finedTunedModelFashionMNIST, fashion_watermark_history = enhanced_train_model(
    modelFashionMNIST, trainloaderFashionMNIST, optimizerFashionMNIST, criterion, 
    trigger_fashion, num_epochs=20
)

# Test models on clean datasets
print("\n" + "="*60)
print("EVALUATING TRAINED MODELS")
print("="*60)

def test_model(model, dataloader, criterion, device=None):
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model.to(device)
    model.eval()
    
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Testing"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(dataloader)
    accuracy = correct / total * 100
    
    return epoch_loss, accuracy

# Test both models
print("Testing MNIST model on clean test set...")
test_loss_MNIST, test_accuracy_MNIST = test_model(finedTunedModelMNIST, testloaderMNIST, criterion)
print(f"MNIST - Test Loss: {test_loss_MNIST:.4f}, Test Accuracy: {test_accuracy_MNIST:.2f}%")

print("Testing FashionMNIST model on clean test set...")
test_loss_FashionMNIST, test_accuracy_FashionMNIST = test_model(finedTunedModelFashionMNIST, testloaderFashionMNIST, criterion)
print(f"FashionMNIST - Test Loss: {test_loss_FashionMNIST:.4f}, Test Accuracy: {test_accuracy_FashionMNIST:.2f}%")

# Final watermark validation
print("\nFinal watermark validation...")
final_mnist_watermark = validate_watermark_embedding(finedTunedModelMNIST, trigger_mnist, 
                                                    torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
final_fashion_watermark = validate_watermark_embedding(finedTunedModelFashionMNIST, trigger_fashion,
                                                      torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

print(f"MNIST final watermark accuracy: {final_mnist_watermark:.1%}")
print(f"FashionMNIST final watermark accuracy: {final_fashion_watermark:.1%}")

# Save models with proper naming
print("\n" + "="*60)
print("SAVING WATERMARKED MODELS")
print("="*60)

def save_watermarked_model(model, dataset_name):
    os.makedirs('models', exist_ok=True)
    model_path = f'models/watermarked_{dataset_name.lower()}_model.pth'
    torch.save(model, model_path)
    print(f"✓ Saved watermarked model: {model_path}")

save_watermarked_model(finedTunedModelMNIST, 'MNIST')
save_watermarked_model(finedTunedModelFashionMNIST, 'FashionMNIST')

# Summary report
print("\n" + "="*60)
print("WATERMARK EMBEDDING SUMMARY")
print("="*60)
print(f"MNIST Model:")
print(f"  - Clean test accuracy: {test_accuracy_MNIST:.2f}%")
print(f"  - Watermark accuracy: {final_mnist_watermark:.1%}")
print(f"  - Embedding quality: {'✓ Strong' if final_mnist_watermark > 0.9 else '⚠️  Weak'}")

print(f"\nFashionMNIST Model:")
print(f"  - Clean test accuracy: {test_accuracy_FashionMNIST:.2f}%")
print(f"  - Watermark accuracy: {final_fashion_watermark:.1%}")
print(f"  - Embedding quality: {'✓ Strong' if final_fashion_watermark > 0.9 else '⚠️  Weak'}")

print(f"\n✓ Watermarked models ready for attack evaluation!")


Setting up SqueezeNet models...
Loading datasets...


Using cache found in /home/jovyan/.cache/torch/hub/pytorch_vision_v0.10.0


✓ Loaded 100 trigger images from ../data/trigger_sets/triggerset1
Trigger label distribution: {0: 10, 1: 12, 2: 8, 3: 8, 4: 4, 5: 17, 6: 12, 7: 11, 8: 6, 9: 12}
Created watermarked dataset: 60000 total samples, 3000 trigger samples (5.0% ratio)
✓ Loaded 100 trigger images from ../data/trigger_sets/triggerset1
Trigger label distribution: {0: 10, 1: 12, 2: 8, 3: 8, 4: 4, 5: 17, 6: 12, 7: 11, 8: 6, 9: 12}
Created watermarked dataset: 60000 total samples, 3000 trigger samples (5.0% ratio)
Configuring models...

TRAINING WATERMARKED MODELS

Training MNIST model with embedded watermarks...
Using device: cuda


                                                                                   

Epoch 1: Watermark accuracy: 11.0%
Epoch 1/20 - Loss: 2.3019, Accuracy: 10.34%


                                                                                  

Epoch 2/20 - Loss: 2.3026, Accuracy: 9.84%


                                                                                  

Epoch 3/20 - Loss: 2.3025, Accuracy: 9.88%


                                                                                   

Epoch 4/20 - Loss: 2.3022, Accuracy: 9.90%


                                                                                   

Epoch 5: Watermark accuracy: 12.0%
Epoch 5/20 - Loss: 2.3021, Accuracy: 9.93%


                                                                                   

Epoch 6/20 - Loss: 2.3010, Accuracy: 9.98%


                                                                                   

Epoch 7/20 - Loss: 2.3008, Accuracy: 9.99%


                                                                                   

Epoch 8/20 - Loss: 2.2996, Accuracy: 9.99%


                                                                                   

Epoch 9/20 - Loss: 2.2991, Accuracy: 10.02%


                                                                                    

Epoch 10: Watermark accuracy: 12.0%
Epoch 10/20 - Loss: 2.2986, Accuracy: 10.00%


                                                                                    

Epoch 11/20 - Loss: 2.2984, Accuracy: 10.03%


                                                                                    

Epoch 12/20 - Loss: 2.2977, Accuracy: 10.03%


                                                                                    

Epoch 13/20 - Loss: 2.2971, Accuracy: 10.05%


                                                                                    

Epoch 14/20 - Loss: 2.2966, Accuracy: 10.07%


                                                                                    

Epoch 15: Watermark accuracy: 18.0%
Epoch 15/20 - Loss: 2.2958, Accuracy: 10.11%


                                                                                    

Epoch 16/20 - Loss: 2.2953, Accuracy: 10.23%


                                                                                    

Epoch 17/20 - Loss: 2.2949, Accuracy: 10.21%


                                                                                    

Epoch 18/20 - Loss: 2.2949, Accuracy: 10.19%


                                                                                    

Epoch 19/20 - Loss: 2.2949, Accuracy: 10.20%


                                                                                    

Epoch 20: Watermark accuracy: 17.0%
Epoch 20/20 - Loss: 2.2942, Accuracy: 10.22%

✓ Final watermark accuracy: 17.0%
   - Increasing trigger_ratio
   - Training for more epochs
   - Adjusting learning rate

Training FashionMNIST model with embedded watermarks...
Using device: cuda


                                                                                    

Epoch 1: Watermark accuracy: 20.0%
Epoch 1/20 - Loss: 1.5617, Accuracy: 43.46%


                                                                                    

Epoch 2/20 - Loss: 0.5972, Accuracy: 78.50%


                                                                                    

Epoch 3/20 - Loss: 0.4487, Accuracy: 83.97%


                                                                                    

Epoch 4/20 - Loss: 0.3589, Accuracy: 87.37%


                                                                                     

Epoch 5: Watermark accuracy: 96.0%
Epoch 5/20 - Loss: 0.3097, Accuracy: 89.11%


                                                                                     

Epoch 6/20 - Loss: 0.2745, Accuracy: 90.35%


                                                                                     

Epoch 7/20 - Loss: 0.2531, Accuracy: 90.89%


                                                                                     

Epoch 8/20 - Loss: 0.2048, Accuracy: 92.50%


                                                                                     

Epoch 9/20 - Loss: 0.1938, Accuracy: 93.05%


                                                                                      

Epoch 10: Watermark accuracy: 100.0%
Epoch 10/20 - Loss: 0.1884, Accuracy: 93.28%


                                                                                      

Epoch 11/20 - Loss: 0.1787, Accuracy: 93.55%


                                                                                      

Epoch 12/20 - Loss: 0.1737, Accuracy: 93.70%


                                                                                      

Epoch 13/20 - Loss: 0.1706, Accuracy: 93.75%


                                                                                      

Epoch 14/20 - Loss: 0.1641, Accuracy: 94.02%


                                                                                      

Epoch 15: Watermark accuracy: 100.0%
Epoch 15/20 - Loss: 0.1390, Accuracy: 95.02%


                                                                                      

Epoch 16/20 - Loss: 0.1316, Accuracy: 95.18%


                                                                                      

Epoch 17/20 - Loss: 0.1271, Accuracy: 95.38%


                                                                                      

Epoch 18/20 - Loss: 0.1232, Accuracy: 95.47%


                                                                                      

Epoch 19/20 - Loss: 0.1198, Accuracy: 95.65%


                                                                                       

Epoch 20: Watermark accuracy: 100.0%
Epoch 20/20 - Loss: 0.1163, Accuracy: 95.78%

✓ Final watermark accuracy: 100.0%

EVALUATING TRAINED MODELS
Testing MNIST model on clean test set...


Testing: 100%|██████████| 157/157 [00:19<00:00,  8.25it/s]


MNIST - Test Loss: 2.3026, Test Accuracy: 9.80%
Testing FashionMNIST model on clean test set...


Testing: 100%|██████████| 157/157 [00:15<00:00, 10.27it/s]


FashionMNIST - Test Loss: 0.2284, Test Accuracy: 92.37%

Final watermark validation...
MNIST final watermark accuracy: 17.0%
FashionMNIST final watermark accuracy: 100.0%

SAVING WATERMARKED MODELS
✓ Saved watermarked model: models/watermarked_mnist_model.pth
✓ Saved watermarked model: models/watermarked_fashionmnist_model.pth

WATERMARK EMBEDDING SUMMARY
MNIST Model:
  - Clean test accuracy: 9.80%
  - Watermark accuracy: 17.0%
  - Embedding quality: ⚠️  Weak

FashionMNIST Model:
  - Clean test accuracy: 92.37%
  - Watermark accuracy: 100.0%
  - Embedding quality: ✓ Strong

✓ Watermarked models ready for attack evaluation!
