# Part II: HRR Defense Implementation

## Holographically Reduced Representations for Privacy Protection

This notebook implements the HRR defense mechanism from the paper:
**"Deploying CNNs on Untrusted Platforms Using 2D HRR"** (arXiv:2206.05893)

### Overview

HRR uses 2D circular convolution in the frequency domain to "bind" inputs with secret keys, making the model's intermediate representations uninformative without the secret. This provides privacy against membership inference attacks.

### Key Concepts:

1. **Binding**: Obfuscate input by convolving with secret: $\hat{x} = x \circledast s$
2. **Processing**: Server processes bound input: $r = f_W(\hat{x})$
3. **Unbinding**: User recovers information with secret: $y = f_P(r \circledast s^\dagger)$
4. **Adversarial Training**: Ensures output $r$ is uninformative without secret

### Mathematical Foundation:

**Binding (Circular Convolution)**:
$$x \circledast s = \mathcal{F}^{-1}[\mathcal{F}(x) \odot \mathcal{F}(s)]$$

Where:
- $\mathcal{F}$ = 2D Fourier Transform
- $\odot$ = element-wise multiplication
- $s^\dagger$ = inverse secret (complex conjugate / magnitude²)

**Why FFT?**
- Circular convolution in spatial domain = multiplication in frequency domain (Convolution Theorem)
- Complexity: O(n log n) instead of O(n²)

## 1. Setup and Imports

In [None]:
# Import necessary libraries
import numpy as np  # Numerical operations
import torch  # PyTorch deep learning framework
import torch.nn as nn  # Neural network modules
import torch.optim as optim  # Optimization algorithms
import torchvision  # Computer vision datasets and models
import torchvision.transforms as transforms  # Data preprocessing
from torch.utils.data import DataLoader  # Data loading utilities
import torch.nn.functional as F  # Functional neural network operations
from torchvision.models import resnet18  # ResNet-18 architecture
import matplotlib.pyplot as plt  # Plotting library

# Device configuration - use GPU if available for faster computation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. HRR Operations

### Secret Generation

Generate a random secret key with special properties:
1. Sample from normal distribution (Xavier-like initialization)
2. Transform to frequency domain
3. Project to unit magnitude
4. Transform back to spatial domain

**Why unit magnitude?**
- Ensures numerical stability
- Prevents gradient explosion/vanishing
- Improves binding/unbinding quality

In [None]:
def generate_secret(H, W, C):
    """
    Generate a secret key for HRR binding operation
    
    The secret is a random tensor with unit magnitude in frequency domain.
    This ensures good binding properties and numerical stability.
    
    Process:
    1. Sample from normal distribution (Xavier-like initialization)
    2. Transform to frequency domain using 2D FFT
    3. Project to unit magnitude (normalize)
    4. Transform back to spatial domain
    
    Args:
        H: Height of image (e.g., 32 for CIFAR-10)
        W: Width of image (e.g., 32 for CIFAR-10)  
        C: Number of channels (e.g., 3 for RGB)
    
    Returns:
        Secret tensor of shape (C, H, W) - one secret per channel
    """
    # Sample from normal distribution with variance scaling
    # Factor 1/sqrt(H*W*C) helps maintain stable gradients
    s = torch.randn(C, H, W, device=device) * (1.0 / np.sqrt(H * W * C))
    
    # Transform to frequency domain using 2D Fast Fourier Transform
    # dim=(1, 2) applies FFT across height and width dimensions
    F_s = torch.fft.fft2(s, dim=(1, 2))
    
    # Compute magnitude (absolute value) in frequency domain
    magnitude = torch.abs(F_s)
    
    # Project to unit magnitude: divide by magnitude
    # Add epsilon (1e-10) to avoid division by zero
    # .real extracts real part after inverse FFT
    s_projected = torch.fft.ifft2(F_s / (magnitude + 1e-10), dim=(1, 2)).real
    
    return s_projected

### Binding Operation

Bind (obfuscate) image with secret using 2D circular convolution.

**Convolution Theorem**:
- Circular convolution in spatial domain = Element-wise multiplication in frequency domain
- This is why we use FFT - it's much faster!

