In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from typing import List, Dict, Optional, Union
import os
import torch
from utils import GMM4PR, get_dataset, build_model, build_decoder_from_flag
from typing import Tuple, Dict, Any
# Define CSV files to analyze
# Get all CSV files from the directory using pathlib

# csv_dir = Path('./ckp/gmm_ckp/resnet18_on_cifar10')
# csv_dir = Path('./ckp/gmm_ckp/resnet50_on_cifar10')
csv_dir = Path('./ckp/gmm_ckp/resnet18_on_tinyimagenet')
# csv_dir = Path('./ckp/gmm_ckp/vgg16_on_cifar10')

all_csv_files = sorted([str(f) for f in csv_dir.glob('*.csv')])

print(f"Found {len(all_csv_files)} CSV files in {csv_dir}")

# Inspect first few files
def inspect_csv_columns(csv_files: List[str], last_row_idx: int = -1):
    """
    Inspect and display columns available in CSV files.
    
    Parameters
    ----------
    csv_files : List[str]
        List of CSV file paths to inspect
    """
    print("="*80)
    print("CSV File Column Inspection")
    print("="*80)
    
    for csv_file in csv_files:
        df = pd.read_csv(csv_file)
        print(f"\nFile: {Path(csv_file).name}")
        print(f"  Shape: {df.shape} (rows, columns)")
        print(f"  Columns: {list(df.columns)}")
        print(f"  Last few rows:")
        print(df.iloc[:last_row_idx].tail(3).to_string(index=False))
        print("-" * 80)


print("Function 'inspect_csv_columns' defined successfully!")


def load_gmm_config(gmm_path: str) -> Dict[str, Any]:
    """
    Load the configuration from a saved GMM checkpoint.
    
    Args:
        gmm_path: Path to the GMM checkpoint file (.pt)
        
    Returns:
        Dictionary containing both GMM config and full experiment config
    """
    if not os.path.isfile(gmm_path):
        raise FileNotFoundError(f"GMM checkpoint not found: {gmm_path}")
    
    ckpt = torch.load(gmm_path, map_location="cpu", weights_only=False)
    
    if "config" not in ckpt:
        raise ValueError(f"No config found in checkpoint: {gmm_path}")
    
    gmm_config = ckpt["config"]
    
    # The full experiment config is stored inside gmm_config["config"]
    exp_config = gmm_config.get("config", {})
    
    print(f"✓ Loaded config from: {gmm_path}")
    print(f"  Experiment: {exp_config.get('exp_name', 'N/A')}")
    print(f"  Dataset: {exp_config.get('dataset', 'N/A')}")
    print(f"  Architecture: {exp_config.get('arch', 'N/A')}")
    print(f"  K: {gmm_config['K']}, latent_dim: {gmm_config['latent_dim']}")
    print(f"  Condition mode: {gmm_config.get('cond_mode', 'None')}")
    print(f"  Covariance type: {gmm_config.get('cov_type', 'diag')}")
    print(f"  Perturbation: {gmm_config.get('budget', {})}")
    
    return gmm_config


def load_gmm_model(gmm_path: str, device: str = "cuda") -> Tuple[GMM4PR, Any, Any, Any]:
    """
    Load a complete GMM model from checkpoint with all necessary components.
    
    Args:
        gmm_path: Path to the GMM checkpoint file (.pt)
        device: Device to load the model to ('cuda' or 'cpu')
        
    Returns:
        Tuple of (gmm_model, classifier_model, feat_extractor, up_sampler)
    """
    # Step 1: Load configuration
    gmm_config = load_gmm_config(gmm_path)
    exp_config = gmm_config.get("config", {})
    
    # Extract necessary config values
    dataset_name = exp_config.get("dataset", "cifar10")
    data_root = exp_config.get("data_root", "./dataset")
    arch = exp_config.get("arch", "resnet18")
    clf_ckpt = exp_config.get("clf_ckpt")
    use_decoder = exp_config.get("use_decoder", False)
    decoder_backend = exp_config.get("decoder_backend", "bicubic")
    latent_dim = gmm_config["latent_dim"]
    
    # Step 2: Load dataset to get output shape
    print(f"\n{'='*60}")
    print("Loading dataset...")
    dataset, num_classes, out_shape = get_dataset(dataset_name, data_root, train=True, resize=False)
    print(f"  Dataset: {dataset_name}, Classes: {num_classes}, Shape: {out_shape}")
    
    # Step 3: Load classifier and feature extractor
    print(f"\n{'='*60}")
    print(f"Loading classifier: {arch}")
    print(f"  Device: {device}")
    
    if not clf_ckpt or not os.path.isfile(clf_ckpt):
        raise FileNotFoundError(f"Classifier checkpoint not found: {clf_ckpt}")
    
    model, feat_extractor = build_model(arch, num_classes, device)
    
    # Load classifier weights
    state = torch.load(clf_ckpt, map_location="cpu", weights_only=False)
    state = state.get("state_dict", state.get("model_state", state))
    state = {k.replace("module.", ""): v for k, v in state.items()}
    
    model.load_state_dict(state, strict=False)
    model = model.to(device).eval()
    for p in model.parameters():
        p.requires_grad = False
    
    feat_extractor = feat_extractor.to(device).eval()
    for p in feat_extractor.parameters():
        p.requires_grad = False
    
    # Check parameter sharing
    model_params = {id(p) for p in model.parameters()}
    feat_params = {id(p) for p in feat_extractor.parameters()}
    shared = model_params & feat_params
    
    print(f"  Model params: {len(model_params)}, Feat extractor params: {len(feat_params)}")
    if shared:
        print(f"  Shared parameters: {len(shared)}")
    
    # Step 4: Build decoder/up_sampler if needed
    up_sampler = None
    if use_decoder:
        print(f"\n{'='*60}")
        print(f"Building decoder: {decoder_backend}")
        up_sampler = build_decoder_from_flag(
            decoder_backend,
            latent_dim,
            out_shape,
            device
        )
        print(f"  ✓ Decoder built successfully!")
    else:
        print(f"\n{'='*60}")
        print(f"No decoder used (use_decoder={use_decoder})")
    
    # Step 5: Load the GMM model
    print(f"\n{'='*60}")
    print(f"Loading GMM model from: {gmm_path}")
    gmm = GMM4PR.load_from_checkpoint(
        gmm_path,
        feat_extractor=feat_extractor,
        up_sampler=up_sampler,
        map_location=device,
        strict=True
    )
    
    gmm = gmm.to(device).eval()
    print(f"✓ GMM model loaded successfully!")
    print(f"{'='*60}\n")
    
    return gmm, gmm_config, model, feat_extractor, up_sampler


