##README

## Caution!
This notebook will take significantly long time to complete and generate inverted images **(20-30 hours with GPU)**!

You are welcome to run this if you are okay for the execution times. We have used following system for this notebook:

- Google Console Platform - Vertex AI Jupyter Notebook
  - Machine type: n1-standard-16 (16 vCPUs, 60 GB RAM)
  - GPU: NVIDIA Tesla P100 x 1
  - Environment: TensorFlow Enterprise 2.16 (Intel® MKL-DNN/MKL)
  - CUDA Version 12


## Alternative
We have already run and obtain inverted image samples. This can be found at the ```inverted_images``` folder under ```/main/model_inversion/blakcbox/inverted_images``` path in the github page.

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# Imports and Functions

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
import torch.nn.functional as F
from functools import partial
import traceback

class BrainMRIClassifier(nn.Module):
    """CNN classifier for brain MRI images with 4 output classes"""
    def __init__(self):
        super(BrainMRIClassifier, self).__init__()
        self.features = nn.Sequential(
            # First block
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(32),
            # Second block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(64),
            # Third block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(128),
            # Fourth block
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(256),
            # Fifth block
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.BatchNorm2d(256),
        )
        # With input size 299x299 and 5 pooling layers, spatial dimensions ~9x9
        self.flat_features = 256 * 9 * 9
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.5),
            nn.Linear(self.flat_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 4)  # 4 classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

    def get_feature_maps(self, x):
        # Helper method to extract intermediate feature maps for inversion guidance
        feature_maps = []

        # Extract feature maps from each block
        x1 = self.features[0:4](x)  # First block
        feature_maps.append(x1)

        x2 = self.features[4:8](x1)  # Second block
        feature_maps.append(x2)

        x3 = self.features[8:12](x2)  # Third block
        feature_maps.append(x3)

        x4 = self.features[12:16](x3)  # Fourth block
        feature_maps.append(x4)

        x5 = self.features[16:20](x4)  # Fifth block
        feature_maps.append(x5)

        return feature_maps




# -------------------------------------------------------------
# 1. Model Loading and Feature Extraction
# -------------------------------------------------------------
def load_model(model_path, model_class=None):
    """
    Loads a pre-trained model from the given file path.
    If the loaded object is a state dictionary, a model_class must be provided
    to instantiate the model and load the state.
    """
    checkpoint = torch.load(model_path)
    # Check if checkpoint is a state dict (OrderedDict) and not a full model
    if isinstance(checkpoint, dict) and not hasattr(checkpoint, 'eval'):
        if model_class is None:
            raise ValueError("The checkpoint is a state_dict but no model_class was provided. "
                             "Please supply the model_class argument.")
        model = model_class()  # Instantiate your model architecture
        model.load_state_dict(checkpoint)
    else:
        model = checkpoint
    model.eval()  # set to evaluation mode
    return model


def extract_model_layers(model):
    """
    Analyzes a model and extracts information about its layers
    to help determine optimal feature extraction points.
    Enhanced to handle EnsembleModel.
    """
    layers_info = []

    # Special handling for EnsembleModel
    if isinstance(model, EnsembleModel):
        # Try to extract layers from the first wrapped model
        if model.models and len(model.models) > 0:
            return extract_model_layers(model.models[0])
        return []  # Empty list if no models

    # Standard layer extraction
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            layers_info.append({
                'name': name,
                'type': type(module).__name__,
                'params': sum(p.numel() for p in module.parameters())
            })
    return layers_info

def get_optimal_hook_layers(model):
    """
    Automatically determines optimal layers for feature extraction
    based on model architecture. Fixed to handle custom models like EnsembleModel.
    """
    layers_info = extract_model_layers(model)

    # Check if we found any layers
    if not layers_info:
        # For custom models like EnsembleModel, return a default layer name
        # that will be handled specially in the main function
        return ["default_hook"]

    if len(layers_info) <= 2:
        return [layers_info[-1]['name']]

    # For ResNet-like models, target specific blocks
    all_names = [layer['name'] for layer in layers_info]
    candidates = []

    # Look for common layer naming patterns
    for name in all_names:
        if any(pattern in name for pattern in ['layer2', 'layer3', 'layer4', 'block', 'features']):
            if 'conv' in name.lower() and 'weight' not in name:
                candidates.append(name)

    # Select a diverse set of layers (early, middle, late)
    if len(candidates) >= 3:
        idx1 = len(candidates) // 4
        idx2 = len(candidates) // 2
        idx3 = (3 * len(candidates)) // 4
        return [candidates[idx1], candidates[idx2], candidates[idx3]]
    elif len(candidates) > 0:
        return candidates
    else:
        # Fallback to selecting first, middle and last conv layers
        conv_layers = [l['name'] for l in layers_info if 'Conv' in l['type']]
        if len(conv_layers) >= 3:
            return [conv_layers[0], conv_layers[len(conv_layers)//2], conv_layers[-1]]
        elif len(conv_layers) > 0:
            return [conv_layers[-1]]
        else:
            # Final fallback - use the last layer of any type
            return [layers_info[-1]['name']] if layers_info else ["default_hook"]


# -------------------------------------------------------------
# 2. Enhanced Regularization Functions with Self-tuning
# -------------------------------------------------------------
def total_variation_loss(img, tv_weight):
    """
    Enhanced total variation loss with improved normalization.
    """
    batch_size = img.size()[0]
    h_x = img.size()[2]
    w_x = img.size()[3]
    count_h = (h_x - 1) * w_x
    count_w = h_x * (w_x - 1)

    # Use higher-order differences for better smoothness
    h_tv = torch.pow(img[:, :, 1:, :] - img[:, :, :h_x-1, :], 2).sum()
    w_tv = torch.pow(img[:, :, :, 1:] - img[:, :, :, :w_x-1], 2).sum()

    # Add second-order differences for capturing textures better
    if h_x > 2:
        h_tv2 = torch.pow(img[:, :, 2:, :] - 2*img[:, :, 1:-1, :] + img[:, :, :-2, :], 2).sum()
        h_tv = h_tv + 0.5 * h_tv2

    if w_x > 2:
        w_tv2 = torch.pow(img[:, :, :, 2:] - 2*img[:, :, :, 1:-1] + img[:, :, :, :-2], 2).sum()
        w_tv = w_tv + 0.5 * w_tv2

    return tv_weight * (h_tv / count_h + w_tv / count_w) / batch_size

def color_distribution_loss(img, color_weight):
    """
    Enhanced color loss that encourages natural color distributions.
    Penalizes both gray images and unnatural color distributions.
    For grayscale images, returns zero loss.
    """
    # If the image has less than 3 channels, bypass color loss.
    if img.shape[1] < 3:
        return torch.tensor(0.0, device=img.device)

    # Calculate mean and std across spatial dimensions for each channel
    mean_rgb = torch.mean(img, dim=[2, 3])
    std_rgb = torch.std(img, dim=[2, 3])

    # Split channels
    mr, mg, mb = torch.split(mean_rgb, 1, dim=1)
    sr, sg, sb = torch.split(std_rgb, 1, dim=1)

    # 1. Channel diversity loss - penalize when channels are too similar (gray image)
    diversity_loss = -torch.mean(torch.abs(mr - mg) + torch.abs(mr - mb) + torch.abs(mg - mb))

    # 2. Natural distribution loss - RGB channels typically have correlations
    # Encourage typical RGB relationships: G typically higher than R and B
    natural_mean_loss = torch.mean(torch.relu(mr - mg)) + torch.mean(torch.relu(mb - mg))

    # 3. Natural variance loss - encourage reasonable variance in each channel
    target_std = torch.tensor([0.2, 0.2, 0.2], device=img.device).view(1, 3)
    variance_loss = F.mse_loss(std_rgb, target_std)

    return color_weight * (diversity_loss + 0.5 * natural_mean_loss + variance_loss)


def perceptual_smoothness_loss(img, smooth_weight):
    """
    Multi-scale perceptual smoothness that better preserves edges.
    Fixed to handle padding correctly.
    """
    loss = 0.0

    # Multiple kernel sizes capture different levels of detail
    for kernel_size in [3, 5, 7]:
        # Create Gaussian-like kernel (approximation)
        sigma = kernel_size / 3
        grid_x = torch.arange(kernel_size, device=img.device) - (kernel_size - 1) / 2
        grid_y = grid_x.view(-1, 1)
        kernel_2d = torch.exp(-(grid_x.pow(2) + grid_y.pow(2)) / (2 * sigma**2))
        kernel_2d = kernel_2d / kernel_2d.sum()

        # Expand to 4D kernel
        channels = img.shape[1]
        kernel = kernel_2d.expand(channels, 1, kernel_size, kernel_size)

        # Apply smoothing
        padding = (kernel_size - 1) // 2
        smoothed = F.conv2d(img, kernel, padding=padding, groups=channels)

        # Compute image gradients using finite differences
        grad_x = torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1])
        grad_y = torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :])

        # Manually handle padding instead of using F.pad with 'replicate'
        # For grad_x: pad the last column by repeating the last valid column
        last_col = grad_x[:, :, :, -1:]
        grad_x = torch.cat([grad_x, last_col], dim=3)

        # For grad_y: pad the last row by repeating the last valid row
        last_row = grad_y[:, :, -1:, :]
        grad_y = torch.cat([grad_y, last_row], dim=2)

        # Compute edge-preserving weights: lower weight near edges, higher in smooth regions
        edge_weights = torch.exp(-50 * (grad_x.pow(2) + grad_y.pow(2)))

        # Calculate weighted difference
        weighted_diff = edge_weights * (img - smoothed).pow(2)
        loss += weighted_diff.mean()

    return smooth_weight * (loss / 3)  # Average across scales

