In [2]:
!pip install -q torch torchvision matplotlib seaborn scikit-learn

print("✓ All packages installed!")
print("\nVerifying GPU availability...")
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")



✓ All packages installed!

Verifying GPU availability...
PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: Tesla T4
GPU Memory: 15.83 GB


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from scipy.stats import pearsonr
import copy
from tqdm.auto import tqdm
import pickle
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Configure matplotlib
plt.style.use('default')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10

print("✓ Libraries imported successfully!")

✓ Libraries imported successfully!


In [4]:
class VanillaCNN(nn.Module):
    """Simple CNN for comparison"""
    def __init__(self, num_classes=10):
        super(VanillaCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


class ResidualBlock(nn.Module):
    """Residual block with skip connection"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class SimpleResNet(nn.Module):
    """Simple ResNet for CIFAR-10"""
    def __init__(self, num_classes=10):
        super(SimpleResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


print("✓ Model architectures defined!")
print(f"  - Vanilla CNN: {sum(p.numel() for p in VanillaCNN().parameters())/1e6:.2f}M parameters")
print(f"  - ResNet: {sum(p.numel() for p in SimpleResNet().parameters())/1e6:.2f}M parameters")



✓ Model architectures defined!
  - Vanilla CNN: 2.47M parameters
  - ResNet: 2.78M parameters


In [5]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from scipy.stats import pearsonr
import copy
from tqdm import tqdm
import pickle

# ============================================================================
# EFFICIENT LANDSCAPE PROBING METHODS
# ============================================================================

class LandscapeAnalyzer:
    """Core class for analyzing loss landscape geometry"""

    def __init__(self, model, criterion, device='cuda'):
        self.model = model
        self.criterion = criterion
        self.device = device

    def compute_loss_and_gradient(self, dataloader, subset_size=None):
        """Compute loss and gradient on dataset"""
        self.model.eval()
        total_loss = 0.0
        gradients = {name: torch.zeros_like(param)
                    for name, param in self.model.named_parameters()}

        count = 0
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            if subset_size and batch_idx >= subset_size:
                break

            inputs, targets = inputs.to(self.device), targets.to(self.device)

            outputs = self.model(inputs)
            loss = self.criterion(outputs, targets)
            total_loss += loss.item() * inputs.size(0)

            # Compute gradients
            self.model.zero_grad()
            loss.backward()

            for name, param in self.model.named_parameters():
                if param.grad is not None:
                    gradients[name] += param.grad.data

            count += inputs.size(0)

        avg_loss = total_loss / count
        for name in gradients:
            gradients[name] /= len(dataloader)

        return avg_loss, gradients

    def lanczos_hessian_spectrum(self, dataloader, num_eigenvalues=20,
                                 num_iterations=100, subset_size=10):
        """
        Compute top eigenvalues of Hessian using Lanczos algorithm
        Memory-efficient: doesn't materialize full Hessian
        """
        print(f"Computing Hessian spectrum using Lanczos (k={num_eigenvalues})...")

        # Get parameter vector
        params = [p for p in self.model.parameters() if p.requires_grad]
        param_shapes = [p.shape for p in params]

        def flatten_params(params_list):
            return torch.cat([p.view(-1) for p in params_list])

        def unflatten_to_params(flat_vec):
            result = []
            offset = 0
            for shape in param_shapes:
                numel = np.prod(shape)
                result.append(flat_vec[offset:offset+numel].view(shape))
                offset += numel
            return result

        def hessian_vector_product(vector):
            """Compute Hv using finite differences of gradients"""
            self.model.zero_grad()

            # Compute gradient at current point
            _, grad1 = self.compute_loss_and_gradient(dataloader, subset_size)
            grad1_vec = flatten_params([grad1[name] for name, _ in self.model.named_parameters()])

            # Small perturbation
            epsilon = 1e-3
            vector_as_params = unflatten_to_params(vector)

            # Temporarily perturb parameters
            with torch.no_grad():
                for param, delta in zip(params, vector_as_params):
                    param.add_(delta, alpha=epsilon)

            # Compute gradient at perturbed point
            _, grad2 = self.compute_loss_and_gradient(dataloader, subset_size)
            grad2_vec = flatten_params([grad2[name] for name, _ in self.model.named_parameters()])

            # Restore parameters
            with torch.no_grad():
                for param, delta in zip(params, vector_as_params):
                    param.add_(delta, alpha=-epsilon)

            # Hv ≈ (g(θ+εv) - g(θ)) / ε
            hv = (grad2_vec - grad1_vec) / epsilon
            return hv

        # Lanczos iteration
        param_vec = flatten_params(params)
        n = param_vec.numel()

        # Random starting vector
        v = torch.randn_like(param_vec)
        v = v / torch.norm(v)

        # Tridiagonal matrix
        alpha = []
        beta = [0]
        V = [v]

        for i in range(min(num_eigenvalues, num_iterations)):
            # Compute Hv
            w = hessian_vector_product(V[-1])

            # Orthogonalize
            w = w - beta[-1] * (V[-2] if len(V) > 1 else 0)
            alpha.append(torch.dot(w, V[-1]).item())
            w = w - alpha[-1] * V[-1]

            beta_new = torch.norm(w).item()
            beta.append(beta_new)

            if beta_new < 1e-10:
                break

            v_new = w / beta_new
            V.append(v_new)

            if (i + 1) % 10 == 0:
                print(f"  Lanczos iteration {i+1}/{num_eigenvalues}")

        # Build tridiagonal matrix and compute eigenvalues
        T = np.diag(alpha) + np.diag(beta[1:-1], k=1) + np.diag(beta[1:-1], k=-1)
        eigenvalues = np.linalg.eigvalsh(T)
        eigenvalues = sorted(eigenvalues, reverse=True)

        return eigenvalues

    def compute_sharpness(self, dataloader, rho=0.05, subset_size=10):
        """
        Compute sharpness: max loss in ρ-ball around current parameters
        Uses SAM-style ascent to find adversarial direction
        """
        print(f"Computing sharpness with ρ={rho}...")

        # Current loss
        loss_current, _ = self.compute_loss_and_gradient(dataloader, subset_size)

        # Compute adversarial perturbation (gradient ascent direction)
        _, gradients = self.compute_loss_and_gradient(dataloader, subset_size)

        # Normalize gradient to get direction
        grad_norm = 0.0
        for name, param in self.model.named_parameters():
            if name in gradients:
                grad_norm += torch.sum(gradients[name] ** 2).item()
        grad_norm = np.sqrt(grad_norm)

        # Store original parameters
        original_params = {name: param.clone()
                          for name, param in self.model.named_parameters()}

        # Perturb in adversarial direction
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                if name in gradients and grad_norm > 0:
                    perturbation = (rho / grad_norm) * gradients[name]
                    param.add_(perturbation)

        # Compute loss at perturbed point
        loss_perturbed, _ = self.compute_loss_and_gradient(dataloader, subset_size)

        # Restore original parameters
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.copy_(original_params[name])

        sharpness = loss_perturbed - loss_current
        return sharpness

    def mode_connectivity(self, model2, dataloader, num_points=20):
        """
        Analyze mode connectivity between self.model and model2
        Returns: losses along linear interpolation path
        """
        print(f"Analyzing mode connectivity ({num_points} points)...")

        alphas = np.linspace(0, 1, num_points)
        losses = []
        accuracies = []

        # Store original parameters
        original_params = {name: param.clone()
                          for name, param in self.model.named_parameters()}
        model2_params = {name: param.clone()
                        for name, param in model2.named_parameters()}

        for alpha in tqdm(alphas, desc="Interpolation"):
            # Interpolate parameters: θ(α) = (1-α)θ₁ + αθ₂
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    if name in model2_params:
                        param.copy_((1 - alpha) * original_params[name] +
                                   alpha * model2_params[name])

            # Evaluate
            self.model.eval()
            total_loss = 0.0
            correct = 0
            total = 0

            with torch.no_grad():
                for inputs, targets in dataloader:
                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    outputs = self.model(inputs)
                    loss = self.criterion(outputs, targets)
                    total_loss += loss.item() * inputs.size(0)

                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()

            losses.append(total_loss / total)
            accuracies.append(100. * correct / total)

        # Restore original parameters
        with torch.no_grad():
            for name, param in self.model.named_parameters():
                param.copy_(original_params[name])

        return alphas, losses, accuracies

    def visualize_loss_surface_2d(self, dataloader, direction1=None, direction2=None,
                                  range_=1.0, resolution=20, subset_size=5):
        """
        Visualize 2D slice of loss surface around current parameters
        If directions not provided, uses random directions
        """
        print(f"Computing 2D loss surface ({resolution}x{resolution} grid)...")

        # Get random directions if not provided
        if direction1 is None or direction2 is None:
            params = [p for p in self.model.parameters() if p.requires_grad]
            direction1 = [torch.randn_like(p) for p in params]
            direction2 = [torch.randn_like(p) for p in params]

            # Normalize directions
            norm1 = np.sqrt(sum([torch.sum(d**2).item() for d in direction1]))
            norm2 = np.sqrt(sum([torch.sum(d**2).item() for d in direction2]))
            direction1 = [d / norm1 for d in direction1]
            direction2 = [d / norm2 for d in direction2]

            # Gram-Schmidt orthogonalization
            dot_product = sum([torch.sum(d1 * d2).item()
                             for d1, d2 in zip(direction1, direction2)])
            direction2 = [d2 - dot_product * d1
                         for d1, d2 in zip(direction1, direction2)]
            norm2 = np.sqrt(sum([torch.sum(d**2).item() for d in direction2]))
            direction2 = [d / norm2 for d in direction2]

        # Store original parameters
        original_params = [p.clone() for p in self.model.parameters() if p.requires_grad]
        params = [p for p in self.model.parameters() if p.requires_grad]

        # Create grid
        alphas = np.linspace(-range_, range_, resolution)
        betas = np.linspace(-range_, range_, resolution)

        losses = np.zeros((resolution, resolution))

        for i, alpha in enumerate(tqdm(alphas, desc="Computing surface")):
            for j, beta in enumerate(betas):
                # Set parameters: θ = θ₀ + α*d₁ + β*d₂
                with torch.no_grad():
                    for k, param in enumerate(params):
                        param.copy_(original_params[k] +
                                   alpha * direction1[k] +
                                   beta * direction2[k])

                # Compute loss
                loss, _ = self.compute_loss_and_gradient(dataloader, subset_size)
                losses[i, j] = loss

        # Restore original parameters
        with torch.no_grad():
            for k, param in enumerate(params):
                param.copy_(original_params[k])

        return alphas, betas, losses



class VanillaCNN(nn.Module):
    """Simple CNN for comparison"""
    def __init__(self, num_classes=10):
        super(VanillaCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


class ResidualBlock(nn.Module):
    """Residual block with skip connection"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class SimpleResNet(nn.Module):
    """Simple ResNet for CIFAR-10"""
    def __init__(self, num_classes=10):
        super(SimpleResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



def train_model(model, trainloader, testloader, device, epochs=20, lr=0.01):
    """Train a model and return training history"""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    history = {
        'train_loss': [], 'train_acc': [],
        'test_loss': [], 'test_acc': []
    }

    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{epochs}')
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            pbar.set_postfix({'loss': train_loss/total, 'acc': 100.*correct/total})

        history['train_loss'].append(train_loss / total)
        history['train_acc'].append(100. * correct / total)

        # Testing
        model.eval()
        test_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, targets in testloader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        history['test_loss'].append(test_loss / total)
        history['test_acc'].append(100. * correct / total)

        print(f'Epoch {epoch+1}: Test Loss: {test_loss/total:.4f}, Test Acc: {100.*correct/total:.2f}%')

        scheduler.step()

    return history


def evaluate_model(model, dataloader, device):
    """Evaluate model accuracy and loss"""
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return total_loss / total, 100. * correct / total



def run_complete_analysis():
    """
    Main function to run complete landscape analysis

    This will:
    1. Train models with different architectures
    2. Compute landscape metrics (sharpness, Hessian spectrum, etc.)
    3. Analyze mode connectivity
    4. Visualize loss surfaces
    5. Correlate geometry with generalization
    """

    print("="*80)
    print("LOSS LANDSCAPE GEOMETRY & OPTIMIZATION DYNAMICS")
    print("Complete Analysis Framework")
    print("="*80)

    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}")

    # Data loading
    print("\n[1/7] Loading CIFAR-10 dataset...")
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                           download=True, transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                          download=True, transform=transform_test)
    testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

    # For landscape analysis (smaller subset for efficiency)
    subset_indices = np.random.choice(len(testset), 1000, replace=False)
    subset = torch.utils.data.Subset(testset, subset_indices)
    subsetloader = DataLoader(subset, batch_size=100, shuffle=False, num_workers=2)

    # Train models
    print("\n[2/7] Training models...")

    print("\n  Training Vanilla CNN...")
    vanilla_model = VanillaCNN(num_classes=10)
    vanilla_history = train_model(vanilla_model, trainloader, testloader, device, epochs=20)

    print("\n  Training ResNet...")
    resnet_model = SimpleResNet(num_classes=10)
    resnet_history = train_model(resnet_model, trainloader, testloader, device, epochs=20)

    # Train a second instance for mode connectivity
    print("\n  Training second ResNet instance for mode connectivity...")
    resnet_model2 = SimpleResNet(num_classes=10)
    resnet_history2 = train_model(resnet_model2, trainloader, testloader, device, epochs=20)

    # Landscape analysis
    results = {}

    print("\n[3/7] Computing landscape metrics for Vanilla CNN...")
    vanilla_analyzer = LandscapeAnalyzer(vanilla_model, nn.CrossEntropyLoss(), device)

    vanilla_sharpness = vanilla_analyzer.compute_sharpness(subsetloader, rho=0.05)
    vanilla_eigenvalues = vanilla_analyzer.lanczos_hessian_spectrum(subsetloader, num_eigenvalues=20)
    vanilla_loss, vanilla_acc = evaluate_model(vanilla_model, testloader, device)

    results['vanilla'] = {
        'sharpness': vanilla_sharpness,
        'eigenvalues': vanilla_eigenvalues,
        'test_loss': vanilla_loss,
        'test_acc': vanilla_acc,
        'history': vanilla_history
    }

    print("\n[4/7] Computing landscape metrics for ResNet...")
    resnet_analyzer = LandscapeAnalyzer(resnet_model, nn.CrossEntropyLoss(), device)

    resnet_sharpness = resnet_analyzer.compute_sharpness(subsetloader, rho=0.05)
    resnet_eigenvalues = resnet_analyzer.lanczos_hessian_spectrum(subsetloader, num_eigenvalues=20)
    resnet_loss, resnet_acc = evaluate_model(resnet_model, testloader, device)

    results['resnet'] = {
        'sharpness': resnet_sharpness,
        'eigenvalues': resnet_eigenvalues,
        'test_loss': resnet_loss,
        'test_acc': resnet_acc,
        'history': resnet_history
    }

    print("\n[5/7] Analyzing mode connectivity...")
    alphas, losses, accuracies = resnet_analyzer.mode_connectivity(
        resnet_model2, testloader, num_points=15
    )

    results['mode_connectivity'] = {
        'alphas': alphas,
        'losses': losses,
        'accuracies': accuracies
    }

    print("\n[6/7] Computing 2D loss surface visualizations...")
    vanilla_x, vanilla_y, vanilla_surface = vanilla_analyzer.visualize_loss_surface_2d(
        subsetloader, range_=0.5, resolution=15
    )

    resnet_x, resnet_y, resnet_surface = resnet_analyzer.visualize_loss_surface_2d(
        subsetloader, range_=0.5, resolution=15
    )

    results['vanilla']['surface'] = (vanilla_x, vanilla_y, vanilla_surface)
    results['resnet']['surface'] = (resnet_x, resnet_y, resnet_surface)

    print("\n[7/7] Saving results...")
    with open('landscape_analysis_results.pkl', 'wb') as f:
        pickle.dump(results, f)

    # Print summary
    print("\n" + "="*80)
    print("ANALYSIS SUMMARY")
    print("="*80)
    print(f"\nVanilla CNN:")
    print(f"  Test Accuracy: {vanilla_acc:.2f}%")
    print(f"  Test Loss: {vanilla_loss:.4f}")
    print(f"  Sharpness (ρ=0.05): {vanilla_sharpness:.6f}")
    print(f"  Max Hessian eigenvalue: {vanilla_eigenvalues[0]:.4f}")
    print(f"  Min Hessian eigenvalue: {vanilla_eigenvalues[-1]:.4f}")

    print(f"\nResNet:")
    print(f"  Test Accuracy: {resnet_acc:.2f}%")
    print(f"  Test Loss: {resnet_loss:.4f}")
    print(f"  Sharpness (ρ=0.05): {resnet_sharpness:.6f}")
    print(f"  Max Hessian eigenvalue: {resnet_eigenvalues[0]:.4f}")
    print(f"  Min Hessian eigenvalue: {resnet_eigenvalues[-1]:.4f}")

    print(f"\nMode Connectivity:")
    print(f"  Max barrier height: {max(losses) - min(losses):.4f}")
    print(f"  Min accuracy along path: {min(accuracies):.2f}%")

    return results