print("Functions defined: load_gmm_config(), load_gmm_model()")


keywords = ['loss_hist', 'cond(xy)', 'K7', 'decoder(trainable_128)', 'linf(16)' ,'reg(none)']
keywords_entropy = keywords.copy()
keywords_entropy[0] = 'collapse_log'
filtered_csv_files = [f for f in all_csv_files if all(k in f for k in keywords)]
filtered_csv_files_entropy = [f for f in all_csv_files if all(k in f for k in keywords_entropy)]

# Create corresponding GMM model paths from the filtered CSV files
gmm_model_paths = [f.replace('loss_hist_', 'gmm_').replace('.csv', '.pt') for f in filtered_csv_files]
print("\n".join([f"{n}" for n in filtered_csv_files]))
print("\n".join([f"{n}" for n in filtered_csv_files_entropy]))
print("\n".join([f"{n}" for n in gmm_model_paths]))

Found 2 CSV files in ckp/gmm_ckp/resnet18_on_tinyimagenet
Function 'inspect_csv_columns' defined successfully!
Functions defined: load_gmm_config(), load_gmm_model()
ckp/gmm_ckp/resnet18_on_tinyimagenet/loss_hist_K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none).csv
ckp/gmm_ckp/resnet18_on_tinyimagenet/collapse_log_K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none).csv
ckp/gmm_ckp/resnet18_on_tinyimagenet/gmm_K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none).pt


In [5]:
from utils import compute_pr_on_clean_correct
# Specify the GMM checkpoint path
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET = "tinyimagenet"

res_collection = {}
# Load the complete GMM model
for GMM_PATH in gmm_model_paths:
    gmm, cfg, model, feat_extractor, up_sampler = load_gmm_model(GMM_PATH, device=DEVICE)

    test_dataset, num_classes, out_shape = get_dataset(DATASET, "./dataset", train=True, resize=False)
    # Create DataLoader for evaluation
    loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=256,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Set number of samples per image
    S_SAMPLES = 500  # Number of Monte Carlo samples per image

    # Optional: Specify batch indices to evaluate (None = all batches)
    # BATCH_INDICES = None  # Evaluate all batches
    # BATCH_INDICES = [0, 1, 2, 3, 4]  # Evaluate only first 5 batches
    BATCH_INDICES = range(1000) 

    print(f"\nComputing PR on clean-correct samples...")
    print(f"  GMM: {cfg['config']['exp_name']}")
    print(f"  Samples per image: {S_SAMPLES}")
    print(f"  Batch indices: {'All' if BATCH_INDICES is None else BATCH_INDICES}")

    # Compute PR
    pr, n_used, clean_acc = compute_pr_on_clean_correct(
        model=model,
        gmm=gmm,
        loader=loader,
        out_shape=out_shape,
        device=DEVICE,
        num_samples=S_SAMPLES,
        batch_indices=BATCH_INDICES,
        temperature=1,  # Very low temperature should approximate hard sampling
        use_soft_sampling=True,
        chunk_size=32
    )

    res_collection[GMM_PATH] = pr

    print(f"\n{'='*60}")
    print(f"Results:")
    print(f"  Clean Accuracy: {clean_acc*100:.2f}%")
    print(f"  Probabilistic Robustness (PR): {pr}")
    print(f"  Clean-correct samples used: {n_used}")
    print(f"{'='*60}")

    print("\nAll components loaded and ready for analysis!")