def naturalness_prior_loss(img, natural_weight):
    """
    Naturalness prior encouraging realistic image statistics.
    Based on natural image priors in the gradient domain.
    """
    # Gradient in x and y directions
    grad_x = img[:, :, :, 1:] - img[:, :, :, :-1]
    grad_y = img[:, :, 1:, :] - img[:, :, :-1, :]

    # Natural images follow a heavy-tailed distribution in gradient domain
    # We can approximate this with a combination of L1 and log penalties
    l1_grad = torch.mean(torch.abs(grad_x)) + torch.mean(torch.abs(grad_y))

    # Log penalty encourages sparse but strong gradients (edges)
    eps = 1e-5
    log_grad = torch.mean(torch.log(torch.abs(grad_x) + eps)) + torch.mean(torch.log(torch.abs(grad_y) + eps))

    return natural_weight * (l1_grad - 0.1 * log_grad)

def compute_fft_loss(img, fft_weight):
    """
    Spectral loss operating in the frequency domain to encourage
    natural frequency distributions found in real images.
    Fixed to handle cuFFT size requirements.
    """
    # Convert to grayscale for frequency analysis
    if img.shape[1] >= 3:
        # RGB to grayscale
        gray = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
    else:
        # Already grayscale
        gray = img

    # Ensure dimensions are compatible with cuFFT (powers of 2, 3, 5, 7)
    # A simple approach is to pad to the next power of 2
    h, w = gray.shape[2], gray.shape[3]
    padded_h = 2**int(np.ceil(np.log2(h)))
    padded_w = 2**int(np.ceil(np.log2(w)))

    if h != padded_h or w != padded_w:
        # Pad to power of 2 dimensions
        padding_h = padded_h - h
        padding_w = padded_w - w
        pad_h1, pad_h2 = padding_h // 2, padding_h - (padding_h // 2)
        pad_w1, pad_w2 = padding_w // 2, padding_w - (padding_w // 2)

        # Use zero padding
        gray = F.pad(gray, (pad_w1, pad_w2, pad_h1, pad_h2), mode='constant', value=0)

    try:
        # Compute 2D FFT
        fft = torch.fft.fft2(gray)
        fft_mag = torch.abs(fft)

        # Shift to center low frequencies
        fft_mag = torch.fft.fftshift(fft_mag)

        # Create a reference power spectrum that follows 1/f distribution
        h, w = fft_mag.shape[-2:]
        cy, cx = h // 2, w // 2
        y_grid, x_grid = torch.meshgrid(torch.arange(h, device=img.device),
                                        torch.arange(w, device=img.device),
                                        indexing='ij')
        y_grid = y_grid - cy
        x_grid = x_grid - cx
        dist = torch.sqrt(x_grid.pow(2) + y_grid.pow(2)) + 1e-5
        target_spectrum = 1 / dist

        # Normalize target and actual spectrum
        target_spectrum = target_spectrum / target_spectrum.sum()
        actual_spectrum = fft_mag / (fft_mag.sum() + 1e-8)

        # Compute KL divergence as a measure of distribution difference
        eps = 1e-8
        kl_div = target_spectrum * torch.log((target_spectrum + eps) / (actual_spectrum + eps))

        return fft_weight * kl_div.sum()

    except RuntimeError:
        # Fallback if FFT still fails: return a small constant loss
        print("Warning: FFT computation failed, using fallback loss")
        return fft_weight * torch.tensor(0.1, device=img.device)