if __name__ == "__main__":

    results = run_complete_analysis()

    print("\n" + "="*80)
    print("Analysis complete! Results saved to 'landscape_analysis_results.pkl'")
    print("Run the visualization script next to generate plots.")
    print("="*80)

LOSS LANDSCAPE GEOMETRY & OPTIMIZATION DYNAMICS
Complete Analysis Framework

Using device: cuda

[1/7] Loading CIFAR-10 dataset...


100%|██████████| 170M/170M [09:37<00:00, 295kB/s]



[2/7] Training models...

  Training Vanilla CNN...


Epoch 1/20: 100%|██████████| 391/391 [00:19<00:00, 19.98it/s, loss=1.82, acc=33]


Epoch 1: Test Loss: 1.4353, Test Acc: 47.67%


Epoch 2/20: 100%|██████████| 391/391 [00:19<00:00, 20.47it/s, loss=1.43, acc=47.7]


Epoch 2: Test Loss: 1.2193, Test Acc: 55.29%


Epoch 3/20: 100%|██████████| 391/391 [00:19<00:00, 19.62it/s, loss=1.22, acc=56.2]


Epoch 3: Test Loss: 1.0408, Test Acc: 63.17%


Epoch 4/20: 100%|██████████| 391/391 [00:20<00:00, 19.09it/s, loss=1.06, acc=62.3]