✓ Loaded config from: ckp/gmm_ckp/resnet18_on_tinyimagenet/gmm_K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none).pt
  Experiment: K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none)
  Dataset: tinyimagenet
  Architecture: resnet18
  K: 7, latent_dim: 128
  Condition mode: xy
  Covariance type: full
  Perturbation: {'norm': 'linf', 'eps': 0.06274509803921569}

Loading dataset...
  Dataset: tinyimagenet, Classes: 200, Shape: (3, 64, 64)

Loading classifier: resnet18
  Device: cuda
  Model params: 62, Feat extractor params: 60
  Shared parameters: 60

Building decoder: bicubic_trainable
[Decoder 'bicubic'] 18,963 params
  ✓ Decoder built successfully!

Loading GMM model from: ckp/gmm_ckp/resnet18_on_tinyimagenet/gmm_K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none).pt


  ckpt = torch.load(path, map_location=map_location)


[Params] Shared trunk: 263,680 | pi: 903 | mu: 459,648, | cov: 59,124,224 | Total: 59,584,775
Model loaded from ckp/gmm_ckp/resnet18_on_tinyimagenet/gmm_K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none).pt
✓ GMM model loaded successfully!


Computing PR on clean-correct samples...
  GMM: K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none)
  Samples per image: 500
  Batch indices: range(0, 1000)
[PR@clean] used=5068 / seen=10000 (clean acc=50.68%), num_samples=500, chunk_size=32 → PR=0.9296 [method: soft (Gumbel-Softmax)]

Results:
  Clean Accuracy: 50.68%
  Probabilistic Robustness (PR): 0.9296479936462739
  Clean-correct samples used: 5068

All components loaded and ready for analysis!


In [3]:
res_collection

{'ckp/gmm_ckp/vgg16_on_tinyimagenet/gmm_K7_cond(xy)_decoder(trainable_128)_linf(16)_reg(none).pt': 0.7883551615140205}

In [96]:
inspect_csv_columns(filtered_csv_files, last_row_idx=200)

CSV File Column Inspection

File: loss_hist_K12_cond(xy)_decoder(none)_linf(16)_reg(none).csv
  Shape: (50, 6) (rows, columns)
  Columns: ['epoch', 'loss', 'main_loss', 'reg_loss', 'pr', 'learning_rate']
  Last few rows:
 epoch     loss  main_loss  reg_loss       pr  learning_rate
    48 4.566737   4.566737       0.0 0.761502         0.0005
    49 4.560260   4.560260       0.0 0.760193         0.0005
    50 4.554434   4.554434       0.0 0.759272         0.0005
--------------------------------------------------------------------------------

File: loss_hist_K12_cond(xy)_decoder(nontrainable)_linf(16)_reg(none).csv
  Shape: (50, 6) (rows, columns)
  Columns: ['epoch', 'loss', 'main_loss', 'reg_loss', 'pr', 'learning_rate']
  Last few rows:
 epoch     loss  main_loss  reg_loss       pr  learning_rate
    48 5.481014   5.481014       0.0 0.893260         0.0005
    49 5.477815   5.477815       0.0 0.892706         0.0005
    50 5.475565   5.475565       0.0 0.892452         0.0005
--------

In [97]:
inspect_csv_columns(filtered_csv_files_entropy, last_row_idx=5)

CSV File Column Inspection

File: collapse_log_K12_cond(xy)_decoder(none)_linf(16)_reg(none).csv
  Shape: (5, 8) (rows, columns)
  Columns: ['epoch', 'max_pi', 'min_pi', 'std_pi', 'entropy_ratio', 'T_pi', 'T_gumbel', 'avg_loss']
  Last few rows:
 epoch   max_pi   min_pi   std_pi  entropy_ratio  T_pi  T_gumbel  avg_loss
    30 0.199573 0.000024 0.082800       0.761507   1.0  0.467347  4.728826
    40 0.199597 0.000009 0.082816       0.761165   1.0  0.283673  4.633919
    50 0.199605 0.000004 0.082822       0.761042   1.0  0.100000  4.554434
--------------------------------------------------------------------------------

File: collapse_log_K12_cond(xy)_decoder(nontrainable)_linf(16)_reg(none).csv
  Shape: (5, 8) (rows, columns)
  Columns: ['epoch', 'max_pi', 'min_pi', 'std_pi', 'entropy_ratio', 'T_pi', 'T_gumbel', 'avg_loss']
  Last few rows:
 epoch   max_pi   min_pi   std_pi  entropy_ratio  T_pi  T_gumbel  avg_loss
    30 0.302273 0.000058 0.094034       0.738606   1.0  0.467347  5.545