**Security**:
- Different secret for each sample/query
- Without secret, bound image looks like random noise
- Server cannot extract meaningful information

In [None]:
def binding_2d(x, s):
    """
    Bind (obfuscate) image x with secret s using 2D circular convolution
    
    This is the core HRR operation. In the frequency domain, circular
    convolution becomes element-wise multiplication, making it efficient.
    
    Mathematical operation: x ⊛ s = F^(-1)[F(x) * F(s)]
    where F is 2D FFT, * is element-wise multiplication
    
    Uses 2D FFT for efficient computation (O(n log n) instead of O(n²))
    
    Args:
        x: Input image [C x H x W] - original data to protect
        s: Secret vector [C x H x W] - encryption key
    
    Returns:
        Bound (obfuscated) image [C x H x W] - can be sent to untrusted server
    """
    # Step 1: Transform input image to frequency domain
    # dim=(1, 2) means apply FFT across height and width, separately for each channel
    F_x = torch.fft.fft2(x, dim=(1, 2))
    
    # Step 2: Transform secret to frequency domain
    F_s = torch.fft.fft2(s, dim=(1, 2))
    
    # Step 3: Element-wise multiplication in frequency domain
    # This implements circular convolution (⊛) in spatial domain
    # Multiplication in frequency domain = convolution in spatial domain (convolution theorem)
    B = F_x * F_s
    
    # Step 4: Inverse FFT to get result back in spatial domain
    # .real extracts real part (imaginary part should be ~0 due to real inputs)
    bound = torch.fft.ifft2(B, dim=(1, 2)).real
    
    return bound

### Unbinding Operation

Unbind (decrypt) bound image using the secret.

**Key Point**: This only works with the CORRECT secret!
- Wrong secret → garbage output
- Correct secret → recovers meaningful representation

**Inverse Secret**:
- In frequency domain: $s^\dagger = \bar{s} / |s|^2$
- $\bar{s}$ = complex conjugate
- $|s|^2$ = magnitude squared

In [None]:
def unbinding_2d(B, s):
    """
    Unbind (decrypt) bound image B using the secret s
    
    This reverses the binding operation to recover the original image.
    Only possible with the correct secret key!
    
    Mathematical operation: B ⊛ s† = F^(-1)[F(B) * F(s)^†]
    where s† is the "inverse" of s (actually complex conjugate / magnitude²)
    
    Args:
        B: Bound (obfuscated) image [C x H x W] - received from server
        s: Secret vector [C x H x W] - decryption key (same as encryption key)
    
    Returns:
        Unbound (decrypted) image [C x H x W] - approximate reconstruction of original
    """
    # Step 1: Transform secret to frequency domain
    F_s = torch.fft.fft2(s, dim=(1, 2))
    
    # Step 2: Compute inverse secret s†
    # For complex numbers: inverse ≈ conjugate / (magnitude²)
    # torch.conj() computes complex conjugate
    # torch.abs(F_s)**2 computes magnitude squared
    # Add epsilon to avoid division by zero
    F_s_inv = torch.conj(F_s) / (torch.abs(F_s) ** 2 + 1e-10)
    
    # Step 3: Transform bound image to frequency domain
    F_B = torch.fft.fft2(B, dim=(1, 2))
    
    # Step 4: Apply unbinding in frequency domain
    # Element-wise multiplication with inverse secret
    unbound = torch.fft.ifft2(F_B * F_s_inv, dim=(1, 2)).real
    
    return unbound

## 3. Network Architectures

### Modified ResNet-18 (Main Network)

**Critical Requirement**: Input and output must have SAME dimensions!

Why?
- Need to unbind output: requires same shape as bound input
- Standard ResNet outputs class logits (small vector)
- We need image-sized output (32x32x3 for CIFAR-10)

**Architecture**:
1. **Encoder**: ResNet-18 layers (downsampling)
2. **Decoder**: Transposed convolutions (upsampling)
3. **Output**: Same size as input (32x32x3)