Epoch 4: Test Loss: 0.8945, Test Acc: 68.58%


Epoch 5/20: 100%|██████████| 391/391 [00:18<00:00, 20.96it/s, loss=0.956, acc=66.3]


Epoch 5: Test Loss: 0.8116, Test Acc: 71.88%


Epoch 6/20: 100%|██████████| 391/391 [00:19<00:00, 20.07it/s, loss=0.879, acc=69.3]


Epoch 6: Test Loss: 0.7674, Test Acc: 73.28%


Epoch 7/20: 100%|██████████| 391/391 [00:18<00:00, 21.36it/s, loss=0.822, acc=71.4]


Epoch 7: Test Loss: 0.7283, Test Acc: 74.80%


Epoch 8/20: 100%|██████████| 391/391 [00:18<00:00, 20.96it/s, loss=0.76, acc=73.2]


Epoch 8: Test Loss: 0.7124, Test Acc: 75.18%


Epoch 9/20: 100%|██████████| 391/391 [00:19<00:00, 19.74it/s, loss=0.724, acc=74.6]


Epoch 9: Test Loss: 0.6742, Test Acc: 76.72%


Epoch 10/20: 100%|██████████| 391/391 [00:18<00:00, 21.08it/s, loss=0.691, acc=76]


Epoch 10: Test Loss: 0.6165, Test Acc: 78.89%