In [16]:
def compute_pr_with_baseline_noise(
    model,
    loader,
    out_shape,
    device,
    distribution='gaussian',
    num_samples=500,
    epsilon=16.0/255,
    norm_type='linf',
    batch_indices=None,
    chunk_size=None,
    clip_to_valid_range=False
):
    """
    Compute PR on clean-correct set using baseline noise distributions (Gaussian or Uniform).
    
    This function samples perturbations directly in the input space from standard distributions
    (no GMM, no decoder needed) and evaluates probabilistic robustness.
    
    PR = E_{(x,y) ∈ CleanCorrect} E_{delta~Noise} [ 1{ f(x+delta) = y } ]
    
    Args:
        model: Classifier model (should be in eval mode)
        loader: DataLoader for evaluation
        out_shape: Image shape (C, H, W)
        device: torch device
        distribution: Type of noise distribution to use:
                     - 'gaussian': Standard Gaussian N(0, I) in input space
                     - 'uniform': Uniform distribution U(-1, 1) in input space
        num_samples: Number of samples per image
        epsilon: Perturbation budget (gamma parameter for g_ball)
        norm_type: Type of norm constraint ('linf' or 'l2')
        batch_indices: Optional set/list of batch indices to evaluate
        chunk_size: Maximum samples to process at once. If None, process all at once.
                   Useful for large num_samples to avoid OOM errors.
        clip_to_valid_range: If True, clip perturbed images to [0, 1]. Default False to match 
                            GMM behavior (which doesn't clip).
    
    Returns:
        (pr, n_used, clean_acc): PR score, number of clean-correct samples, clean accuracy
    """
    from utils import g_ball
    
    C, H, W = out_shape
    
    total_used = 0       # number of clean-correct samples
    pr_sum = 0.0         # sum of per-image PR values
    clean_correct = 0    # number of correctly classified clean samples
    total_seen = 0       # total samples seen
    
    with torch.no_grad():
        for it, (x, y, _) in enumerate(loader):
            if (batch_indices is not None) and (it not in batch_indices):
                continue
            
            x, y = x.to(device), y.to(device)
            B = x.size(0)
            
            # Clean predictions & mask of correct ones
            logits_clean = model(x)
            pred_clean = logits_clean.argmax(1)
            mask = (pred_clean == y)
            
            clean_correct += mask.sum().item()
            total_seen += B
            
            if mask.sum().item() == 0:
                continue  # nothing to evaluate in this batch
            
            x_sel = x[mask]
            y_sel = y[mask]
            n = x_sel.size(0)
            
            # Determine chunk size for processing
            if chunk_size is None:
                # Process all samples at once
                actual_chunk_size = num_samples
            else:
                actual_chunk_size = min(chunk_size, num_samples)
            
            # Accumulate success rate for each image
            per_image_success = torch.zeros(n, device=device)
            
            # Process in chunks to avoid OOM
            num_chunks = (num_samples + actual_chunk_size - 1) // actual_chunk_size
            
            for chunk_idx in range(num_chunks):
                start_idx = chunk_idx * actual_chunk_size
                end_idx = min((chunk_idx + 1) * actual_chunk_size, num_samples)
                S_chunk = end_idx - start_idx
                
                # Sample noise from the specified distribution
                # Shape: [S_chunk, n, C, H, W]
                if distribution == 'gaussian':
                    # Standard Gaussian: N(0, I)
                    noise = torch.randn(S_chunk, n, C, H, W, device=device)
                elif distribution == 'uniform':
                    # Uniform: U(-1, 1)
                    noise = torch.rand(S_chunk, n, C, H, W, device=device) * 2 - 1
                else:
                    raise ValueError(f"Unknown distribution: {distribution}. Choose 'gaussian' or 'uniform'.")
                
                # Reshape for g_ball: [S_chunk * n, C, H, W]
                noise_flat = noise.view(S_chunk * n, C, H, W)
                
                # Apply g_ball to constrain to perturbation budget
                perturbations = g_ball(noise_flat, gamma=epsilon, norm_type=norm_type)
                
                # Create perturbed images (matching GMM pipeline exactly)
                x_rep = x_sel.unsqueeze(0).expand(S_chunk, -1, -1, -1, -1).reshape(S_chunk * n, C, H, W)
                x_perturbed = x_rep + perturbations
                
                # Optionally clip to valid image range [0, 1]
                # Note: GMM method does NOT clip by default, so we match that behavior
                if clip_to_valid_range:
                    x_perturbed = torch.clamp(x_perturbed, 0.0, 1.0)
                
                # Evaluate classifier on perturbed images
                y_rep = y_sel.repeat(S_chunk)
                logits = model(x_perturbed)
                pred = logits.argmax(1)
                
                # Check which predictions match the true label
                # Shape: [S_chunk * n] -> [S_chunk, n]
                correct = (pred == y_rep).float().view(S_chunk, n)
                
                # Accumulate success rate
                per_image_success += correct.sum(dim=0)
            
            # Compute per-image PR (average success rate over all samples)
            per_image_pr = per_image_success / num_samples
            
            # Accumulate
            pr_sum += per_image_pr.sum().item()
            total_used += n
    
    pr = pr_sum / max(1, total_used)
    clean_acc = clean_correct / max(1, total_seen)
    
    chunk_info = f", chunk_size={chunk_size}" if chunk_size is not None else ""
    clip_info = " (clipped)" if clip_to_valid_range else " (no clipping)"
    print(f"[PR@clean - Baseline] used={total_used} / seen={total_seen} "
          f"(clean acc={clean_acc*100:.2f}%), num_samples={num_samples}{chunk_info} → "
          f"PR={pr:.4f} [distribution: {distribution}, norm: {norm_type}, eps={epsilon:.6f}{clip_info}]")
    
    return pr, total_used, clean_acc