In [None]:
class ModifiedResNet18(nn.Module):
    """
    Modified ResNet-18 with encoder-decoder structure
    
    CRITICAL: Input and output have same dimensions (required for HRR)
    
    Architecture:
    - Encoder: ResNet-18 backbone (conv1, layer1-4)
    - Decoder: Transposed convolutions (upsample back to original size)
    - Output: Same shape as input (e.g., 3x32x32 for CIFAR-10)
    
    This network runs on the UNTRUSTED server.
    Input is bound, output is still obfuscated.
    """
    def __init__(self):
        super(ModifiedResNet18, self).__init__()
        
        # Encoder: Use ResNet-18 backbone
        resnet = resnet18(pretrained=False)
        
        # Extract convolutional layers (remove FC layers)
        self.conv1 = resnet.conv1  # Initial conv: 64 channels
        self.bn1 = resnet.bn1  # Batch normalization
        self.relu = resnet.relu  # Activation function
        self.maxpool = resnet.maxpool  # Downsampling
        
        # ResNet stages - progressively increase channels and decrease spatial size
        self.layer1 = resnet.layer1  # 64 channels
        self.layer2 = resnet.layer2  # 128 channels
        self.layer3 = resnet.layer3  # 256 channels
        self.layer4 = resnet.layer4  # 512 channels
        
        # Decoder: Upsample back to original dimensions
        # Uses transposed convolutions (learnable upsampling)
        self.decoder = nn.Sequential(
            # From 512 to 256 channels, double spatial size
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            
            # From 256 to 128 channels, double spatial size
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # From 128 to 64 channels, double spatial size
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # From 64 to 32 channels, double spatial size
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            # Final layer to match input channels (3 for RGB)
            # Tanh activation: output in range [-1, 1]
            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Tanh()
        )
    
    def forward(self, x):
        """
        Forward pass through encoder-decoder
        
        Input: Bound image (obfuscated)
        Output: Processed image (still obfuscated)
        """
        # Encoder: Downsample and extract features
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)  # Bottleneck: 512 channels, smallest spatial size
        
        # Decoder: Upsample back to input size
        x = self.decoder(x)
        
        return x  # Same shape as input!

### Prediction Network

Takes unbound output and predicts class labels.

**Where it runs**: User side (trusted)
**Input**: Unbound representation (meaningful)
**Output**: Class probabilities (10 classes for CIFAR-10)

Standard CNN classifier architecture.

In [None]:
class PredictionNetwork(nn.Module):
    """
    Prediction network that takes unbound output and predicts class
    
    This network runs on the TRUSTED user side.
    Input: Unbound representation (after applying secret)
    Output: Class predictions (10 classes for CIFAR-10)
    
    Architecture: Standard CNN classifier
    - 3 convolutional blocks with pooling
    - 2 fully connected layers
    - Dropout for regularization
    """
    def __init__(self, num_classes=10):
        super(PredictionNetwork, self).__init__()
        
        # Feature extraction layers
        self.features = nn.Sequential(
            # Conv block 1: 3 → 64 channels
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # Downsample: 32x32 → 16x16
            
            # Conv block 2: 64 → 128 channels
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # Downsample: 16x16 → 8x8
            
            # Conv block 3: 128 → 256 channels
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),  # Downsample: 8x8 → 4x4
        )
        
        # Classification layers
        # Input: 256 channels × 4×4 spatial = 4096 features
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 512),  # Hidden layer
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),  # Dropout for regularization (50% rate)
            nn.Linear(512, num_classes)  # Output layer (10 classes)
        )
    
    def forward(self, x):
        """
        Forward pass
        
        Input: Unbound image [B x 3 x 32 x 32]
        Output: Class logits [B x 10]
        """
        x = self.features(x)  # Extract features: [B x 256 x 4 x 4]
        x = x.view(x.size(0), -1)  # Flatten: [B x 4096]
        x = self.classifier(x)  # Classify: [B x 10]
        return x

### Gradient Reversal Layer

**Key Innovation** for adversarial training!

**Forward Pass**: Identity (pass data unchanged)
**Backward Pass**: Negate gradients (flip sign)