Epoch 11/20: 100%|██████████| 391/391 [00:19<00:00, 20.52it/s, loss=0.65, acc=77.4]


Epoch 11: Test Loss: 0.6076, Test Acc: 79.33%


Epoch 12/20: 100%|██████████| 391/391 [00:19<00:00, 19.81it/s, loss=0.624, acc=78.3]


Epoch 12: Test Loss: 0.5916, Test Acc: 79.54%


Epoch 13/20: 100%|██████████| 391/391 [00:18<00:00, 20.79it/s, loss=0.597, acc=79.3]


Epoch 13: Test Loss: 0.5698, Test Acc: 80.56%


Epoch 14/20: 100%|██████████| 391/391 [00:18<00:00, 20.65it/s, loss=0.574, acc=80]


Epoch 14: Test Loss: 0.5445, Test Acc: 80.98%


Epoch 15/20: 100%|██████████| 391/391 [00:19<00:00, 20.28it/s, loss=0.555, acc=80.8]


Epoch 15: Test Loss: 0.5478, Test Acc: 81.04%


Epoch 16/20: 100%|██████████| 391/391 [00:18<00:00, 21.26it/s, loss=0.539, acc=81.3]


Epoch 16: Test Loss: 0.5385, Test Acc: 81.61%


Epoch 17/20: 100%|██████████| 391/391 [00:18<00:00, 20.97it/s, loss=0.521, acc=81.9]


Epoch 17: Test Loss: 0.5247, Test Acc: 81.93%


Epoch 18/20: 100%|██████████| 391/391 [00:18<00:00, 20.63it/s, loss=0.51, acc=82.1]


Epoch 18: Test Loss: 0.5151, Test Acc: 82.26%


Epoch 19/20: 100%|██████████| 391/391 [00:17<00:00, 21.74it/s, loss=0.506, acc=82.5]


Epoch 19: Test Loss: 0.5127, Test Acc: 82.57%


Epoch 20/20: 100%|██████████| 391/391 [00:19<00:00, 20.45it/s, loss=0.502, acc=82.7]


Epoch 20: Test Loss: 0.5088, Test Acc: 82.43%

  Training ResNet...


Epoch 1/20: 100%|██████████| 391/391 [00:35<00:00, 10.87it/s, loss=1.41, acc=48.2]


Epoch 1: Test Loss: 1.2249, Test Acc: 58.05%


Epoch 2/20: 100%|██████████| 391/391 [00:34<00:00, 11.23it/s, loss=0.946, acc=66.3]


Epoch 2: Test Loss: 1.1949, Test Acc: 62.43%


Epoch 3/20: 100%|██████████| 391/391 [00:35<00:00, 11.14it/s, loss=0.74, acc=74]


Epoch 3: Test Loss: 0.6899, Test Acc: 75.91%


Epoch 4/20: 100%|██████████| 391/391 [00:35<00:00, 11.17it/s, loss=0.621, acc=78.4]


Epoch 4: Test Loss: 0.7209, Test Acc: 75.08%


Epoch 5/20: 100%|██████████| 391/391 [00:35<00:00, 11.13it/s, loss=0.53, acc=81.6]


Epoch 5: Test Loss: 0.7877, Test Acc: 74.82%


Epoch 6/20: 100%|██████████| 391/391 [00:35<00:00, 11.09it/s, loss=0.476, acc=83.5]


Epoch 6: Test Loss: 0.6213, Test Acc: 80.06%


Epoch 7/20: 100%|██████████| 391/391 [00:35<00:00, 11.16it/s, loss=0.424, acc=85.4]


Epoch 7: Test Loss: 0.6372, Test Acc: 79.04%


Epoch 8/20: 100%|██████████| 391/391 [00:34<00:00, 11.22it/s, loss=0.383, acc=86.7]


Epoch 8: Test Loss: 0.5131, Test Acc: 83.12%


Epoch 9/20: 100%|██████████| 391/391 [00:35<00:00, 11.09it/s, loss=0.35, acc=87.9]


Epoch 9: Test Loss: 0.5236, Test Acc: 82.87%


Epoch 10/20: 100%|██████████| 391/391 [00:35<00:00, 11.01it/s, loss=0.316, acc=89.1]


Epoch 10: Test Loss: 0.6296, Test Acc: 79.95%


Epoch 11/20: 100%|██████████| 391/391 [00:34<00:00, 11.18it/s, loss=0.283, acc=90.2]


Epoch 11: Test Loss: 0.4428, Test Acc: 85.90%


Epoch 12/20: 100%|██████████| 391/391 [00:34<00:00, 11.19it/s, loss=0.257, acc=91]


Epoch 12: Test Loss: 0.4278, Test Acc: 86.48%


Epoch 13/20: 100%|██████████| 391/391 [00:34<00:00, 11.18it/s, loss=0.229, acc=92.3]


Epoch 13: Test Loss: 0.3751, Test Acc: 87.91%


Epoch 14/20: 100%|██████████| 391/391 [00:35<00:00, 11.14it/s, loss=0.203, acc=93.2]


Epoch 14: Test Loss: 0.3409, Test Acc: 88.80%