print("Function 'compute_pr_with_baseline_noise' defined successfully!")

Function 'compute_pr_with_baseline_noise' defined successfully!


In [None]:
# Example usage of compute_pr_with_baseline_noise
# Uncomment and run to test the function

# Load test dataset and model (reusing from above)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET = "cifar10"

# Load test dataset
test_dataset, num_classes, out_shape = get_dataset(DATASET, "./dataset", train=False, resize=False)

# Create DataLoader for evaluation
loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# Load a pretrained classifier (use the same one from GMM experiments)
# model should be loaded already from previous cells

# Set parameters
NUM_SAMPLES = 500  # Number of samples per image
EPSILON = 16 / 255  # Same as linf(16) in your GMM experiments
NORM_TYPE = 'linf'  # 'linf' or 'l2'
BATCH_INDICES = range(1000)  # Evaluate first 1000 batches, or None for all

# Example 1: Evaluate with Gaussian noise
print("="*80)
print("Evaluating with Standard Gaussian Noise")
print("="*80)
pr_gaussian, n_used_gaussian, clean_acc_gaussian = compute_pr_with_baseline_noise(
    model=model,
    loader=loader,
    out_shape=out_shape,
    device=DEVICE,
    distribution='gaussian',
    num_samples=NUM_SAMPLES,
    epsilon=EPSILON,
    norm_type=NORM_TYPE,
    batch_indices=BATCH_INDICES,
    chunk_size=8  # Process 8 samples at a time to avoid OOM
)

# Example 2: Evaluate with Uniform noise
print("\n" + "="*80)
print("Evaluating with Uniform Noise")
print("="*80)
pr_uniform, n_used_uniform, clean_acc_uniform = compute_pr_with_baseline_noise(
    model=model,
    loader=loader,
    out_shape=out_shape,
    device=DEVICE,
    distribution='uniform',
    num_samples=NUM_SAMPLES,
    epsilon=EPSILON,
    norm_type=NORM_TYPE,
    batch_indices=BATCH_INDICES,
    chunk_size=32
)

# Compare results
print("\n" + "="*80)
print("Comparison Summary")
print("="*80)
print(f"dataset: {DATASET}, model: {model}, num_samples: {NUM_SAMPLES}, epsilon: {EPSILON}, norm: {NORM_TYPE}")
print(f"Gaussian PR: {pr_gaussian}")
print(f"Uniform PR:  {pr_uniform}")
print(f"GMM PR:      {res_collection[gmm_model_paths[0]]:.4f}")  # From previous evaluation

Files already downloaded and verified
Evaluating with Standard Gaussian Noise
[PR@clean - Baseline] used=9029 / seen=10000 (clean acc=90.29%), num_samples=500, chunk_size=8 → PR=0.9919 [distribution: gaussian, norm: linf, eps=0.062745 (no clipping)]

Evaluating with Uniform Noise
[PR@clean - Baseline] used=9029 / seen=10000 (clean acc=90.29%), num_samples=500, chunk_size=32 → PR=0.9939 [distribution: uniform, norm: linf, eps=0.062745 (no clipping)]

Comparison Summary
dataset: cifar10, model: ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runni

In [None]:
# ============================================================================
# PGD and CW Attack Functions
# ============================================================================

import torch
import torch.nn.functional as F
from tqdm import tqdm


# ----------------------------------------------------------------------------
# Helper Functions
# ----------------------------------------------------------------------------

def accuracy(logits, target):
    """Compute accuracy given logits and target labels."""
    _, pred = torch.max(logits, dim=1)
    correct = (pred == target).sum()
    total = target.size(0)
    acc = (float(correct) / total) * 100
    return acc


# ----------------------------------------------------------------------------
# PGD Attack
# ----------------------------------------------------------------------------

def pgd_attack(model, x, y, epsilon=8/255, step_size=2/255, num_steps=10, random_init=True):
    """
    Perform PGD (Projected Gradient Descent) attack on input images.
    
    Args:
        model: Classifier model
        x: Input images [B, C, H, W]
        y: True labels [B]
        epsilon: Maximum perturbation budget (L-inf norm)
        step_size: Step size for each iteration
        num_steps: Number of PGD iterations
        random_init: Whether to initialize with random noise
    
    Returns:
        x_adv: Adversarial examples [B, C, H, W]
    """
    model.eval()
    
    x_adv = x.detach().clone()
    
    # Random initialization within epsilon ball
    if random_init:
        x_adv = x_adv + torch.empty_like(x_adv).uniform_(-epsilon, epsilon)
        x_adv = torch.clamp(x_adv, min=0, max=1).detach()
    
    # PGD iterations
    for _ in range(num_steps):
        x_adv.requires_grad_(True)
        
        # Forward pass
        logits = model(x_adv)
        loss = F.cross_entropy(logits, y, reduction="mean")
        
        # Backward pass
        grad = torch.autograd.grad(loss, [x_adv])[0]
        
        # Update adversarial example
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        
        # Project back to epsilon ball
        delta = torch.clamp(x_adv - x, min=-epsilon, max=epsilon)
        x_adv = torch.clamp(x + delta, min=0, max=1).detach()
    
    return x_adv