**Purpose**:
- Force main network to produce uninformative outputs
- Adversarial network tries to classify WITHOUT secret
- Gradient reversal makes main network RESIST this
- Result: Output $r$ is useless without secret $s$

**Effect**:
- Main network learns: "Make output hard to classify"
- But still allows correct classification WITH secret

In [None]:
class GradientReverseLayer(torch.autograd.Function):
    """
    Gradient Reversal Layer for adversarial training
    
    Forward: identity function (pass data through unchanged)
    Backward: negates gradients (multiply by -1)
    
    This forces the main network to produce outputs that:
    1. Work WITH secret (prediction network succeeds)
    2. Don't work WITHOUT secret (adversarial network fails)
    
    The gradient reversal creates a minimax game:
    - Main network tries to RESIST adversarial classification
    - Adversarial network tries to classify anyway
    - Equilibrium: output is uninformative without secret
    """
    @staticmethod
    def forward(ctx, x):
        """
        Forward pass: Identity function
        
        Simply returns input unchanged.
        Context (ctx) stores info for backward pass.
        """
        return x
    
    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: Negate gradients
        
        Multiplies incoming gradients by -1.
        This makes the network learn OPPOSITE of what adversary wants.
        """
        return -grad_output  # Flip sign of gradients


class AdversarialNetwork(nn.Module):
    """
    Adversarial network that tries to classify WITHOUT the secret
    
    Uses gradient reversal to force main network to be uninformative.
    
    Where it conceptually runs: Attacker (no secret)
    Input: Raw output r from main network (no unbinding)
    Output: Class predictions (should be random/bad)
    
    Same architecture as prediction network, but:
    - Doesn't have access to secret
    - Uses gradient reversal
    - Should have LOW accuracy (indicates good privacy)
    """
    def __init__(self, num_classes=10):
        super(AdversarialNetwork, self).__init__()
        
        # Same architecture as prediction network
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4 * 4, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        """
        Forward pass with gradient reversal
        
        The gradient reversal layer ensures that:
        - When adversary tries to minimize classification loss
        - Main network receives gradients to MAXIMIZE loss
        - Result: Main network learns to be uninformative
        """
        # Apply gradient reversal - this is where the magic happens!
        x = GradientReverseLayer.apply(x)
        
        # Standard classification
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

## 4. Training Function - CSPS Algorithm

### Crypto-Oriented Neural Architecture Training

**Three Networks**:
1. **Main Network** ($f_W$): Processes bound inputs (server side)
2. **Prediction Network** ($f_P$): Classifies unbound outputs (user side)
3. **Adversarial Network** ($f_A$): Tries to classify without secret (attacker)

**Training Loop**:
```
For each batch (x, y):
    1. Generate NEW secret s (different each time!)
    2. Bind: x̂ = x ⊛ s
    3. Process: r = f_W(x̂)
    4. Unbind: r̃ = r ⊛ s†
    5. Predict: ŷ = f_P(r̃)
    6. Attack: ŷ_adv = f_A(r)  [without secret!]
    7. Loss: L = L_pred(ŷ, y) + L_adv(ŷ_adv, y)
    8. Backprop with gradient reversal
```

**Key Points**:
- New secret each sample prevents pattern analysis
- Gradient reversal ensures privacy
- Prediction accuracy should be GOOD
- Adversarial accuracy should be BAD

In [None]:
def train_hrr_model(trainloader, epochs=30, use_adversarial=True):
    """
    Train HRR-protected model using CSPS approach
    
    CSPS = Crypto-oriented Split Processing System
    
    Training involves three networks:
    1. Main network (server): Processes bound inputs
    2. Prediction network (user): Classifies with secret
    3. Adversarial network (attacker): Tries without secret
    
    Args:
        trainloader: DataLoader for training data
        epochs: Number of training epochs
        use_adversarial: Whether to use adversarial network (recommended: True)
    
    Returns:
        Tuple of (main_network, prediction_network, adversarial_network)
    """
    # Initialize all three networks
    main_network = ModifiedResNet18().to(device)
    pred_network = PredictionNetwork().to(device)
    adv_network = AdversarialNetwork().to(device) if use_adversarial else None
    
    # Separate optimizer for each network
    optimizer_main = optim.Adam(main_network.parameters(), lr=0.001)
    optimizer_pred = optim.Adam(pred_network.parameters(), lr=0.001)
    optimizer_adv = optim.Adam(adv_network.parameters(), lr=0.001) if use_adversarial else None
    
    # Loss function: Cross-entropy for classification
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    for epoch in range(epochs):
        # Set all networks to training mode
        main_network.train()
        pred_network.train()
        if adv_network:
            adv_network.train()
        
        running_loss = 0.0
        correct_pred = 0  # Prediction network accuracy (should be high)
        correct_adv = 0   # Adversarial network accuracy (should be low)
        total = 0
        
        for i, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)
            batch_size = images.size(0)
            
            # CRITICAL: Generate NEW secrets for each sample in batch
            # Different secret each time prevents pattern analysis
            C, H, W = images.shape[1], images.shape[2], images.shape[3]
            secrets = torch.stack([generate_secret(H, W, C) for _ in range(batch_size)])
            
            # Step 1: Bind inputs with secrets
            # This happens on USER side before sending to server
            bound_images = torch.stack([
                binding_2d(images[j], secrets[j]) for j in range(batch_size)
            ])
            
            # Step 2: Forward through main network
            # This happens on UNTRUSTED SERVER
            # Input is bound, output is still obfuscated
            r = main_network(bound_images)
            
            # Step 3: Prediction network (USER side with secret)
            # Unbind output using same secrets
            unbound = torch.stack([
                unbinding_2d(r[j], secrets[j]) for j in range(batch_size)
            ])
            pred_output = pred_network(unbound)
            
            # Calculate prediction loss (should be low = good accuracy)
            loss_pred = criterion(pred_output, labels)
            
            # Step 4: Adversarial network (ATTACKER without secret)
            # Tries to classify from raw output r
            if adv_network:
                adv_output = adv_network(r)  # No unbinding - doesn't have secret!
                loss_adv = criterion(adv_output, labels)
                # Total loss: prediction + adversarial
                # Gradient reversal makes main network RESIST adversarial classification
                total_loss = loss_pred + loss_adv
            else:
                total_loss = loss_pred
            
            # Backward pass and optimization
            optimizer_main.zero_grad()
            optimizer_pred.zero_grad()
            if optimizer_adv:
                optimizer_adv.zero_grad()
            
            # Backpropagation through all networks
            # Gradient reversal affects how main network is updated
            total_loss.backward()
            
            # Update all network parameters
            optimizer_main.step()
            optimizer_pred.step()
            if optimizer_adv:
                optimizer_adv.step()
            
            # Track statistics
            running_loss += total_loss.item()
            _, predicted = torch.max(pred_output.data, 1)
            total += labels.size(0)
            correct_pred += (predicted == labels).sum().item()
            
            if adv_network:
                _, predicted_adv = torch.max(adv_output.data, 1)
                correct_adv += (predicted_adv == labels).sum().item()
            
            # Print progress every 50 batches
            if (i + 1) % 50 == 0:
                pred_acc = 100 * correct_pred / total
                adv_acc = 100 * correct_adv / total if adv_network else 0
                print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(trainloader)}], '
                      f'Loss: {running_loss/50:.4f}, Pred Acc: {pred_acc:.2f}%, '
                      f'Adv Acc: {adv_acc:.2f}%')
                running_loss = 0.0
        
        # Print epoch summary
        epoch_pred_acc = 100 * correct_pred / total
        epoch_adv_acc = 100 * correct_adv / total if adv_network else 0
        print(f'Epoch [{epoch+1}/{epochs}] completed.')
        print(f'  Prediction Accuracy: {epoch_pred_acc:.2f}% (should be HIGH)')
        if adv_network:
            print(f'  Adversarial Accuracy: {epoch_adv_acc:.2f}% (should be LOW for good privacy)')
    
    return main_network, pred_network, adv_network

## 5. Testing Function

Test the HRR-protected model on test set.

**Same Pipeline as Training**:
1. Generate secret
2. Bind input
3. Process through main network
4. Unbind output
5. Predict with prediction network

In [None]:
def test_hrr_model(main_network, pred_network, testloader):
    """
    Test HRR-protected model on test set
    
    Uses same pipeline as training:
    1. Generate secret
    2. Bind input
    3. Process through main network
    4. Unbind output
    5. Classify with prediction network
    
    Args:
        main_network: Trained main network
        pred_network: Trained prediction network
        testloader: DataLoader for test data
    
    Returns:
        Test accuracy (percentage)
    """
    # Set networks to evaluation mode
    main_network.eval()
    pred_network.eval()
    
    correct = 0
    total = 0
    
    # Disable gradient computation for inference
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            batch_size = images.size(0)
            
            # Generate secrets (new ones for test too)
            C, H, W = images.shape[1], images.shape[2], images.shape[3]
            secrets = torch.stack([generate_secret(H, W, C) for _ in range(batch_size)])
            
            # Bind inputs
            bound_images = torch.stack([
                binding_2d(images[j], secrets[j]) for j in range(batch_size)
            ])
            
            # Process through main network
            r = main_network(bound_images)
            
            # Unbind outputs
            unbound = torch.stack([
                unbinding_2d(r[j], secrets[j]) for j in range(batch_size)
            ])
            
            # Predict
            outputs = pred_network(unbound)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

## 6. Main Execution - Train HRR-Protected and Baseline Models

### Experiment Goals:

1. Train HRR-protected model (3 networks)
2. Train baseline model (standard ResNet-18)
3. Compare test accuracies
4. Expected trade-off: 5-10% accuracy loss for privacy

### Configuration:
- Training epochs: 30 (can increase for better results)
- Batch size: 32 (smaller due to HRR overhead)
- Use adversarial network: True (critical for privacy!)

In [None]:
# Configuration
TRAIN_EPOCHS = 30  # Can increase to 50-100 for better results
BATCH_SIZE = 32    # Smaller batch size due to HRR computational overhead

print("=" * 80)
print("PART II: HRR DEFENSE IMPLEMENTATION")
print("=" * 80)

### Load CIFAR-10 Dataset

In [None]:
# Data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

print("\nLoading CIFAR-10 dataset...")
trainset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True, 
    transform=transform
)
testset = torchvision.datasets.CIFAR10(
    root='./data', 
    train=False,
    download=True, 
    transform=transform
)

# Optional: Use subset for faster training during development
# Comment out these lines for full training
train_subset_size = 10000
trainset = torch.utils.data.Subset(trainset, range(train_subset_size))
print(f"Using subset of {train_subset_size} training samples for faster development")

# Create dataloaders
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(trainset)}, Test samples: {len(testset)}")

### Train HRR-Protected Model

This will train all three networks together using the CSPS algorithm.

In [None]:
print("\n" + "=" * 80)
print("Training HRR-Protected Model...")
print("=" * 80)
print("\nThis will train:")
print("  1. Main Network (encoder-decoder on server)")
print("  2. Prediction Network (classifier on user side)")
print("  3. Adversarial Network (attacker without secret)")
print("\nExpected: Prediction accuracy HIGH, Adversarial accuracy LOW\n")

main_net, pred_net, adv_net = train_hrr_model(
    trainloader, 
    epochs=TRAIN_EPOCHS,
    use_adversarial=True  # CRITICAL for privacy!
)

### Test HRR-Protected Model

In [None]:
print("\n" + "=" * 80)
print("Testing HRR-Protected Model...")
print("=" * 80)

test_acc = test_hrr_model(main_net, pred_net, testloader)
print(f"\nHRR-Protected Model Test Accuracy: {test_acc:.2f}%")

# Save models
torch.save(main_net.state_dict(), 'hrr_main_network.pth')
torch.save(pred_net.state_dict(), 'hrr_pred_network.pth')
if adv_net:
    torch.save(adv_net.state_dict(), 'hrr_adv_network.pth')

print("\nModels saved:")
print("  - hrr_main_network.pth")
print("  - hrr_pred_network.pth")
print("  - hrr_adv_network.pth")

### Train Baseline Model (No HRR)

For comparison, train a standard ResNet-18 without HRR protection.

In [None]:
print("\n" + "=" * 80)
print("Training Baseline Model (No HRR)...")
print("=" * 80)

# Initialize baseline model
baseline_model = resnet18(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(baseline_model.parameters(), lr=0.001)

# Training loop (standard training - no HRR)
baseline_model.train()
for epoch in range(TRAIN_EPOCHS):
    running_loss = 0.0
    correct = 0
    total = 0
    
    for i, (images, labels) in enumerate(trainloader):
        images, labels = images.to(device), labels.to(device)
        
        # Standard training (no binding/unbinding)
        optimizer.zero_grad()
        outputs = baseline_model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if (i + 1) % 50 == 0:
            print(f'Epoch [{epoch+1}/{TRAIN_EPOCHS}], Step [{i+1}/{len(trainloader)}], '
                  f'Loss: {running_loss/50:.4f}, Acc: {100*correct/total:.2f}%')
            running_loss = 0.0
    
    epoch_acc = 100 * correct / total
    print(f'Epoch [{epoch+1}/{TRAIN_EPOCHS}] completed. Accuracy: {epoch_acc:.2f}%')

### Test Baseline Model

In [None]:
# Test baseline model
baseline_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = baseline_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

baseline_acc = 100 * correct / total
print(f"\nBaseline Model Test Accuracy: {baseline_acc:.2f}%")

# Save baseline model
torch.save(baseline_model.state_dict(), 'baseline_model.pth')
print("Baseline model saved to 'baseline_model.pth'")

## 7. Final Summary and Analysis

In [None]:
print("\n" + "=" * 80)
print("PART II COMPLETE - Summary")
print("=" * 80)

print(f"\nHRR-Protected Model Accuracy: {test_acc:.2f}%")
print(f"Baseline Model Accuracy: {baseline_acc:.2f}%")
print(f"Accuracy Drop: {baseline_acc - test_acc:.2f}%")

print("\nInterpretation:")
if baseline_acc - test_acc < 5:
    print("  ✓ Excellent! Minimal accuracy loss for privacy protection.")
elif baseline_acc - test_acc < 10:
    print("  ✓ Good! Acceptable accuracy trade-off for privacy.")
else:
    print("  ⚠ Significant accuracy loss. May need more training or tuning.")

print("\nNext Steps:")
print("  1. Run evaluate_hrr_defense.py to test against RMIA attack")
print("  2. Compare AUC: baseline vs HRR-protected")
print("  3. Measure privacy-utility trade-off")

print("\nAll models have been saved and ready for evaluation.")
print("=" * 80)

## Conclusion

### What We Implemented:

1. **HRR Operations**: Binding/unbinding using 2D FFT
2. **Three Networks**: Main (server), Prediction (user), Adversarial (attacker)
3. **CSPS Training**: Adversarial training with gradient reversal
4. **Privacy Mechanism**: Output uninformative without secret

### Key Insights:

1. **Trade-off**: Small accuracy loss (~5-10%) for privacy
2. **Gradient Reversal**: Critical for forcing uninformative outputs
3. **New Secrets**: Different key per sample prevents pattern analysis
4. **FFT Efficiency**: Makes HRR practical (O(n log n) complexity)

### TASK 2.2 Questions Preview:

**Q1: How effective is HRR at preventing RMIA?**
- Will measure AUC reduction in next notebook
- Expected: 20-30% AUC drop (attack becomes much harder)

**Q2: Does HRR qualify as encryption?**
- No - it's obfuscation, not cryptographic encryption
- Provides practical privacy, not provable security
- Good for cost-effective protection

**Q3: Could attackers adapt?**
- Clustering attacks: Fail (ARI < 2%)
- Inversion attacks: Fail (poor reconstruction)
- Supervised learning: Limited success (2-5× random)
- Gradient reversal makes adaptation very hard