Epoch 15/20: 100%|██████████| 391/391 [00:35<00:00, 11.11it/s, loss=0.179, acc=94]


Epoch 15: Test Loss: 0.3458, Test Acc: 88.44%


Epoch 16/20: 100%|██████████| 391/391 [00:35<00:00, 11.05it/s, loss=0.161, acc=94.8]


Epoch 16: Test Loss: 0.3104, Test Acc: 89.81%


Epoch 17/20: 100%|██████████| 391/391 [00:35<00:00, 11.14it/s, loss=0.145, acc=95.2]


Epoch 17: Test Loss: 0.3101, Test Acc: 89.84%


Epoch 18/20: 100%|██████████| 391/391 [00:35<00:00, 11.14it/s, loss=0.131, acc=95.8]


Epoch 18: Test Loss: 0.3033, Test Acc: 90.28%


Epoch 19/20: 100%|██████████| 391/391 [00:34<00:00, 11.22it/s, loss=0.125, acc=96.1]


Epoch 19: Test Loss: 0.2945, Test Acc: 90.37%


Epoch 20/20: 100%|██████████| 391/391 [00:35<00:00, 11.11it/s, loss=0.121, acc=96.2]


Epoch 20: Test Loss: 0.2944, Test Acc: 90.42%

  Training second ResNet instance for mode connectivity...


Epoch 1/20: 100%|██████████| 391/391 [00:34<00:00, 11.17it/s, loss=1.43, acc=47.5]


Epoch 1: Test Loss: 1.3985, Test Acc: 52.78%


Epoch 2/20: 100%|██████████| 391/391 [00:35<00:00, 11.16it/s, loss=0.938, acc=66.6]


Epoch 2: Test Loss: 0.9752, Test Acc: 68.17%


Epoch 3/20: 100%|██████████| 391/391 [00:35<00:00, 11.14it/s, loss=0.728, acc=74.4]


Epoch 3: Test Loss: 0.8979, Test Acc: 71.51%


Epoch 4/20: 100%|██████████| 391/391 [00:35<00:00, 11.16it/s, loss=0.61, acc=78.8]


Epoch 4: Test Loss: 0.7500, Test Acc: 75.75%


Epoch 5/20: 100%|██████████| 391/391 [00:34<00:00, 11.19it/s, loss=0.536, acc=81.3]


Epoch 5: Test Loss: 0.5853, Test Acc: 80.37%


Epoch 6/20: 100%|██████████| 391/391 [00:34<00:00, 11.19it/s, loss=0.476, acc=83.6]


Epoch 6: Test Loss: 0.6183, Test Acc: 79.07%


Epoch 7/20: 100%|██████████| 391/391 [00:35<00:00, 11.17it/s, loss=0.425, acc=85.2]


Epoch 7: Test Loss: 0.5576, Test Acc: 81.71%


Epoch 8/20: 100%|██████████| 391/391 [00:35<00:00, 11.17it/s, loss=0.386, acc=86.7]


Epoch 8: Test Loss: 0.5527, Test Acc: 81.84%


Epoch 9/20: 100%|██████████| 391/391 [00:35<00:00, 11.15it/s, loss=0.345, acc=88]


Epoch 9: Test Loss: 0.4654, Test Acc: 84.33%


Epoch 10/20: 100%|██████████| 391/391 [00:35<00:00, 11.12it/s, loss=0.311, acc=89.2]


Epoch 10: Test Loss: 0.5063, Test Acc: 83.20%


Epoch 11/20: 100%|██████████| 391/391 [00:35<00:00, 11.17it/s, loss=0.281, acc=90.3]


Epoch 11: Test Loss: 0.4321, Test Acc: 85.83%


Epoch 12/20: 100%|██████████| 391/391 [00:35<00:00, 11.16it/s, loss=0.251, acc=91.5]


Epoch 12: Test Loss: 0.4179, Test Acc: 86.58%


Epoch 13/20: 100%|██████████| 391/391 [00:34<00:00, 11.20it/s, loss=0.226, acc=92.4]


Epoch 13: Test Loss: 0.3717, Test Acc: 87.90%


Epoch 14/20: 100%|██████████| 391/391 [00:35<00:00, 11.16it/s, loss=0.201, acc=93.2]


Epoch 14: Test Loss: 0.3336, Test Acc: 89.07%


Epoch 15/20: 100%|██████████| 391/391 [00:35<00:00, 11.15it/s, loss=0.177, acc=94.1]


Epoch 15: Test Loss: 0.3276, Test Acc: 89.24%


Epoch 16/20: 100%|██████████| 391/391 [00:34<00:00, 11.18it/s, loss=0.16, acc=94.7]


Epoch 16: Test Loss: 0.3178, Test Acc: 89.73%


Epoch 17/20: 100%|██████████| 391/391 [00:35<00:00, 11.16it/s, loss=0.143, acc=95.3]


Epoch 17: Test Loss: 0.3110, Test Acc: 89.70%


Epoch 18/20: 100%|██████████| 391/391 [00:34<00:00, 11.18it/s, loss=0.132, acc=95.8]


Epoch 18: Test Loss: 0.3002, Test Acc: 90.15%


Epoch 19/20: 100%|██████████| 391/391 [00:35<00:00, 11.13it/s, loss=0.123, acc=96.2]


Epoch 19: Test Loss: 0.2937, Test Acc: 90.30%


Epoch 20/20: 100%|██████████| 391/391 [00:34<00:00, 11.18it/s, loss=0.117, acc=96.4]


Epoch 20: Test Loss: 0.2924, Test Acc: 90.42%

[3/7] Computing landscape metrics for Vanilla CNN...
Computing sharpness with ρ=0.05...
Computing Hessian spectrum using Lanczos (k=20)...
  Lanczos iteration 10/20
  Lanczos iteration 20/20

[4/7] Computing landscape metrics for ResNet...
Computing sharpness with ρ=0.05...
Computing Hessian spectrum using Lanczos (k=20)...
  Lanczos iteration 10/20
  Lanczos iteration 20/20

[5/7] Analyzing mode connectivity...
Analyzing mode connectivity (15 points)...


Interpolation: 100%|██████████| 15/15 [00:41<00:00,  2.76s/it]



[6/7] Computing 2D loss surface visualizations...
Computing 2D loss surface (15x15 grid)...


Computing surface: 100%|██████████| 15/15 [00:57<00:00,  3.85s/it]