def evaluate_pgd(model, loader, device, epsilon=8/255, step_size=2/255, num_steps=10, 
                 batch_indices=None, verbose=True):
    """
    Evaluate model robustness under PGD attack on clean-correct samples.
    
    This function:
    1. Identifies correctly classified samples (clean-correct set)
    2. Performs PGD attack only on those samples
    3. Reports attack success rate and robust accuracy on clean-correct set
    
    Args:
        model: Classifier model
        loader: DataLoader for evaluation
        device: torch device
        epsilon: Maximum perturbation budget (L-inf norm)
        step_size: Step size for each PGD iteration
        num_steps: Number of PGD iterations
        batch_indices: Optional indices of batches to evaluate
        verbose: Whether to show progress bar
    
    Returns:
        (clean_acc, robust_acc, attack_success_rate): 
            - clean_acc: Accuracy on all samples
            - robust_acc: Accuracy on clean-correct samples after attack
            - attack_success_rate: Percentage of clean-correct samples misclassified by attack
    """
    model.eval()
    
    total_samples = 0
    clean_correct_count = 0
    robust_correct_count = 0
    
    iterator = enumerate(loader)
    if verbose:
        iterator = tqdm(iterator, total=len(loader), desc='PGD Attack Evaluation')
    
    for i, batch in iterator:
        if (batch_indices is not None) and (i not in batch_indices):
            continue
        
        # Handle both 2-tuple and 3-tuple batch formats
        if len(batch) == 3:
            x, y, _ = batch
        else:
            x, y = batch
        
        x, y = x.to(device), y.to(device)
        B = x.size(0)
        total_samples += B
        
        # Step 1: Identify clean-correct samples
        with torch.no_grad():
            clean_logits = model(x)
            clean_pred = clean_logits.argmax(1)
            clean_correct_mask = (clean_pred == y)
        
        clean_correct_count += clean_correct_mask.sum().item()
        
        # Skip if no correct predictions in this batch
        if clean_correct_mask.sum().item() == 0:
            continue
        
        # Step 2: Select only clean-correct samples
        x_correct = x[clean_correct_mask]
        y_correct = y[clean_correct_mask]
        
        # Step 3: Perform PGD attack only on clean-correct samples
        x_adv = pgd_attack(model, x_correct, y_correct, epsilon=epsilon, 
                          step_size=step_size, num_steps=num_steps, random_init=True)
        
        # Step 4: Evaluate robustness on attacked samples
        with torch.no_grad():
            adv_logits = model(x_adv)
            adv_pred = adv_logits.argmax(1)
            robust_correct = (adv_pred == y_correct).sum().item()
            robust_correct_count += robust_correct
    
    # Compute metrics
    clean_acc = 100.0 * clean_correct_count / total_samples
    robust_acc = 100.0 * robust_correct_count / max(1, clean_correct_count)
    attack_success_rate = 100.0 - robust_acc
    
    print(f"\n[PGD Attack Results]")
    print(f"  Epsilon: {epsilon:.4f} ({epsilon*255:.1f}/255)")
    print(f"  Step size: {step_size:.4f} ({step_size*255:.1f}/255)")
    print(f"  Num steps: {num_steps}")
    print(f"  " + "="*60)
    print(f"  Total samples: {total_samples}")
    print(f"  Clean-correct samples: {clean_correct_count} ({clean_acc:.2f}%)")
    print(f"  " + "="*60)
    print(f"  Robust accuracy (on clean-correct): {robust_acc:.2f}%")
    print(f"  Attack success rate: {attack_success_rate:.2f}%")
    print(f"  " + "="*60)
    
    return clean_acc, robust_acc, attack_success_rate


# ----------------------------------------------------------------------------
# CW (Carlini-Wagner) Attack
# ----------------------------------------------------------------------------

def cw_loss(logits, targets, num_classes, margin=2.0, targeted=False):
    """
    Carlini-Wagner loss function.
    
    Args:
        logits: Model outputs [B, num_classes]
        targets: True labels [B]
        num_classes: Number of classes
        margin: Confidence margin (kappa in the CW paper)
        targeted: Whether this is a targeted attack
    
    Returns:
        loss: CW loss (scalar)
    """
    onehot_targets = F.one_hot(targets, num_classes).float().to(logits.device)
    
    # Score of the correct class
    self_loss = torch.sum(onehot_targets * logits, dim=1)
    
    # Maximum score among other classes
    other_loss = torch.max(
        (1 - onehot_targets) * logits - onehot_targets * 10000, dim=1
    )[0]
    
    # CW loss: make correct class score lower than max of other classes
    if targeted:
        loss = torch.sum(torch.clamp(other_loss - self_loss + margin, min=0))
    else:
        loss = -torch.sum(torch.clamp(self_loss - other_loss + margin, min=0))
    
    # Average over batch
    loss = loss / logits.size(0)
    
    return loss