# -------------------------------------------------------------
# 3. Advanced Initialization Strategies
# -------------------------------------------------------------
def get_initial_image(strategy='mixed', size=(1, 3, 224, 224), device='cpu', target_class=None, channels=None):
    """
    Enhanced initialization strategies for faster convergence.
    Accepts a 'channels' parameter. If not provided, it defaults to size[1].
    """
    if channels is None:
        channels = size[1]

    if strategy == 'mean':
        if channels == 1:
            mean = torch.tensor([0.5]).view(1, 1, 1, 1).to(device)
        else:
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
        img = mean.repeat(size[0], 1, size[2], size[3])
    elif strategy == 'gaussian':
        if channels == 1:
            mean = torch.tensor([0.5]).view(1, 1, 1, 1).to(device)
            std = torch.tensor([0.25]).view(1, 1, 1, 1).to(device)
        else:
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
            std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
        img = torch.randn(size, device=device) * std + mean
    elif strategy == 'pca':
        img = torch.zeros(size, device=device)
        for c in range(channels):
            freq_representation = torch.zeros((size[2], size[3]), dtype=torch.complex64, device=device)
            r_max = min(size[2], size[3]) // 8
            cy, cx = size[2] // 2, size[3] // 2
            for ky in range(size[2]):
                for kx in range(size[3]):
                    y_rel = (ky - cy) / r_max
                    x_rel = (kx - cx) / r_max
                    dist = torch.sqrt(y_rel**2 + x_rel**2)
                    if dist < 1.0:
                        phase = torch.rand(1, device=device) * 2 * np.pi
                        amplitude = torch.exp(-3.0 * dist)
                        freq_representation[ky, kx] = amplitude * torch.exp(1j * phase)
            channel_data = torch.fft.ifft2(freq_representation).real
            channel_data = (channel_data - channel_data.min()) / (channel_data.max() - channel_data.min() + 1e-8)
            img[0, c] = channel_data
    elif strategy == 'mixed':
        if channels == 1:
            mean = torch.tensor([0.5]).view(1, 1, 1, 1).to(device)
        else:
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
        base = mean.repeat(size[0], 1, size[2], size[3])
        freq_img = torch.zeros(size, device=device)
        for octave in range(3):
            scale_factor = 2 ** octave
            noise_size = (size[2] // scale_factor, size[3] // scale_factor)
            noise = torch.randn((1, channels, *noise_size), device=device)
            noise = F.interpolate(noise, size=(size[2], size[3]), mode='bilinear', align_corners=False)
            freq_img += noise * (0.5 ** octave)
        freq_img = (freq_img - freq_img.min()) / (freq_img.max() - freq_img.min() + 1e-8) * 0.2
        img = base + freq_img
    elif strategy == 'class_prior' and target_class is not None:
        hue_shift = (target_class % 10) / 10.0
        if channels == 1:
            base_color = torch.tensor([0.5]).view(1, 1, 1, 1).to(device)
        else:
            r = 0.5 + 0.4 * torch.cos(2 * np.pi * hue_shift)
            g = 0.5 + 0.4 * torch.cos(2 * np.pi * (hue_shift + 1/3))
            b = 0.5 + 0.4 * torch.cos(2 * np.pi * (hue_shift + 2/3))
            base_color = torch.tensor([r, g, b]).view(1, 3, 1, 1).to(device)
        img = base_color.repeat(size[0], 1, size[2], size[3])
        noise = torch.randn((1, channels, size[2]//4, size[3]//4), device=device)
        noise = F.interpolate(noise, size=(size[2], size[3]), mode='bilinear', align_corners=False)
        noise = (noise - noise.min()) / (noise.max() - noise.min() + 1e-8) * 0.15
        img = img + noise
    else:
        img = torch.randn(size, device=device)

    img = torch.clamp(img, 0, 1)
    img.requires_grad = True
    return img

# -------------------------------------------------------------
# 4. Improved Feature Matching with Adaptive Statistics
# -------------------------------------------------------------
class FeatureHook:
    """
    Advanced feature hook with statistics tracking capabilities.
    """
    def __init__(self, layer_name, adaptive_stats=True):
        self.layer_name = layer_name
        self.features = None
        self.adaptive_stats = adaptive_stats
        self.running_mean = None
        self.running_std = None
        self.momentum = 0.9

    def hook_fn(self, module, input, output):
        self.features = output

        # Update running statistics for adaptive matching
        if self.adaptive_stats:
            current_mean = output.mean(dim=[0, 2, 3]).detach()
            current_std = output.std(dim=[0, 2, 3]).detach()

            if self.running_mean is None:
                self.running_mean = current_mean
                self.running_std = current_std
            else:
                self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * current_mean
                self.running_std = self.momentum * self.running_std + (1 - self.momentum) * current_std

    def get_target_stats(self):
        """
        Returns target statistics for feature matching.
        Adapts to the evolving feature distribution during optimization.
        """
        if self.adaptive_stats and self.running_mean is not None:
            return {
                'mean': self.running_mean,
                'std': self.running_std
            }
        else:
            # Fallback to reasonable defaults
            if self.features is not None:
                num_channels = self.features.shape[1]
                device = self.features.device
                return {
                    'mean': torch.zeros(num_channels, device=device),
                    'std': torch.ones(num_channels, device=device)
                }
            return None

    def compute_feature_loss(self, target_stats=None):
        """
        Compute feature distribution matching loss.
        """
        if self.features is None:
            return torch.tensor(0.0, device='cpu')

        if target_stats is None:
            target_stats = self.get_target_stats()
            if target_stats is None:
                return torch.tensor(0.0, device=self.features.device)

        current_mean = self.features.mean(dim=[0, 2, 3])
        current_std = self.features.std(dim=[0, 2, 3])

        # Mean and std matching with additional correlation structure
        mean_loss = F.mse_loss(current_mean, target_stats['mean'])
        std_loss = F.mse_loss(current_std, target_stats['std'])

        # Optionally add correlation structure matching
        # (omitted for simplicity but could be added here)

        return mean_loss + std_loss

# -------------------------------------------------------------
# 5. Advanced Multi-scale and Multi-resolution Optimization
# -------------------------------------------------------------
def progressive_model_inversion_attack(
        model,
        target_class,
        scales=None,
        iterations_per_scale=500,
        lr_initial=0.1,
        lr_final=0.001,
        auto_schedule_hyperparams=True,
        regularization_weights={
            'tv': 5e-4,
            'l2': 1e-4,
            'color': 5e-5,
            'smooth': 1e-4,
            'natural': 2e-4,
            'fft': 1e-5,
            'feature': 5e-3
        },
        init_strategy='mixed',
        hook_layers=None,
        verbose=False,
        log_interval=50,
        init_img=None):

    import sys  # for dynamic printing
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Get model information
    model_info = get_model_info(model)
    input_channels = model_info["input_channels"] or 1

    # For BrainMRIClassifier, we need input size 299x299
    if scales is None:
        scales = [(1, input_channels, 75, 75),
                  (1, input_channels, 150, 150),
                  (1, input_channels, 299, 299)]

    # Force final scale to be 299x299
    if scales[-1][2:] != (299, 299):
        scales[-1] = (1, input_channels, 299, 299)

    print(f"Using scales: {scales} with {input_channels} input channels")

    # Get optimal hook layers if not specified
    if hook_layers is None:
        hook_layers = get_optimal_hook_layers(model)

    # Initialize feature hooks
    hooks = []
    feature_extractors = {}

    # Special handling for ensemble model
    if "default_hook" in hook_layers:
        print("Using default hooks for custom model")
    else:
        for layer_name in hook_layers:
            for name, module in model.named_modules():
                if layer_name in name:
                    extractor = FeatureHook(layer_name)
                    feature_extractors[layer_name] = extractor
                    hooks.append(module.register_forward_hook(extractor.hook_fn))
                    break

    # Initialize with the coarsest scale or provided init_img
    if init_img is None:
        current_img = get_initial_image(init_strategy, size=scales[0], device=device,
                                       target_class=target_class, channels=input_channels)
    else:
        current_img = init_img.to(device)

    current_img.requires_grad = True

    # Training losses for visualization
    losses_history = {
        'total': [], 'ce': [], 'tv': [], 'l2': [], 'color': [],
        'smooth': [], 'natural': [], 'fft': [], 'feature': []
    }

    # Cross-entropy loss function
    ce_loss_fn = nn.CrossEntropyLoss()
    target = torch.tensor([target_class], device=device)

    # Track best image and score
    best_img = current_img.clone().detach()
    best_score = float('inf')

    # Process each scale
    for scale_idx, scale in enumerate(scales):
        # Print phase information on its own line.
        print(f"\nOptimizing at scale {scale[2]}x{scale[3]}")

        # Resize image if needed
        if current_img.shape != scale:
            with torch.no_grad():
                resized_img = F.interpolate(current_img.detach(), size=scale[2:],
                                           mode='bilinear', align_corners=False)
            current_img = resized_img.clone()
            current_img.requires_grad = True

        # Determine iterations for this scale
        scale_factor = scale[2] / scales[0][2]
        iterations = int(iterations_per_scale * min(scale_factor, 2.0))

        # Initialize optimizer and scheduler
        lr = lr_initial * (lr_final / lr_initial) ** (scale_idx / len(scales))
        optimizer = optim.AdamW([current_img], lr=lr, weight_decay=0.01)
        scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=iterations,
                             pct_start=0.3, anneal_strategy='cos')

        # Define scheduler for hyperparameters
        if auto_schedule_hyperparams:
            def schedule_weight(base_weight, iteration, iterations, warmup=0.2):
                if iteration < iterations * warmup:
                    factor = iteration / (iterations * warmup)
                else:
                    factor = 1.0
                return base_weight * factor
        else:
            schedule_weight = lambda w, i, t, **kwargs: w

        # Optimization loop for this scale
        for iteration in range(iterations):
            optimizer.zero_grad()

            # Forward pass - ALWAYS resize to 299x299 for model output
            resized_img = F.interpolate(current_img, size=(299, 299),
                                       mode='bilinear', align_corners=False)
            output = model(resized_img)

            # Calculate losses
            ce_loss = ce_loss_fn(output, target)
            tv_loss = total_variation_loss(current_img,
                                         schedule_weight(regularization_weights['tv'], iteration, iterations))
            l2_loss = schedule_weight(regularization_weights['l2'], iteration, iterations) * torch.norm(current_img, 2)
            color_loss = color_distribution_loss(current_img,
                                               schedule_weight(regularization_weights['color'], iteration, iterations))
            smooth_loss = perceptual_smoothness_loss(current_img,
                                                   schedule_weight(regularization_weights['smooth'], iteration, iterations))
            natural_loss = naturalness_prior_loss(current_img,
                                                schedule_weight(regularization_weights['natural'], iteration, iterations, warmup=0.4))
            fft_loss = compute_fft_loss(current_img,
                                       schedule_weight(regularization_weights['fft'], iteration, iterations, warmup=0.5))

            # Feature matching losses
            feature_loss = 0.0
            if feature_extractors:
                feature_w = schedule_weight(regularization_weights['feature'] * scale_factor,
                                         iteration, iterations, warmup=0.3)
                for layer_name, extractor in feature_extractors.items():
                    feature_loss += extractor.compute_feature_loss() * feature_w

            # Combine all losses
            total_loss = ce_loss + tv_loss + l2_loss + color_loss + smooth_loss + natural_loss + fft_loss + feature_loss

            # Record losses occasionally for visualization
            if iteration % log_interval == 0 or iteration == iterations - 1:
                losses_history['total'].append(total_loss.item())
                losses_history['ce'].append(ce_loss.item())
                losses_history['tv'].append(tv_loss.item())
                losses_history['l2'].append(l2_loss.item())
                losses_history['color'].append(color_loss.item())
                losses_history['smooth'].append(smooth_loss.item())
                losses_history['natural'].append(natural_loss.item())
                losses_history['fft'].append(fft_loss.item())
                losses_history['feature'].append(feature_loss.item())

            # Backward pass and update
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_([current_img], max_norm=1.0)
            optimizer.step()
            scheduler.step()

            # Ensure image remains in valid range
            current_img.data = torch.clamp(current_img.data, 0, 1)

            # Save best result
            score = ce_loss.item() + 0.1 * (tv_loss.item() + smooth_loss.item() + natural_loss.item())
            if score < best_score:
                best_score = score
                best_img = current_img.clone().detach()

            # Dynamic update of progress on the same line (if verbose)
            if verbose and (iteration % log_interval == 0 or iteration == iterations - 1):
                sys.stdout.write(f"\rScale {scale[2]}x{scale[3]} Iter {iteration}/{iterations} - Loss: {total_loss.item():.4f}, CE: {ce_loss.item():.4f}")
                sys.stdout.flush()
        # Move to next line after finishing iterations for this scale
        sys.stdout.write("\n")
        sys.stdout.flush()

    # Clean up hooks
    for hook in hooks:
        hook.remove()

    # Apply post-processing to the best image
    try:
        final_img = advanced_post_process(best_img)
        print("Post-processing completed successfully")
    except Exception as e:
        print(f"Error in post-processing: {str(e)}")
        final_img = best_img  # Use best image without post-processing if there's an error

    return final_img


def plot_losses(losses_history, filename):
    """
    Create a detailed loss curve plot.
    """
    plt.figure(figsize=(12, 8))
    # Main plot with overall loss
    plt.subplot(2, 1, 1)
    plt.plot(losses_history['total'], 'k-', label='Total Loss')
    plt.plot(losses_history['ce'], 'r-', label='CE Loss')
    plt.title('Overall Loss Progress')
    plt.yscale('log')
    plt.legend()

    # Subplot with individual regularization terms
    plt.subplot(2, 1, 2)
    for key in ['tv', 'l2', 'color', 'smooth', 'natural', 'fft', 'feature']:
        if losses_history[key]:
            plt.plot(losses_history[key], label=f'{key} Loss')
    plt.title('Regularization Losses')
    plt.yscale('log')
    plt.legend()

    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

# -------------------------------------------------------------
# 6. Advanced Post-Processing Pipeline
# -------------------------------------------------------------
def advanced_post_process(img):
    """
    Advanced post-processing pipeline for improved visual quality.
    Fixed to handle both grayscale and RGB images.
    """
    # Convert tensor to numpy, handling both color and grayscale images
    img_cpu = img.squeeze().cpu()

    # Check if grayscale (1 channel) or color (3 channels)
    if len(img_cpu.shape) == 2:
        # Grayscale image (H, W)
        img_np = img_cpu.numpy()
        is_grayscale = True
    elif len(img_cpu.shape) == 3 and img_cpu.shape[0] == 1:
        # Single channel image in format (1, H, W)
        img_np = img_cpu.numpy()[0]
        is_grayscale = True
    else:
        # Color image (C, H, W) -> (H, W, C)
        img_np = img_cpu.numpy().transpose(1, 2, 0)
        is_grayscale = False

    # 1. Contrast stretching with percentile-based normalization
    p2, p98 = np.percentile(img_np, (2, 98))
    img_np = np.clip((img_np - p2) / (p98 - p2 + 1e-8), 0, 1)

    # For grayscale images, skip color-specific enhancements
    if not is_grayscale:
        # 2. Local contrast enhancement (CLAHE-like)
        img_np = local_contrast_enhance(img_np)

        # 3. Histogram equalization with color preservation
        img_np = histogram_equalization_with_color(img_np)

        # 5. Color balancing
        img_np = color_balance(img_np)
    else:
        # Apply grayscale-specific enhancements
        hist, bins = np.histogram(img_np.flatten(), 256, [0, 1])
        cdf = hist.cumsum()
        cdf = cdf / (cdf[-1] + 1e-8)  # Normalize
        img_np = np.interp(img_np.flatten(), bins[:-1], cdf).reshape(img_np.shape)

    # 4. Detail enhancement with edge preservation (works for both color and grayscale)
    if is_grayscale:
        # For grayscale
        img_np = edge_preserving_sharpen_gray(img_np)

        # Convert back to tensor (add channel dimension)
        img_t = torch.from_numpy(img_np).unsqueeze(0).unsqueeze(0).float()
    else:
        # For color images
        img_np = edge_preserving_sharpen(img_np)

        # Convert back to tensor
        img_t = torch.from_numpy(img_np.transpose(2, 0, 1)).unsqueeze(0).float()

    return img_t

def edge_preserving_sharpen_gray(img, sigma=0.5, amount=1.0):
    """
    Apply edge-preserving sharpening to grayscale images.
    """
    # Create a Gaussian kernel for edge detection
    kernel_size = max(3, int(2 * sigma) * 2 + 1)
    kernel_1d = np.exp(-np.arange(-(kernel_size//2), kernel_size//2 + 1)**2 / (2 * sigma**2))
    kernel_1d = kernel_1d / kernel_1d.sum()
    kernel_2d = np.outer(kernel_1d, kernel_1d)

    # Apply Gaussian blur
    blurred = convolve2d(img, kernel_2d, mode='same', boundary='symm')

    # Calculate edge mask
    gx = convolve2d(img, np.array([[-1, 0, 1]]), mode='same', boundary='symm')
    gy = convolve2d(img, np.array([[-1], [0], [1]]), mode='same', boundary='symm')
    gradient_mag = np.sqrt(gx**2 + gy**2)

    # Normalize and invert to give less weight to edges
    edge_mask = 1 - np.clip(gradient_mag / (gradient_mag.max() + 1e-8), 0, 1)**2

    # Apply sharpening with edge preservation
    high_freq = img - blurred
    sharpened = img + amount * high_freq * edge_mask

    return np.clip(sharpened, 0, 1)

def local_contrast_enhance(img, tile_size=16, clip_limit=3.0):
    """
    Simplified CLAHE-like local contrast enhancement.
    """
    result = np.zeros_like(img)

    # Process each channel
    for c in range(img.shape[2]):
        channel = img[:, :, c]
        height, width = channel.shape

        # Process each tile
        for y in range(0, height, tile_size):
            for x in range(0, width, tile_size):
                # Get the tile
                y_end = min(y + tile_size, height)
                x_end = min(x + tile_size, width)
                tile = channel[y:y_end, x:x_end]

                # Skip empty tiles
                if tile.size == 0:
                    continue

                # Compute histogram
                hist, bins = np.histogram(tile.flatten(), 256, [0, 1])

                # Clip histogram
                if clip_limit > 0:
                    clip = clip_limit * tile.size / 256
                    hist_sum = 0
                    for i in range(len(hist)):
                        if hist[i] > clip:
                            hist_sum += hist[i] - clip
                            hist[i] = clip

                    # Redistribute clipped pixels
                    redistr = hist_sum / 256
                    for i in range(len(hist)):
                        hist[i] += redistr

                # Calculate CDF
                cdf = hist.cumsum()
                cdf = cdf / cdf[-1]  # Normalize

                # Apply histogram equalization to the tile
                tile_result = np.interp(tile.flatten(), bins[:-1], cdf)
                result[y:y_end, x:x_end, c] = tile_result.reshape(tile.shape)

    return result

def histogram_equalization_with_color(img):
    """
    Performs histogram equalization while preserving color relationships.
    Works in YCbCr color space to maintain color while enhancing contrast.
    """
    # Convert to YCbCr-like space (simple approximation)
    y = 0.299 * img[:,:,0] + 0.587 * img[:,:,1] + 0.114 * img[:,:,2]
    cb = -0.1687 * img[:,:,0] - 0.3313 * img[:,:,1] + 0.5 * img[:,:,2] + 0.5
    cr = 0.5 * img[:,:,0] - 0.4187 * img[:,:,1] - 0.0813 * img[:,:,2] + 0.5

    # Apply histogram equalization to Y channel only
    hist, bins = np.histogram(y.flatten(), 256, [0, 1])
    cdf = hist.cumsum()
    cdf = cdf / cdf[-1]  # Normalize
    y_eq = np.interp(y.flatten(), bins[:-1], cdf).reshape(y.shape)

    # Convert back to RGB
    r = y_eq + 1.402 * (cr - 0.5)
    g = y_eq - 0.344136 * (cb - 0.5) - 0.714136 * (cr - 0.5)
    b = y_eq + 1.772 * (cb - 0.5)

    # Combine and clip
    result = np.stack([r, g, b], axis=2)
    return np.clip(result, 0, 1)

def edge_preserving_sharpen(img, sigma=0.5, amount=1.0):
    """
    Apply edge-preserving sharpening using a bilateral filter approximation.
    """
    # Create a Gaussian kernel for edge detection
    kernel_size = max(3, int(2 * sigma) * 2 + 1)
    kernel_1d = np.exp(-np.arange(-(kernel_size//2), kernel_size//2 + 1)**2 / (2 * sigma**2))
    kernel_1d = kernel_1d / kernel_1d.sum()
    kernel_2d = np.outer(kernel_1d, kernel_1d)

    # Apply filtering
    blurred = np.zeros_like(img)
    for c in range(img.shape[2]):
        # Apply Gaussian blur for each channel
        blurred[:,:,c] = convolve2d(img[:,:,c], kernel_2d, mode='same', boundary='symm')

    # Calculate edge mask
    edge_mask = np.ones_like(img)
    for c in range(img.shape[2]):
        # Create edge mask using gradient magnitude
        gx = convolve2d(img[:,:,c], np.array([[-1, 0, 1]]), mode='same', boundary='symm')
        gy = convolve2d(img[:,:,c], np.array([[-1], [0], [1]]), mode='same', boundary='symm')
        gradient_mag = np.sqrt(gx**2 + gy**2)

        # Normalize and invert to give less weight to edges
        edge_mask[:,:,c] = 1 - np.clip(gradient_mag / gradient_mag.max(), 0, 1)**2

    # Apply sharpening with edge preservation
    high_freq = img - blurred
    sharpened = img + amount * high_freq * edge_mask

    return np.clip(sharpened, 0, 1)

def convolve2d(img, kernel, mode='same', boundary='symm'):
    """
    Simple 2D convolution implementation to avoid scipy dependency.
    """
    k_h, k_w = kernel.shape
    i_h, i_w = img.shape

    # Pad the image based on boundary mode
    if boundary == 'symm':
        pad_h = k_h // 2
        pad_w = k_w // 2
        padded = np.pad(img, ((pad_h, pad_h), (pad_w, pad_w)), mode='symmetric')
    else:
        pad_h = k_h // 2
        pad_w = k_w // 2
        padded = np.pad(img, ((pad_h, pad_h), (pad_w, pad_w)), mode='constant')

    # Output array
    out = np.zeros_like(img)

    # Apply convolution
    for i in range(i_h):
        for j in range(i_w):
            out[i, j] = np.sum(padded[i:i+k_h, j:j+k_w] * kernel)

    return out

def color_balance(img, clip_percent=1):
    """
    Automatically balance colors by applying separate contrast stretching to each channel.
    """
    result = np.zeros_like(img)

    for c in range(img.shape[2]):
        channel = img[:,:,c]
        # Calculate percentile values
        low = np.percentile(channel, clip_percent)
        high = np.percentile(channel, 100 - clip_percent)

        # Apply contrast stretching
        result[:,:,c] = np.clip((channel - low) / (high - low), 0, 1)

    return result

# -------------------------------------------------------------
# 7. Highly Advanced Ensemble Attack with Knowledge Distillation
# -------------------------------------------------------------
def advanced_ensemble_attack(
        models,
        target_class,
        weights=None,
        distill_iterations=500,
        scales=None,
        verbose=True,
        log_interval=10,
        **attack_params):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Move all models to the same device
    for i in range(len(models)):
        models[i] = models[i].to(device)

    # Normalize weights if not provided
    if weights is None:
        weights = [1.0/len(models)] * len(models)

    # Get model info
    model_info = get_model_info(models[0])
    input_channels = model_info["input_channels"] or 1

    # Set scales with 299x299 final size
    if scales is None:
        scales = [(1, input_channels, 112, 112),
                 (1, input_channels, 224, 224),
                 (1, input_channels, 299, 299)]

    # Force 299x299 final scale
    if scales[-1][2:] != (299, 299):
        if verbose:
            print(f"Adjusting final scale to (1, {input_channels}, 299, 299)")
        scales[-1] = (1, input_channels, 299, 299)

    # Phase 1: Individual model inversions
    print("\nPhase 1: Individual model inversions")
    individual_images = []

    indiv_attack_params = attack_params.copy()
    if 'iterations_per_scale' in indiv_attack_params:
        indiv_attack_params['iterations_per_scale'] = indiv_attack_params['iterations_per_scale'] // 2

    for i, model in enumerate(models):
        print(f"  Inverting model {i+1}/{len(models)}")
        img = progressive_model_inversion_attack(model, target_class, scales=scales[:2], **indiv_attack_params)
        if img is not None:  # Ensure we have a valid image
            individual_images.append(img)
        else:
            print(f"  WARNING: Model {i+1} produced no image")
            # Create a default random noise image
            default_img = torch.rand((1, input_channels, 224, 224), device=device)
            individual_images.append(default_img)

    # Verify we have at least one valid image
    if not individual_images:
        print("ERROR: No valid images produced. Returning random noise.")
        return torch.rand((1, input_channels, 299, 299), device=device)

    # Phase 2: Distillation
    print("\nPhase 2: Knowledge distillation from individual reconstructions")

    # Resize all images to the same dimensions before averaging
    standard_size = (1, input_channels, 224, 224)  # Use a standard size for distillation
    resized_images = []

    print("  Standardizing image dimensions...")
    for i, img in enumerate(individual_images):
        print(f"  Image {i+1} shape before resize: {img.shape}")
        resized = F.interpolate(img.to(device), size=standard_size[2:], mode='bilinear', align_corners=False)
        print(f"  Image {i+1} shape after resize: {resized.shape}")
        resized_images.append(resized)

    # Initialize ensemble seed with weighted average of RESIZED images
    ensemble_seed = torch.zeros(standard_size, device=device)
    for img, weight in zip(resized_images, weights):
        ensemble_seed += weight * img

    # Ensure ensemble_seed is valid and on the correct device
    ensemble_seed = torch.clamp(ensemble_seed, 0, 1)  # Ensure valid range
    ensemble_seed.requires_grad = True

    # Run distillation
    optimizer = optim.Adam([ensemble_seed], lr=0.01)
    scheduler = CosineAnnealingLR(optimizer, T_max=distill_iterations, eta_min=0.001)

    # Debug print
    print(f"  Ensemble seed shape: {ensemble_seed.shape}, device: {ensemble_seed.device}")
    print(f"  Model devices: {[next(m.parameters()).device for m in models]}")

    feature_hook = FeatureHook("features", adaptive_stats=False)
    hooks = []

    for iter in range(distill_iterations):
        optimizer.zero_grad()

        # Resize to 299x299 for model
        resized_seed = F.interpolate(ensemble_seed, size=(299, 299),
                                   mode='bilinear', align_corners=False)

        # Ensemble classification loss
        ce_loss = 0
        for model, weight in zip(models, weights):
            output = model(resized_seed)
            target = torch.tensor([target_class], device=device)
            ce_loss += weight * F.cross_entropy(output, target)

        # Feature consistency with individual reconstructions
        feature_loss = 0
        for i, (model, indiv_img) in enumerate(zip(models, individual_images)):
            # Register temporary hooks
            for name, module in model.named_modules():
                if isinstance(module, nn.Conv2d) and 'features' in name:
                    hook = module.register_forward_hook(feature_hook.hook_fn)
                    hooks.append(hook)
                    break

            # Get individual image features
            with torch.no_grad():
                indiv_img_resized = F.interpolate(indiv_img.to(device), size=(299, 299),
                                                mode='bilinear', align_corners=False)
                _ = model(indiv_img_resized)
                target_features = feature_hook.features.detach() if feature_hook.features is not None else None

            # Skip if we couldn't extract features
            if target_features is None:
                continue

            # Get current seed features
            _ = model(resized_seed)
            current_features = feature_hook.features

            # Compute feature consistency loss if we have both feature sets
            if current_features is not None and target_features is not None:
                feature_loss += weights[i] * F.mse_loss(current_features, target_features)

            # Clear hooks
            for hook in hooks:
                hook.remove()
            hooks = []

        # Regularization
        tv_loss = total_variation_loss(ensemble_seed, 1e-4)

        # Total loss
        total_loss = ce_loss + 0.5 * feature_loss + tv_loss

        # Update
        total_loss.backward()
        optimizer.step()
        scheduler.step()

        # Ensure valid range
        ensemble_seed.data = torch.clamp(ensemble_seed.data, 0, 1)

        if verbose and (iter % log_interval == 0 or iter == distill_iterations - 1):
            print(f"  Distillation iter {iter}/{distill_iterations}, "
                  f"Loss: {total_loss.item():.4f}, CE: {ce_loss.item():.4f}")

    # Phase 3: Final ensemble optimization
    print("\nPhase 3: Final ensemble optimization")

    # Create ensemble model wrapper
    ensemble_model = EnsembleModel(models, weights).to(device)

    # Ensure we have a valid seed image
    final_seed = ensemble_seed.detach().to(device)
    if torch.isnan(final_seed).any() or torch.isinf(final_seed).any():
        print("WARNING: NaN/Inf values detected in distilled image. Using random initialization.")
        final_seed = torch.rand((1, input_channels, 224, 224), device=device)

    # Final inversion with ensemble model
    final_img = progressive_model_inversion_attack(
        ensemble_model,
        target_class,
        scales=scales,
        init_img=final_seed,
        **attack_params
    )

    # Final sanity check
    if final_img is None or torch.isnan(final_img).any() or torch.isinf(final_img).any():
        print("ERROR: Invalid final image. Returning the distilled seed instead.")
        return final_seed

    return final_img

class EnsembleModel(nn.Module):
    def __init__(self, models, weights):
        super().__init__()
        self.models = nn.ModuleList(models)  # Use ModuleList for proper device handling
        self.weights = weights

        # For feature extraction capability
        # Get the first model's features module if available
        self.features = None
        for model in models:
            if hasattr(model, 'features'):
                self.features = model.features
                break

    def forward(self, x):
        outputs = []
        for model in self.models:
            outputs.append(model(x))

        # Weighted average of logits
        result = torch.zeros_like(outputs[0])
        for output, weight in zip(outputs, self.weights):
            result += weight * output
        return result

    # Add direct feature extraction method to simplify hook registration
    def get_feature_maps(self, x):
        if hasattr(self.models[0], 'get_feature_maps'):
            return self.models[0].get_feature_maps(x)
        return []  # Return empty list if not available

# -------------------------------------------------------------
# 8. Main Routine: State-of-the-Art Federated MI Attack Framework
# -------------------------------------------------------------
def get_model_info(model):
    """Extract input channels, expected input size, and classes from model architecture"""
    input_channels = None
    num_classes = None
    expected_input_size = None

    # Find first conv layer for input channels
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            input_channels = module.in_channels
            break

    # Try to determine expected input size from model
    if hasattr(model, 'flat_features'):
        expected_input_size = model.flat_features
    else:
        # Default to standard size for BrainMRIClassifier
        expected_input_size = 256 * 9 * 9  # As defined in BrainMRIClassifier

    # Find output classes from last linear layer
    last_linear = None
    for module in model.modules():
        if isinstance(module, nn.Linear):
            last_linear = module
    if last_linear:
        num_classes = last_linear.out_features

    return {
        "input_channels": input_channels,
        "num_classes": num_classes,
        "expected_input_size": expected_input_size
    }

# Add this function to debug model dimensions:
def debug_model_dimensions(model, input_size=(1, 1, 299, 299)):
    """Print the output shape at each layer to diagnose dimension issues"""
    device = next(model.parameters()).device
    x = torch.randn(input_size).to(device)

    # Track feature dimensions
    print(f"Input shape: {x.shape}")

    # Features
    for i, layer in enumerate(model.features):
        x = layer(x)
        print(f"After features[{i}] {type(layer).__name__}: {x.shape}")

    # Check final flattened size
    flat_size = x.view(x.size(0), -1).shape[1]
    print(f"Flattened size: {flat_size}, Expected: {model.flat_features}")

    # Try classifier
    try:
        output = model.classifier(x)
        print(f"Final output shape: {output.shape}")
    except Exception as e:
        print(f"Error in classifier: {str(e)}")

# Modification Needed on below cell!

Please modify the model paths (```MODEL_PATHS```) under main function below.

In [None]:
def main():
    import os
    import sys

    # ----------------------------------------------------------------
    # Step 1: Create the base save directory dynamically
    # ----------------------------------------------------------------
    base_save_path = os.path.join(os.getcwd(), "inverted_images", "bench")
    if not os.path.exists(base_save_path):
        os.makedirs(base_save_path)
    sys.stdout.write(f"Base save path set to: {base_save_path}\n")
    sys.stdout.flush()

    # ----------------------------------------------------------------
    # Step 2: Set up file paths for pre-trained models
    # ----------------------------------------------------------------
    MODEL_PATHS = {
        'global': '/pretrained_models/global_model.pth', ## <-Replace
        'clients': [
            '/pretrained_models/client_1_model.pth', ## <-Replace
            '/pretrained_models/client_2_model.pth', ## <-Replace
            '/pretrained_models/client_3_model.pth'  ## <-Replace
        ]
    }
    client_model_paths = MODEL_PATHS['clients']
    global_model_path = MODEL_PATHS['global']

    # ----------------------------------------------------------------
    # Step 3: Load models with dynamic progress updates
    # ----------------------------------------------------------------
    verbose = True
    log_interval = 10  # Log every 10 iterations

    sys.stdout.write("Loading client models...\n")
    sys.stdout.flush()
    client_models = []
    for i, path in enumerate(client_model_paths):
        sys.stdout.write(f"\rLoading client model {i+1}/{len(client_model_paths)}")
        sys.stdout.flush()
        client_models.append(load_model(path, model_class=BrainMRIClassifier))
    sys.stdout.write("\n")

    sys.stdout.write("Loading global model...\n")
    sys.stdout.flush()
    global_model = load_model(global_model_path, model_class=BrainMRIClassifier)
    sys.stdout.write("All models loaded successfully\n")
    sys.stdout.flush()

    sys.stdout.write("\nDebugging model dimensions...\n")
    sys.stdout.flush()
    debug_model_dimensions(client_models[0])

    # Dynamically determine model parameters
    model_info = get_model_info(global_model)
    input_channels = model_info["input_channels"] or 1
    num_classes = model_info["num_classes"] or 4
    sys.stdout.write(f"Detected model configuration: {input_channels} input channels, {num_classes} output classes\n")
    sys.stdout.flush()

    # ----------------------------------------------------------------
    # Step 4: Set inversion parameters and scales
    # ----------------------------------------------------------------
    num_samples = 20  # Number of inversion samples per model per class
    target_classes = list(range(min(4, num_classes)))
    scales = [(1, input_channels, 112, 112),
              (1, input_channels, 224, 224),
              (1, input_channels, 299, 299)]
    attack_params = {
        'iterations_per_scale': 1000,
        'lr_initial': 0.05,
        'lr_final': 0.001,
        'scales': scales,
        'auto_schedule_hyperparams': True,
        'regularization_weights': {
            'tv': 2e-3,
            'l2': 5e-4,
            'color': 1e-4,
            'smooth': 3e-4,
            'natural': 5e-4,
            'fft': 5e-5,
            'feature': 1e-2
        },
        'init_strategy': 'mixed'
    }

    # For ensemble attack: define weights and aggregate all models
    ensemble_weights = [0.5, 0.3, 0.2, 1.0]
    all_models = client_models + [global_model]

    # Dictionary to store one sample per model type per class for later comparison
    inverted_images = {}

    # ----------------------------------------------------------------
    # Step 5: Process each target class
    # ----------------------------------------------------------------
    for target_class in target_classes:
        # Create subfolder for this target class
        class_save_path = os.path.join(base_save_path, f"class_{target_class}")
        if not os.path.exists(class_save_path):
            os.makedirs(class_save_path)
        sys.stdout.write(f"\n=== Processing target class {target_class} ===\n")
        sys.stdout.write(f"Images for class {target_class} will be saved in: {class_save_path}\n")
        sys.stdout.flush()

        # ---- Inversions for Client Models ----
        for i, model in enumerate(client_models):
            client_save_path = os.path.join(class_save_path, f"client_{i+1}")
            if not os.path.exists(client_save_path):
                os.makedirs(client_save_path)
            sys.stdout.write(f"\nPerforming inversion attack on Client Model {i+1} for class {target_class}\n")
            sys.stdout.flush()
            first_sample_saved = False
            for sample in range(num_samples):
                inv_img = progressive_model_inversion_attack(
                    model,
                    target_class,
                    verbose=verbose,
                    log_interval=log_interval,
                    **attack_params
                )
                filename = os.path.join(
                    client_save_path,
                    f"client_model_{i+1}_class_{target_class}_inv_{sample+1}.png"
                )
                vutils.save_image(inv_img, filename)
                sys.stdout.write(f"\rClient Model {i+1} [Class {target_class}]: Sample {sample+1}/{num_samples} saved")
                sys.stdout.flush()
                if not first_sample_saved:
                    inverted_images[f'client_model_{i+1}_class_{target_class}'] = inv_img
                    first_sample_saved = True
            sys.stdout.write("\n")
            sys.stdout.flush()

        # ---- Inversions for Global Model ----
        global_save_path = os.path.join(class_save_path, "global")
        if not os.path.exists(global_save_path):
            os.makedirs(global_save_path)
        sys.stdout.write(f"\nPerforming inversion attack on Global Model for class {target_class}\n")
        sys.stdout.flush()
        first_sample_saved = False
        for sample in range(num_samples):
            inv_img_global = progressive_model_inversion_attack(
                global_model,
                target_class,
                verbose=verbose,
                log_interval=log_interval,
                **attack_params
            )
            filename = os.path.join(
                global_save_path,
                f"global_model_class_{target_class}_inv_{sample+1}.png"
            )
            vutils.save_image(inv_img_global, filename)
            sys.stdout.write(f"\rGlobal Model [Class {target_class}]: Sample {sample+1}/{num_samples} saved")
            sys.stdout.flush()
            if not first_sample_saved:
                inverted_images[f'global_model_class_{target_class}'] = inv_img_global
                first_sample_saved = True
        sys.stdout.write("\n")
        sys.stdout.flush()

        # ---- Inversions for Ensemble Model ----
        ensemble_save_path = os.path.join(class_save_path, "ensemble")
        if not os.path.exists(ensemble_save_path):
            os.makedirs(ensemble_save_path)
        sys.stdout.write(f"\nPerforming advanced ensemble attack for class {target_class}\n")
        sys.stdout.flush()
        first_sample_saved = False
        for sample in range(num_samples):
            inv_img_ensemble = advanced_ensemble_attack(
                all_models,
                target_class,
                weights=ensemble_weights,
                verbose=verbose,
                log_interval=log_interval,
                **attack_params
            )
            filename = os.path.join(
                ensemble_save_path,
                f"ensemble_model_class_{target_class}_inv_{sample+1}.png"
            )
            vutils.save_image(inv_img_ensemble, filename)
            sys.stdout.write(f"\rEnsemble Model [Class {target_class}]: Sample {sample+1}/{num_samples} saved")
            sys.stdout.flush()
            if not first_sample_saved:
                inverted_images[f'ensemble_model_class_{target_class}'] = inv_img_ensemble
                first_sample_saved = True
        sys.stdout.write("\n")
        sys.stdout.flush()

        # ---- Optional Comparative Analysis ----
        sys.stdout.write(f"\nRunning comparative analysis for class {target_class}...\n")
        sys.stdout.flush()
        compare_reconstructions(
            [
                inverted_images.get(f'client_model_1_class_{target_class}'),
                inverted_images.get(f'client_model_2_class_{target_class}'),
                inverted_images.get(f'client_model_3_class_{target_class}'),
                inverted_images.get(f'global_model_class_{target_class}'),
                inverted_images.get(f'ensemble_model_class_{target_class}')
            ],
            ['Client 1', 'Client 2', 'Client 3', 'Global Model', 'Ensemble Model'],
            os.path.join(class_save_path, f"class_{target_class}_comparison.png")
        )
        sys.stdout.write("\n")
        sys.stdout.flush()

    # ----------------------------------------------------------------
    # Step 6: Final display of results if needed
    # ----------------------------------------------------------------
    display_results(inverted_images)




def compare_reconstructions(images, labels, filename):
    """
    Create a visual comparison of different reconstruction methods
    with evaluation metrics. Fixed to handle different tensor dimensions.
    """
    fig, axs = plt.subplots(1, len(images), figsize=(4*len(images), 4))

    # Handle the case of a single image (convert axs to array for consistent indexing)
    if len(images) == 1:
        axs = np.array([axs])

    for i, (img, label) in enumerate(zip(images, labels)):
        # Display the image with proper dimension handling
        img_cpu = img.squeeze().cpu()

        # Check tensor dimensions
        if len(img_cpu.shape) == 2:
            # Already a 2D grayscale image (H, W)
            img_np = img_cpu.numpy()
            axs[i].imshow(img_np, cmap='gray')

        elif len(img_cpu.shape) == 3:
            if img_cpu.shape[0] == 1:
                # Single channel image (1, H, W)
                img_np = img_cpu[0].numpy()
                axs[i].imshow(img_np, cmap='gray')
            else:
                # Color image (3, H, W) -> (H, W, 3)
                img_np = img_cpu.permute(1, 2, 0).numpy()
                axs[i].imshow(img_np)
        else:
            # Unexpected format - display blank with error text
            axs[i].text(0.5, 0.5, f"Invalid shape: {img_cpu.shape}",
                      ha='center', va='center', transform=axs[i].transAxes)

        # Try to calculate quality metrics
        try:
            if len(img_cpu.shape) == 3 and img_cpu.shape[0] == 3:
                # Color image
                sharpness = calculate_sharpness(img_np)
                colorfulness = calculate_colorfulness(img_np)
                axs[i].set_title(f"{label}\nSharp: {sharpness:.2f}, Color: {colorfulness:.2f}")
            else:
                # Grayscale image - only calculate sharpness
                if len(img_cpu.shape) == 2:
                    sharpness = calculate_sharpness_gray(img_cpu.numpy())
                else:
                    sharpness = calculate_sharpness_gray(img_cpu[0].numpy())
                axs[i].set_title(f"{label}\nSharp: {sharpness:.2f}")
        except Exception as e:
            # If metrics fail, just show the label
            print(f"Error calculating metrics for {label}: {str(e)}")
            axs[i].set_title(label)

        axs[i].axis('off')

    plt.tight_layout()
    plt.savefig(filename, dpi=300)
    plt.close()

def calculate_sharpness_gray(img):
    """
    Calculate image sharpness using gradient magnitude for grayscale images.
    """
    # Compute gradients
    if img.shape[0] > 1 and img.shape[1] > 1:
        gx = img[1:, :] - img[:-1, :]
        gy = img[:, 1:] - img[:, :-1]

        # Use the valid region where both gradients are available
        if gx.shape[1] > 0 and gy.shape[0] > 0:
            gx_valid = gx[:, :-1]
            gy_valid = gy[:-1, :]

            # Compute gradient magnitude
            grad_mag = np.sqrt(gx_valid**2 + gy_valid**2)

            # Return mean gradient magnitude as sharpness measure
            return np.mean(grad_mag)

    # Fallback
    return 0.0

def calculate_sharpness(img):
    """
    Calculate image sharpness using gradient magnitude for color images.
    """
    # Convert to grayscale for gradient calculation
    gray = 0.2989 * img[:,:,0] + 0.5870 * img[:,:,1] + 0.1140 * img[:,:,2]
    return calculate_sharpness_gray(gray)

def calculate_colorfulness(img):
    """
    Calculate perceptual colorfulness metric.
    Based on Hasler and Süsstrunk (2003) metric.
    """
    # Split image into channels
    r = img[:,:,0]
    g = img[:,:,1]
    b = img[:,:,2]

    # Compute rg and yb components
    rg = r - g
    yb = 0.5 * (r + g) - b

    # Compute mean and std of components
    rg_mean = np.mean(rg)
    rg_std = np.std(rg)
    yb_mean = np.mean(yb)
    yb_std = np.std(yb)

    # Compute the colorfulness metric
    mean_rgyb = np.sqrt(rg_mean**2 + yb_mean**2)
    std_rgyb = np.sqrt(rg_std**2 + yb_std**2)

    return mean_rgyb + std_rgyb

# Also update display_results function
def display_results(inverted_images):
    """
    Display all generated images grouped by target class with metrics.
    """
    # Group by class
    class_groups = {}
    for key in inverted_images:
        if 'class' in key:
            class_id = key.split('class_')[1].split('_')[0]
            if class_id not in class_groups:
                class_groups[class_id] = []
            class_groups[class_id].append((key, inverted_images[key]))

    # Plot each class separately with quality metrics
    for class_id, images in class_groups.items():
        # Handle differently based on number of images
        num_images = len(images)
        if num_images == 0:
            continue

        fig, axs = plt.subplots(2, num_images, figsize=(15, 10))
        fig.suptitle(f"Reconstructions for Class {class_id}", fontsize=16)

        # If only one image, axs needs to be reshaped for consistent indexing
        if num_images == 1:
            axs = np.array([[axs[0]], [axs[1]]])

        for i, (key, img_tensor) in enumerate(images):
            # Convert tensor to numpy for display with dimension handling
            img = img_tensor.squeeze().cpu()

            # Handle both color and grayscale images
            if len(img.shape) == 3 and img.shape[0] in [1, 3]:
                if img.shape[0] == 1:
                    # Grayscale image
                    img_display = img[0].numpy()
                    axs[0, i].imshow(img_display, cmap='gray')
                else:
                    # Color image - permute to HWC for plotting
                    img_display = img.permute(1, 2, 0).numpy()
                    axs[0, i].imshow(img_display)
            elif len(img.shape) == 2:
                # Direct 2D grayscale
                axs[0, i].imshow(img.numpy(), cmap='gray')
            else:
                # Unexpected format
                axs[0, i].text(0.5, 0.5, f"Invalid shape: {img.shape}",
                             ha='center', va='center', transform=axs[0, i].transAxes)

            axs[0, i].set_title(key.split(f'_class_{class_id}')[0])
            axs[0, i].axis('off')

            # Display frequency spectrum for visualization
            try:
                # Convert to grayscale for FFT if needed
                if len(img.shape) == 3 and img.shape[0] == 3:
                    gray = 0.299 * img[0] + 0.587 * img[1] + 0.114 * img[2]
                elif len(img.shape) == 3 and img.shape[0] == 1:
                    gray = img[0]
                else:
                    gray = img

                f_transform = np.fft.fft2(gray.numpy())
                f_transform_shifted = np.fft.fftshift(f_transform)
                magnitude = np.log(np.abs(f_transform_shifted) + 1)

                axs[1, i].imshow(magnitude, cmap='viridis')
                axs[1, i].set_title('Frequency Domain')
                axs[1, i].axis('off')
            except Exception as e:
                axs[1, i].text(0.5, 0.5, f"FFT error: {str(e)[:20]}...",
                             ha='center', va='center', transform=axs[1, i].transAxes)

        plt.tight_layout()
        plt.subplots_adjust(top=0.9)
        plt.savefig(f'comparison_class_{class_id}.png')
        plt.close()

In [None]:
if __name__ == "__main__":
    main()

Base save path set to: /home/jupyter/notebooks/federated/model_inversion_fl/no_ref_multi_inversion/inverted_images/bench
Loading client models...
Loading client model 3/3
Loading global model...
All models loaded successfully

Debugging model dimensions...
Input shape: torch.Size([1, 1, 299, 299])
After features[0] Conv2d: torch.Size([1, 32, 299, 299])
After features[1] ReLU: torch.Size([1, 32, 299, 299])
After features[2] MaxPool2d: torch.Size([1, 32, 149, 149])
After features[3] BatchNorm2d: torch.Size([1, 32, 149, 149])
After features[4] Conv2d: torch.Size([1, 64, 149, 149])
After features[5] ReLU: torch.Size([1, 64, 149, 149])
After features[6] MaxPool2d: torch.Size([1, 64, 74, 74])
After features[7] BatchNorm2d: torch.Size([1, 64, 74, 74])
After features[8] Conv2d: torch.Size([1, 128, 74, 74])
After features[9] ReLU: torch.Size([1, 128, 74, 74])
After features[10] MaxPool2d: torch.Size([1, 128, 37, 37])
After features[11] BatchNorm2d: torch.Size([1, 128, 37, 37])
After features[12