Computing 2D loss surface (15x15 grid)...


Computing surface: 100%|██████████| 15/15 [01:41<00:00,  6.77s/it]


[7/7] Saving results...

ANALYSIS SUMMARY

Vanilla CNN:
  Test Accuracy: 82.43%
  Test Loss: 0.5088
  Sharpness (ρ=0.05): 0.071657
  Max Hessian eigenvalue: 107.8411
  Min Hessian eigenvalue: -95.4789

ResNet:
  Test Accuracy: 90.42%
  Test Loss: 0.2944
  Sharpness (ρ=0.05): 1.287901
  Max Hessian eigenvalue: 1613.7731
  Min Hessian eigenvalue: -283.3368

Mode Connectivity:
  Max barrier height: 16.6086
  Min accuracy along path: 10.00%

Analysis complete! Results saved to 'landscape_analysis_results.pkl'
Run the visualization script next to generate plots.





In [6]:

import pickle
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from scipy.stats import pearsonr
import seaborn as sns

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

def load_results():
    """Load analysis results"""
    with open('landscape_analysis_results.pkl', 'rb') as f:
        results = pickle.load(f)
    return results


def plot_training_curves(results):
    """Plot training and test curves"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    models = ['vanilla', 'resnet']
    titles = ['Vanilla CNN', 'ResNet']

    for idx, (model, title) in enumerate(zip(models, titles)):
        history = results[model]['history']
        epochs = range(1, len(history['train_loss']) + 1)

        # Loss curves
        axes[0, idx].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
        axes[0, idx].plot(epochs, history['test_loss'], 'r-', label='Test Loss', linewidth=2)
        axes[0, idx].set_xlabel('Epoch', fontsize=12)
        axes[0, idx].set_ylabel('Loss', fontsize=12)
        axes[0, idx].set_title(f'{title} - Loss Curves', fontsize=14, fontweight='bold')
        axes[0, idx].legend(fontsize=11)
        axes[0, idx].grid(True, alpha=0.3)

        # Accuracy curves
        axes[1, idx].plot(epochs, history['train_acc'], 'b-', label='Train Accuracy', linewidth=2)
        axes[1, idx].plot(epochs, history['test_acc'], 'r-', label='Test Accuracy', linewidth=2)
        axes[1, idx].set_xlabel('Epoch', fontsize=12)
        axes[1, idx].set_ylabel('Accuracy (%)', fontsize=12)
        axes[1, idx].set_title(f'{title} - Accuracy Curves', fontsize=14, fontweight='bold')
        axes[1, idx].legend(fontsize=11)
        axes[1, idx].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_curves.png', dpi=300, bbox_inches='tight')
    print("Saved: training_curves.png")
    plt.close()


def plot_hessian_spectrum(results):
    """Plot Hessian eigenvalue spectrum"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    models = ['vanilla', 'resnet']
    titles = ['Vanilla CNN', 'ResNet']
    colors = ['#e74c3c', '#3498db']

    for idx, (model, title, color) in enumerate(zip(models, titles, colors)):
        eigenvalues = results[model]['eigenvalues']
        indices = range(1, len(eigenvalues) + 1)

        axes[idx].bar(indices, eigenvalues, color=color, alpha=0.7, edgecolor='black')
        axes[idx].axhline(y=0, color='black', linestyle='--', linewidth=1)
        axes[idx].set_xlabel('Eigenvalue Index', fontsize=12)
        axes[idx].set_ylabel('Eigenvalue Magnitude', fontsize=12)
        axes[idx].set_title(f'{title} - Hessian Spectrum', fontsize=14, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, axis='y')

        # Add statistics
        max_eig = eigenvalues[0]
        min_eig = eigenvalues[-1]
        n_negative = sum(1 for e in eigenvalues if e < 0)

        textstr = f'Max λ: {max_eig:.2f}\nMin λ: {min_eig:.2f}\n# Negative: {n_negative}'
        axes[idx].text(0.98, 0.97, textstr, transform=axes[idx].transAxes,
                      fontsize=10, verticalalignment='top', horizontalalignment='right',
                      bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    plt.savefig('hessian_spectrum.png', dpi=300, bbox_inches='tight')
    print("Saved: hessian_spectrum.png")
    plt.close()


def plot_comparison_metrics(results):

    fig, axes = plt.subplots(1, 3, figsize=(16, 5))

    models = ['Vanilla\nCNN', 'ResNet']

    # Sharpness
    sharpness_values = [results['vanilla']['sharpness'], results['resnet']['sharpness']]
    bars1 = axes[0].bar(models, sharpness_values, color=['#e74c3c', '#3498db'], alpha=0.7, edgecolor='black', linewidth=2)
    axes[0].set_ylabel('Sharpness (ρ=0.05)', fontsize=12)
    axes[0].set_title('Loss Sharpness Comparison', fontsize=14, fontweight='bold')
    axes[0].grid(True, alpha=0.3, axis='y')

    # Add value labels on bars
    for bar in bars1:
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.4f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

    # Max eigenvalue
    max_eig_values = [results['vanilla']['eigenvalues'][0], results['resnet']['eigenvalues'][0]]
    bars2 = axes[1].bar(models, max_eig_values, color=['#e74c3c', '#3498db'], alpha=0.7, edgecolor='black', linewidth=2)
    axes[1].set_ylabel('Max Eigenvalue (λₘₐₓ)', fontsize=12)
    axes[1].set_title('Maximum Hessian Eigenvalue', fontsize=14, fontweight='bold')
    axes[1].grid(True, alpha=0.3, axis='y')

    for bar in bars2:
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}', ha='center', va='bottom', fontsize=11, fontweight='bold')

    # Test accuracy
    acc_values = [results['vanilla']['test_acc'], results['resnet']['test_acc']]
    bars3 = axes[2].bar(models, acc_values, color=['#e74c3c', '#3498db'], alpha=0.7, edgecolor='black', linewidth=2)
    axes[2].set_ylabel('Test Accuracy (%)', fontsize=12)
    axes[2].set_title('Generalization Performance', fontsize=14, fontweight='bold')
    axes[2].set_ylim([min(acc_values)-5, 100])
    axes[2].grid(True, alpha=0.3, axis='y')

    for bar in bars3:
        height = bar.get_height()
        axes[2].text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')

    plt.tight_layout()
    plt.savefig('metrics_comparison.png', dpi=300, bbox_inches='tight')
    print("Saved: metrics_comparison.png")
    plt.close()