def cw_attack(model, x, y, num_classes, epsilon=8/255, step_size=2/255, num_steps=10, 
              margin=2.0, random_init=True):
    """
    Perform CW (Carlini-Wagner) attack on input images.
    
    Args:
        model: Classifier model
        x: Input images [B, C, H, W]
        y: True labels [B]
        num_classes: Number of classes
        epsilon: Maximum perturbation budget (L-inf norm)
        step_size: Step size for each iteration
        num_steps: Number of iterations
        margin: Confidence margin (kappa)
        random_init: Whether to initialize with random noise
    
    Returns:
        x_adv: Adversarial examples [B, C, H, W]
    """
    model.eval()
    
    x_adv = x.detach().clone()
    
    # Random initialization within epsilon ball
    if random_init:
        x_adv = x_adv + torch.empty_like(x_adv).uniform_(-epsilon, epsilon)
        x_adv = torch.clamp(x_adv, min=0, max=1).detach()
    
    # CW attack iterations
    for _ in range(num_steps):
        x_adv.requires_grad_(True)
        
        # Forward pass
        logits = model(x_adv)
        loss = cw_loss(logits, y, num_classes, margin=margin, targeted=False)
        
        # Backward pass
        grad = torch.autograd.grad(loss, [x_adv])[0]
        
        # Update adversarial example (gradient ascent to maximize loss)
        x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
        
        # Project back to epsilon ball
        delta = torch.clamp(x_adv - x, min=-epsilon, max=epsilon)
        x_adv = torch.clamp(x + delta, min=0, max=1).detach()
    
    return x_adv


def evaluate_cw(model, loader, device, num_classes, epsilon=8/255, step_size=2/255, 
                num_steps=10, margin=2.0, batch_indices=None, verbose=True):
    """
    Evaluate model robustness under CW attack on clean-correct samples.
    
    This function:
    1. Identifies correctly classified samples (clean-correct set)
    2. Performs CW attack only on those samples
    3. Reports attack success rate and robust accuracy on clean-correct set
    
    Args:
        model: Classifier model
        loader: DataLoader for evaluation
        device: torch device
        num_classes: Number of classes
        epsilon: Maximum perturbation budget (L-inf norm)
        step_size: Step size for each iteration
        num_steps: Number of iterations
        margin: Confidence margin (kappa)
        batch_indices: Optional indices of batches to evaluate
        verbose: Whether to show progress bar
    
    Returns:
        (clean_acc, robust_acc, attack_success_rate): 
            - clean_acc: Accuracy on all samples
            - robust_acc: Accuracy on clean-correct samples after attack
            - attack_success_rate: Percentage of clean-correct samples misclassified by attack
    """
    model.eval()
    
    total_samples = 0
    clean_correct_count = 0
    robust_correct_count = 0
    
    iterator = enumerate(loader)
    if verbose:
        iterator = tqdm(iterator, total=len(loader), desc='CW Attack Evaluation')
    
    for i, batch in iterator:
        if (batch_indices is not None) and (i not in batch_indices):
            continue
        
        # Handle both 2-tuple and 3-tuple batch formats
        if len(batch) == 3:
            x, y, _ = batch
        else:
            x, y = batch
        
        x, y = x.to(device), y.to(device)
        B = x.size(0)
        total_samples += B
        
        # Step 1: Identify clean-correct samples
        with torch.no_grad():
            clean_logits = model(x)
            clean_pred = clean_logits.argmax(1)
            clean_correct_mask = (clean_pred == y)
        
        clean_correct_count += clean_correct_mask.sum().item()
        
        # Skip if no correct predictions in this batch
        if clean_correct_mask.sum().item() == 0:
            continue
        
        # Step 2: Select only clean-correct samples
        x_correct = x[clean_correct_mask]
        y_correct = y[clean_correct_mask]
        
        # Step 3: Perform CW attack only on clean-correct samples
        x_adv = cw_attack(model, x_correct, y_correct, num_classes, epsilon=epsilon, 
                         step_size=step_size, num_steps=num_steps, 
                         margin=margin, random_init=True)
        
        # Step 4: Evaluate robustness on attacked samples
        with torch.no_grad():
            adv_logits = model(x_adv)
            adv_pred = adv_logits.argmax(1)
            robust_correct = (adv_pred == y_correct).sum().item()
            robust_correct_count += robust_correct
    
    # Compute metrics
    clean_acc = 100.0 * clean_correct_count / total_samples
    robust_acc = 100.0 * robust_correct_count / max(1, clean_correct_count)
    attack_success_rate = 100.0 - robust_acc
    
    print(f"\n[CW Attack Results]")
    print(f"  Epsilon: {epsilon:.4f} ({epsilon*255:.1f}/255)")
    print(f"  Step size: {step_size:.4f} ({step_size*255:.1f}/255)")
    print(f"  Num steps: {num_steps}")
    print(f"  Margin (kappa): {margin}")
    print(f"  " + "="*60)
    print(f"  Total samples: {total_samples}")
    print(f"  Clean-correct samples: {clean_correct_count} ({clean_acc:.2f}%)")
    print(f"  " + "="*60)
    print(f"  Robust accuracy (on clean-correct): {robust_acc:.2f}%")
    print(f"  Attack success rate: {attack_success_rate:.2f}%")
    print(f"  " + "="*60)
    
    return clean_acc, robust_acc, attack_success_rate