def plot_mode_connectivity(results):

    mc = results['mode_connectivity']

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Loss along path
    axes[0].plot(mc['alphas'], mc['losses'], 'b-o', linewidth=2, markersize=6)
    axes[0].set_xlabel('Interpolation Parameter α', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Loss Along Linear Interpolation Path', fontsize=14, fontweight='bold')
    axes[0].grid(True, alpha=0.3)
    axes[0].axvline(x=0, color='red', linestyle='--', alpha=0.5, label='Model 1')
    axes[0].axvline(x=1, color='green', linestyle='--', alpha=0.5, label='Model 2')
    axes[0].legend(fontsize=11)

    # Mark barrier
    max_loss_idx = np.argmax(mc['losses'])
    axes[0].plot(mc['alphas'][max_loss_idx], mc['losses'][max_loss_idx],
                'r*', markersize=15, label='Barrier Peak')

    barrier_height = max(mc['losses']) - min(mc['losses'])
    textstr = f'Barrier Height: {barrier_height:.4f}'
    axes[0].text(0.5, 0.95, textstr, transform=axes[0].transAxes,
                fontsize=11, verticalalignment='top', horizontalalignment='center',
                bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7))

    # Accuracy along path
    axes[1].plot(mc['alphas'], mc['accuracies'], 'g-o', linewidth=2, markersize=6)
    axes[1].set_xlabel('Interpolation Parameter α', fontsize=12)
    axes[1].set_ylabel('Test Accuracy (%)', fontsize=12)
    axes[1].set_title('Accuracy Along Linear Interpolation Path', fontsize=14, fontweight='bold')
    axes[1].grid(True, alpha=0.3)
    axes[1].axvline(x=0, color='red', linestyle='--', alpha=0.5, label='Model 1')
    axes[1].axvline(x=1, color='green', linestyle='--', alpha=0.5, label='Model 2')
    axes[1].legend(fontsize=11)

    min_acc_idx = np.argmin(mc['accuracies'])
    axes[1].plot(mc['alphas'][min_acc_idx], mc['accuracies'][min_acc_idx],
                'r*', markersize=15, label='Accuracy Dip')

    plt.tight_layout()
    plt.savefig('mode_connectivity.png', dpi=300, bbox_inches='tight')
    print("Saved: mode_connectivity.png")
    plt.close()


def plot_loss_surface_3d(results):
    """Plot 3D loss surface visualization"""
    fig = plt.figure(figsize=(16, 6))

    models = ['vanilla', 'resnet']
    titles = ['Vanilla CNN - Loss Surface', 'ResNet - Loss Surface']

    for idx, (model, title) in enumerate(zip(models, titles)):
        x, y, z = results[model]['surface']

        ax = fig.add_subplot(1, 2, idx+1, projection='3d')

        X, Y = np.meshgrid(x, y)

        # Plot surface
        surf = ax.plot_surface(X, Y, z.T, cmap=cm.viridis,
                              linewidth=0, antialiased=True, alpha=0.8)

        # Mark minimum
        min_idx = np.unravel_index(z.argmin(), z.shape)
        ax.scatter([x[min_idx[0]]], [y[min_idx[1]]], [z[min_idx[0], min_idx[1]]],
                  color='red', s=100, marker='*', label='Minimum')

        ax.set_xlabel('Direction 1', fontsize=11)
        ax.set_ylabel('Direction 2', fontsize=11)
        ax.set_zlabel('Loss', fontsize=11)
        ax.set_title(title, fontsize=13, fontweight='bold', pad=20)

        # Colorbar
        fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5)

        ax.view_init(elev=25, azim=45)

    plt.tight_layout()
    plt.savefig('loss_surface_3d.png', dpi=300, bbox_inches='tight')
    print("Saved: loss_surface_3d.png")
    plt.close()


def plot_loss_surface_contour(results):
    """Plot contour plots of loss surface"""
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    models = ['vanilla', 'resnet']
    titles = ['Vanilla CNN', 'ResNet']

    for idx, (model, title) in enumerate(zip(models, titles)):
        x, y, z = results[model]['surface']

        X, Y = np.meshgrid(x, y)

        # Contour plot
        contour = axes[idx].contourf(X, Y, z.T, levels=20, cmap='viridis')
        axes[idx].contour(X, Y, z.T, levels=20, colors='black', alpha=0.3, linewidths=0.5)

        # Mark minimum
        min_idx = np.unravel_index(z.argmin(), z.shape)
        axes[idx].plot(x[min_idx[0]], y[min_idx[1]], 'r*', markersize=15, label='Minimum')

        axes[idx].set_xlabel('Direction 1', fontsize=12)
        axes[idx].set_ylabel('Direction 2', fontsize=12)
        axes[idx].set_title(f'{title} - Loss Contours', fontsize=14, fontweight='bold')
        axes[idx].legend(fontsize=11)

        # Colorbar
        plt.colorbar(contour, ax=axes[idx])

    plt.tight_layout()
    plt.savefig('loss_surface_contour.png', dpi=300, bbox_inches='tight')
    print("Saved: loss_surface_contour.png")
    plt.close()