print("Attack functions defined successfully!")
print("  - pgd_attack(): Generate PGD adversarial examples")
print("  - evaluate_pgd(): Evaluate model under PGD attack (on clean-correct)")
print("  - cw_attack(): Generate CW adversarial examples")
print("  - evaluate_cw(): Evaluate model under CW attack (on clean-correct)")

In [None]:
# ============================================================================
# Example Usage: PGD and CW Attacks
# ============================================================================
# Uncomment and run to evaluate attacks

# # Assuming model and loader are already defined from previous cells
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DATASET = "cifar10"  # or "tinyimagenet"
# 
# # Load test dataset
# test_dataset, num_classes, out_shape = get_dataset(DATASET, "./dataset", train=False, resize=False)
# 
# # Create DataLoader
# test_loader = torch.utils.data.DataLoader(
#     test_dataset,
#     batch_size=128,
#     shuffle=False,
#     num_workers=2,
#     pin_memory=True
# )
# 
# # Model should be loaded from previous cells
# # model = ...  # Your trained classifier
# 
# # Attack parameters
# EPSILON = 8 / 255  # 8/255 for typical attacks, or 16/255 to match GMM
# STEP_SIZE = 2 / 255  # Step size for iterative attacks
# NUM_STEPS = 20  # Number of attack iterations
# BATCH_INDICES = range(100)  # Evaluate first 100 batches, or None for all
# 
# print("="*80)
# print("Evaluating PGD Attack on Clean-Correct Samples")
# print("="*80)
# 
# # Evaluate PGD attack (only attacks clean-correct samples)
# clean_acc_pgd, robust_acc_pgd, attack_success_pgd = evaluate_pgd(
#     model=model,
#     loader=test_loader,
#     device=DEVICE,
#     epsilon=EPSILON,
#     step_size=STEP_SIZE,
#     num_steps=NUM_STEPS,
#     batch_indices=BATCH_INDICES,
#     verbose=True
# )
# 
# print("\n" + "="*80)
# print("Evaluating CW Attack on Clean-Correct Samples")
# print("="*80)
# 
# # Evaluate CW attack (only attacks clean-correct samples)
# clean_acc_cw, robust_acc_cw, attack_success_cw = evaluate_cw(
#     model=model,
#     loader=test_loader,
#     device=DEVICE,
#     num_classes=num_classes,
#     epsilon=EPSILON,
#     step_size=STEP_SIZE,
#     num_steps=NUM_STEPS,
#     margin=2.0,  # CW margin parameter (kappa)
#     batch_indices=BATCH_INDICES,
#     verbose=True
# )
# 
# # Compare results
# print("\n" + "="*80)
# print("Attack Comparison Summary (on Clean-Correct Set)")
# print("="*80)
# print(f"Dataset: {DATASET}")
# print(f"Epsilon: {EPSILON:.4f} ({EPSILON*255:.1f}/255)")
# print(f"Num Steps: {NUM_STEPS}")
# print(f"\nClean Accuracy: {clean_acc_pgd:.2f}%")
# print(f"\nRobust Accuracy:")
# print(f"  PGD:  {robust_acc_pgd:.2f}%")
# print(f"  CW:   {robust_acc_cw:.2f}%")
# print(f"\nAttack Success Rate:")
# print(f"  PGD:  {attack_success_pgd:.2f}%")
# print(f"  CW:   {attack_success_cw:.2f}%")
# 
# # You can also generate individual adversarial examples
# # Example: Generate PGD adversarial examples for a single batch
# # x, y, _ = next(iter(test_loader))
# # x, y = x.to(DEVICE), y.to(DEVICE)
# # 
# # # Only attack clean-correct samples
# # with torch.no_grad():
# #     clean_pred = model(x).argmax(1)
# #     correct_mask = (clean_pred == y)
# # 
# # x_correct = x[correct_mask]
# # y_correct = y[correct_mask]
# # 
# # x_adv_pgd = pgd_attack(model, x_correct, y_correct, epsilon=EPSILON, step_size=STEP_SIZE, num_steps=NUM_STEPS)
# # x_adv_cw = cw_attack(model, x_correct, y_correct, num_classes=num_classes, epsilon=EPSILON, step_size=STEP_SIZE, num_steps=NUM_STEPS)