def plot_correlation_analysis(results):

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))


    models_data = {
        'Vanilla CNN': results['vanilla'],
        'ResNet': results['resnet']
    }

    sharpness = [data['sharpness'] for data in models_data.values()]
    test_acc = [data['test_acc'] for data in models_data.values()]
    max_eig = [data['eigenvalues'][0] for data in models_data.values()]

    # Sharpness vs Accuracy
    axes[0].scatter(sharpness, test_acc, s=200, alpha=0.7, c=['#e74c3c', '#3498db'],
                   edgecolors='black', linewidth=2)

    for i, (name, sharp, acc) in enumerate(zip(models_data.keys(), sharpness, test_acc)):
        axes[0].annotate(name, (sharp, acc), xytext=(10, -10),
                        textcoords='offset points', fontsize=10,
                        bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.5))

    # Trend line
    z = np.polyfit(sharpness, test_acc, 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(sharpness), max(sharpness), 100)
    axes[0].plot(x_trend, p(x_trend), "r--", alpha=0.8, linewidth=2, label='Trend')

    # Correlation coefficient
    if len(sharpness) > 1:
        corr, _ = pearsonr(sharpness, test_acc)
        axes[0].text(0.05, 0.95, f'Correlation: {corr:.3f}',
                    transform=axes[0].transAxes, fontsize=11,
                    verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    axes[0].set_xlabel('Sharpness', fontsize=12)
    axes[0].set_ylabel('Test Accuracy (%)', fontsize=12)
    axes[0].set_title('Sharpness vs Generalization', fontsize=14, fontweight='bold')
    axes[0].grid(True, alpha=0.3)
    axes[0].legend(fontsize=11)

    # Max Eigenvalue vs Accuracy
    axes[1].scatter(max_eig, test_acc, s=200, alpha=0.7, c=['#e74c3c', '#3498db'],
                   edgecolors='black', linewidth=2)

    for i, (name, eig, acc) in enumerate(zip(models_data.keys(), max_eig, test_acc)):
        axes[1].annotate(name, (eig, acc), xytext=(10, -10),
                        textcoords='offset points', fontsize=10,
                        bbox=dict(boxstyle='round,pad=0.5', facecolor='yellow', alpha=0.5))

    # Trend line
    z = np.polyfit(max_eig, test_acc, 1)
    p = np.poly1d(z)
    x_trend = np.linspace(min(max_eig), max(max_eig), 100)
    axes[1].plot(x_trend, p(x_trend), "r--", alpha=0.8, linewidth=2, label='Trend')

    # Correlation coefficient
    if len(max_eig) > 1:
        corr, _ = pearsonr(max_eig, test_acc)
        axes[1].text(0.05, 0.95, f'Correlation: {corr:.3f}',
                    transform=axes[1].transAxes, fontsize=11,
                    verticalalignment='top',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    axes[1].set_xlabel('Max Hessian Eigenvalue', fontsize=12)
    axes[1].set_ylabel('Test Accuracy (%)', fontsize=12)
    axes[1].set_title('Curvature vs Generalization', fontsize=14, fontweight='bold')
    axes[1].grid(True, alpha=0.3)
    axes[1].legend(fontsize=11)

    plt.tight_layout()
    plt.savefig('correlation_analysis.png', dpi=300, bbox_inches='tight')
    print("Saved: correlation_analysis.png")
    plt.close()


def generate_summary_table(results):
    """Generate and save summary table as image"""
    fig, ax = plt.subplots(figsize=(12, 4))
    ax.axis('tight')
    ax.axis('off')

    # Prepare data
    metrics = [
        'Test Accuracy (%)',
        'Test Loss',
        'Sharpness (ρ=0.05)',
        'Max Eigenvalue (λₚₐₓ)',
        'Min Eigenvalue (λₘᵢₙ)',
        '# Negative Eigenvalues'
    ]

    vanilla_data = [
        f"{results['vanilla']['test_acc']:.2f}",
        f"{results['vanilla']['test_loss']:.4f}",
        f"{results['vanilla']['sharpness']:.6f}",
        f"{results['vanilla']['eigenvalues'][0]:.4f}",
        f"{results['vanilla']['eigenvalues'][-1]:.4f}",
        f"{sum(1 for e in results['vanilla']['eigenvalues'] if e < 0)}"
    ]

    resnet_data = [
        f"{results['resnet']['test_acc']:.2f}",
        f"{results['resnet']['test_loss']:.4f}",
        f"{results['resnet']['sharpness']:.6f}",
        f"{results['resnet']['eigenvalues'][0]:.4f}",
        f"{results['resnet']['eigenvalues'][-1]:.4f}",
        f"{sum(1 for e in results['resnet']['eigenvalues'] if e < 0)}"
    ]

    table_data = []
    for i, metric in enumerate(metrics):
        table_data.append([metric, vanilla_data[i], resnet_data[i]])

    table = ax.table(cellText=table_data,
                    colLabels=['Metric', 'Vanilla CNN', 'ResNet'],
                    cellLoc='center',
                    loc='center',
                    colWidths=[0.4, 0.3, 0.3])

    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 2.5)

    # Style header
    for i in range(3):
        table[(0, i)].set_facecolor('#4CAF50')
        table[(0, i)].set_text_props(weight='bold', color='white')

    # Alternate row colors
    for i in range(1, len(metrics) + 1):
        for j in range(3):
            if i % 2 == 0:
                table[(i, j)].set_facecolor('#f0f0f0')

    plt.title('Summary of Landscape Metrics', fontsize=16, fontweight='bold', pad=20)
    plt.savefig('summary_table.png', dpi=300, bbox_inches='tight')
    print("Saved: summary_table.png")
    plt.close()


def main():
    """Generate all visualizations"""
    print("="*80)
    print("GENERATING VISUALIZATIONS")
    print("="*80)

    print("\nLoading results...")
    results = load_results()

    print("\nGenerating plots...")
    plot_training_curves(results)
    plot_hessian_spectrum(results)
    plot_comparison_metrics(results)
    plot_mode_connectivity(results)
    plot_loss_surface_3d(results)
    plot_loss_surface_contour(results)
    plot_correlation_analysis(results)
    generate_summary_table(results)

    print("\n" + "="*80)
    print("ALL VISUALIZATIONS COMPLETE!")
    print("="*80)
    print("\nGenerated files:")
    print("  1. training_curves.png")
    print("  2. hessian_spectrum.png")
    print("  3. metrics_comparison.png")
    print("  4. mode_connectivity.png")
    print("  5. loss_surface_3d.png")
    print("  6. loss_surface_contour.png")
    print("  7. correlation_analysis.png")
    print("  8. summary_table.png")
    print("\nUse these figures in your final report!")
    print("="*80)


if __name__ == "__main__":
    main()

GENERATING VISUALIZATIONS

Loading results...

Generating plots...
Saved: training_curves.png
Saved: hessian_spectrum.png
Saved: metrics_comparison.png
Saved: mode_connectivity.png
Saved: loss_surface_3d.png
Saved: loss_surface_contour.png
Saved: correlation_analysis.png
Saved: summary_table.png

ALL VISUALIZATIONS COMPLETE!

Generated files:
  1. training_curves.png
  2. hessian_spectrum.png
  3. metrics_comparison.png
  4. mode_connectivity.png
  5. loss_surface_3d.png
  6. loss_surface_contour.png
  7. correlation_analysis.png
  8. summary_table.png

Use these figures in your final report!
