In [None]:
import torch
import numpy as np
import os
from torch.utils.data import DataLoader, Dataset, Subset, WeightedRandomSampler
import torch.nn as nn
import torch.optim as optim
import glob
from pathlib import Path
import random
import copy
from scipy.spatial.transform import Rotation
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import time
from tqdm import tqdm
from tqdm.notebook import tqdm as tqdm_notebook
from scipy import stats
from torch.profiler import profile, record_function, ProfilerActivity
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve, auc
import math

In [4]:
class SetAbstractionLayer(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp):
        super(SetAbstractionLayer, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, N, 3]
            points: input points data, [B, N, D]
        Return:
            new_xyz: sampled points position data, [B, S, 3]
            new_points_concat: sample points feature data, [B, S, D']
        """
        # Simplified implementation - in a real model this would use farthest point sampling
        # and ball query operations from PointNet++ 
        B, N, C = xyz.shape
        if self.npoint is None:
            S = N
        else:
            S = min(self.npoint, N)
            
        # Simplified: just use the first S points instead of FPS
        new_xyz = xyz[:, :S, :]
        
        # Simplified feature aggregation (in real implementation would be grouped by radius)
        if points is not None:
            new_points = points[:, :S, :]
        else:
            new_points = new_xyz
            
        # Apply MLP - simplified for this example
        new_points = new_points.permute(0, 2, 1).unsqueeze(-1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = torch.relu(bn(conv(new_points)))
        
        new_points = new_points.squeeze(-1).permute(0, 2, 1)
        return new_xyz, new_points

class PointNet2Classification(nn.Module):
    def __init__(self, num_classes=40, normal_channel=False):
        super(PointNet2Classification, self).__init__()
        in_channel = 6 if normal_channel else 3
        
        self.normal_channel = normal_channel
        
        # SA modules
        self.sa1 = SetAbstractionLayer(npoint=512, radius=0.2, nsample=32, 
                                        in_channel=in_channel, mlp=[64, 64, 128])
        self.sa2 = SetAbstractionLayer(npoint=128, radius=0.4, nsample=64, 
                                        in_channel=128 + 3, mlp=[128, 128, 256])
        self.sa3 = SetAbstractionLayer(npoint=None, radius=None, nsample=None, 
                                        in_channel=256 + 3, mlp=[256, 512, 1024])
        
        # FC layers
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        
        # Transpose to match expected input format
        xyz = xyz.transpose(2, 1)
        
        # Set Abstraction layers
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        
        # FC layers
        x = l3_points.view(B, 1024)
        x = self.drop1(torch.relu(self.bn1(self.fc1(x))))
        x = self.drop2(torch.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        
        return x

def _init_pointnet_plus_plus(num_classes=40):
    """Create and initialize a PointNet++ classification model"""
    model = PointNet2Classification(num_classes=num_classes)
    return model

In [5]:
class EdgeConv(nn.Module):
    def __init__(self, k, in_channels, out_channels):
        super(EdgeConv, self).__init__()
        self.k = k
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels*2, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
    def knn(self, x, k):
        inner = -2 * torch.matmul(x.transpose(2, 1), x)
        xx = torch.sum(x**2, dim=1, keepdim=True)
        distance = -xx - inner - xx.transpose(2, 1)
        
        idx = distance.topk(k=k, dim=-1)[1]
        return idx
    
    def get_graph_feature(self, x, k=20):
        batch_size, num_dims, num_points = x.size()
        idx = self.knn(x, k)
        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
        idx = idx + idx_base
        idx = idx.view(-1)
        
        x = x.transpose(2, 1).contiguous()
        feature = x.view(batch_size * num_points, -1)[idx, :]
        feature = feature.view(batch_size, num_points, k, num_dims)
        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
        
        feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()
        return feature
    
    def forward(self, x):
        x = self.get_graph_feature(x, self.k)
        x = self.conv(x)
        x = x.max(dim=-1, keepdim=False)[0]
        return x

class DGCNN(nn.Module):
    def __init__(self, num_classes=40, k=20):
        super(DGCNN, self).__init__()
        self.k = k
        
        # Edge convolution layers
        self.edge_conv1 = EdgeConv(k=k, in_channels=3, out_channels=64)
        self.edge_conv2 = EdgeConv(k=k, in_channels=64, out_channels=64)
        self.edge_conv3 = EdgeConv(k=k, in_channels=64, out_channels=128)
        self.edge_conv4 = EdgeConv(k=k, in_channels=128, out_channels=256)
        
        # MLP layers
        self.mlp = nn.Sequential(
            nn.Conv1d(512, 1024, kernel_size=1, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
        # Fully connected layers
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.5)
        
        self.fc3 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        # Input transform
        batch_size = x.size(0)
        if x.shape[1] == 3 and len(x.shape) == 3:  # BxCxN format
            pass
        else:  # Assuming BxNxC format (like PointNet)
            x = x.transpose(2, 1)
            
        # Extract edge features
        x1 = self.edge_conv1(x)
        x2 = self.edge_conv2(x1)
        x3 = self.edge_conv3(x2)
        x4 = self.edge_conv4(x3)
        
        # Concatenate features
        x = torch.cat([x1, x2, x3, x4], dim=1)
        
        # MLP
        x = self.mlp(x)
        
        # Global max pooling
        x = torch.max(x, 2)[0]
        
        # FC layers
        x = self.drop1(torch.relu(self.bn1(self.fc1(x))))
        x = self.drop2(torch.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        
        return x

def _init_dgcnn(num_classes=40):
    """Create and initialize a DGCNN classification model"""
    model = DGCNN(num_classes=num_classes)
    return model

In [6]:
class PointMLP(nn.Module):
    def __init__(self, num_classes=40, points=1024, embed_dim=64):
        super(PointMLP, self).__init__()
        self.num_classes = num_classes
        self.points = points
        
        # Point embedding layers
        self.embedding = nn.Sequential(
            nn.Conv1d(3, embed_dim, 1),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(embed_dim, embed_dim, 1),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True)
        )
        
        # Hierarchical feature extraction layers
        self.layer1 = nn.Sequential(
            nn.Conv1d(embed_dim, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True)
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True)
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        # Handle input format (B, N, 3) or (B, 3, N)
        batch_size = x.size(0)
        if x.shape[1] != 3 and x.shape[2] == 3:
            x = x.transpose(2, 1)
        
        # Feature extraction
        x = self.embedding(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        # Global pooling
        x = torch.max(x, 2)[0]
        
        # Classification
        x = self.classifier(x)
        
        return x

def _init_pointmlp(num_classes=40):
    """Create and initialize a PointMLP classification model"""
    model = PointMLP(num_classes=num_classes)
    return model

In [7]:
class PointAttentionNet(nn.Module):
    def __init__(self, num_classes=40, num_points=1024, embed_dim=128):
        super(PointAttentionNet, self).__init__()
        self.num_classes = num_classes
        
        # Point embedding
        self.embedding = nn.Sequential(
            nn.Conv1d(3, embed_dim, 1),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True)
        )
        
        # Self-attention modules
        self.self_attn1 = PointSelfAttention(embed_dim, 4)
        self.self_attn2 = PointSelfAttention(embed_dim, 4)
        
        # Feature processing after attention
        self.feature_conv = nn.Sequential(
            nn.Conv1d(embed_dim, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True)
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        # Handle input format (B, N, 3) or (B, 3, N)
        batch_size = x.size(0)
        if x.shape[1] != 3 and x.shape[2] == 3:
            x = x.transpose(2, 1)
        
        # Feature embedding
        x = self.embedding(x)  # B x embed_dim x N
        
        # Apply attention
        x = self.self_attn1(x)
        x = self.self_attn2(x)
        
        # Feature processing
        x = self.feature_conv(x)
        
        # Global pooling
        x = torch.max(x, 2)[0]
        
        # Classification
        x = self.classifier(x)
        
        return x


class PointSelfAttention(nn.Module):
    def __init__(self, channels, heads=4):
        super(PointSelfAttention, self).__init__()
        self.heads = heads
        self.channels = channels
        self.head_dim = channels // heads
        assert self.head_dim * heads == channels, "channels must be divisible by heads"
        
        self.qkv_conv = nn.Conv1d(channels, channels * 3, 1, bias=False)
        self.out_conv = nn.Conv1d(channels, channels, 1)
        
        self.norm1 = nn.BatchNorm1d(channels)
        self.norm2 = nn.BatchNorm1d(channels)
        
        self.ff = nn.Sequential(
            nn.Conv1d(channels, channels * 2, 1),
            nn.BatchNorm1d(channels * 2),
            nn.ReLU(inplace=True),
            nn.Conv1d(channels * 2, channels, 1)
        )
        
    def forward(self, x):
        # x: B x C x N
        batch_size, C, N = x.shape
        residual = x
        
        # Self-attention
        qkv = self.qkv_conv(x)  # B x 3C x N
        qkv = qkv.reshape(batch_size, 3, self.heads, self.head_dim, N)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]  # B x H x D x N
        
        # Compute attention scores
        q = q.permute(0, 1, 3, 2)  # B x H x N x D
        k = k.permute(0, 1, 2, 3)  # B x H x D x N
        attn = torch.matmul(q, k) / (self.head_dim ** 0.5)  # B x H x N x N
        attn = torch.softmax(attn, dim=-1)
        
        # Apply attention
        v = v.permute(0, 1, 3, 2)  # B x H x N x D
        x = torch.matmul(attn, v)  # B x H x N x D
        x = x.permute(0, 1, 3, 2).reshape(batch_size, C, N)
        
        # Output projection
        x = self.out_conv(x)
        x = self.norm1(x + residual)
        
        # Feed forward
        residual = x
        x = self.ff(x)
        x = self.norm2(x + residual)
        
        return x


def _init_custom_attention(num_classes=40):
    """Create and initialize a custom attention-based model"""
    return PointAttentionNet(num_classes=num_classes)

In [8]:
# Import model architectures
# Note: These would typically be imported from external files
# For demonstration, we'll include placeholders for imports

# from models.pointnet2 import PointNet2Classification
# from models.dgcnn import DGCNN
# from models.pointmlp import PointMLP
# from models.custom_attention import PointAttentionNet

class ModelEvaluator:
    def __init__(self, model_name, pretrained_path, num_classes=40):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        self.num_classes = num_classes
        
        # Initialize selected model
        if model_name == "pointnet++":
            self.model = self._init_pointnet_plus_plus()
        elif model_name == "dgcnn":
            self.model = self._init_dgcnn()
        elif model_name == "pointmlp":
            self.model = self._init_pointmlp()
        elif model_name == "custom":
            self.model = self._init_custom_attention()
        else:
            raise ValueError(f"Unsupported model: {model_name}")
            
        # Load pre-trained weights
        self.load_pretrained(pretrained_path)
        self.model.to(self.device)
        self.model.eval()
        
    def _init_pointnet_plus_plus(self):
        return PointNet2Classification(num_classes=self.num_classes)
        
    def _init_dgcnn(self):
        # Placeholder for DGCNN initialization
        return DGCNN(num_classes=self.num_classes)
        
    def _init_pointmlp(self):
        # Placeholder for PointMLP initialization
        return PointMLP(num_classes=self.num_classes)
        
    def _init_custom_attention(self):
        # Placeholder for custom attention model initialization
        # return PointAttentionNet(num_classes=self.num_classes)
        print("Initializing custom attention model")
        return None  # Replace with actual model
    
    def load_pretrained(self, path):
        # Load pre-trained weights if the model exists
        if self.model is not None and os.path.exists(path):
            print(f"Loading pre-trained weights for {self.model_name} from {path}")
            # self.model.load_state_dict(torch.load(path))
        else:
            print(f"No pre-trained weights found at {path}")
    
    def evaluate_adversarial_robustness(self, test_loader, attack_type, epsilon):
        """
        Evaluate model robustness against adversarial attacks
        
        Parameters:
        -----------
        test_loader: DataLoader containing test data
        attack_type: str, type of attack (e.g., 'fgsm', 'pgd')
        epsilon: float, attack strength
        
        Returns:
        --------
        accuracy: float, model accuracy under adversarial attack
        """
        # Placeholder for adversarial evaluation
        print(f"Evaluating {self.model_name} against {attack_type} with ε={epsilon}")
        return 0.0
    
    def evaluate_corruption_robustness(self, test_loader, corruption_type, severity):
        """
        Evaluate model robustness against environmental corruptions
        
        Parameters:
        -----------
        test_loader: DataLoader containing test data
        corruption_type: str, type of corruption (e.g., 'gaussian', 'impulse')
        severity: int, corruption severity level (usually 1-5)
        
        Returns:
        --------
        accuracy: float, model accuracy under corruption
        """
        # Placeholder for corruption evaluation
        print(f"Evaluating {self.model_name} against {corruption_type} corruption at severity {severity}")
        return 0.0

# Example usage:
# evaluator = ModelEvaluator("pointnet++", "pretrained/pointnet_modelnet40.pth")
# adv_acc = evaluator.evaluate_adversarial_robustness(test_loader, "pgd", 0.1)
# corr_acc = evaluator.evaluate_corruption_robustness(test_loader, "gaussian_noise", 3)

In [10]:
import torch
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import glob
from pathlib import Path
import random
from scipy.spatial.transform import Rotation
import math

class ModelNet40Dataset(Dataset):
    def __init__(self, root_dir, split='train', num_points=1024, transform=None, random_rotation=True, class_balance=True):
        """
        ModelNet40 Dataset for 3D point cloud classification
        
        Parameters:
        -----------
        root_dir: str, path to the ModelNet40 dataset
        split: str, 'train' or 'test'
        num_points: int, number of points to sample (1024, 2048, or 4096)
        transform: callable, optional transform to be applied on a sample
        random_rotation: bool, whether to apply random rotation augmentation
        class_balance: bool, whether to use class balancing
        """
        self.root_dir = root_dir
        self.split = split
        self.num_points = num_points
        self.transform = transform
        self.random_rotation = random_rotation
        
        # Get all class folders
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        # Get all model files
        self.files = []
        self.labels = []
        for cls_name in self.classes:
            cls_path = os.path.join(root_dir, cls_name, split)
            if os.path.exists(cls_path):
                model_files = glob.glob(os.path.join(cls_path, '*.off'))
                self.files.extend(model_files)
                self.labels.extend([self.class_to_idx[cls_name]] * len(model_files))
        
        # Prepare weights for class balancing
        self.weights = None
        if class_balance:
            # Count instances per class
            class_counts = np.zeros(len(self.classes))
            for label in self.labels:
                class_counts[label] += 1
            # Calculate weights inversely proportional to class frequency
            self.weights = 1.0 / class_counts
            # Assign weight to each sample
            self.sample_weights = np.array([self.weights[label] for label in self.labels])
            self.sample_weights = torch.from_numpy(self.sample_weights).float()
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        file_path = self.files[idx]
        label = self.labels[idx]
        
        # Load point cloud (placeholder - actual loading depends on file format)
        # For OFF files, you'd parse the OFF format
        # For this example, we'll generate a random point cloud
        # In practice, replace this with actual file loading
        points = np.random.rand(1024, 3)  # Placeholder
        
        # Sample num_points points
        if points.shape[0] > self.num_points:
            indices = np.random.choice(points.shape[0], self.num_points, replace=False)
            points = points[indices]
        elif points.shape[0] < self.num_points:
            # Upsample by duplicating
            indices = np.random.choice(points.shape[0], self.num_points - points.shape[0], replace=True)
            points = np.vstack([points, points[indices]])
        
        # Normalize to unit sphere
        center = np.mean(points, axis=0)
        points = points - center
        dist = np.max(np.sqrt(np.sum(points ** 2, axis=1)))
        points = points / dist
        
        # Apply random rotation
        if self.split == 'train' and self.random_rotation:
            # Random rotation around z-axis
            theta = np.random.uniform(0, 2 * np.pi)
            rotation_matrix = np.array([
                [np.cos(theta), -np.sin(theta), 0],
                [np.sin(theta), np.cos(theta), 0],
                [0, 0, 1]
            ])
            points = points @ rotation_matrix
        
        if self.transform:
            points = self.transform(points)
        
        return points, label

class PointCloudCorruptor:
    """Class for adding various environmental corruptions to point clouds"""
    
    @staticmethod
    def add_snow(points, density=0.1, scatter_strength=0.05):
        """Add snow effect to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        density: float, snow density (0.0-1.0)
        scatter_strength: float, scattering effect strength
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        num_snow_points = int(num_points * density)
        
        # Generate snow points
        snow_points = np.random.uniform(-1, 1, size=(num_snow_points, 3))
        
        # Create scattering effect
        scatter = np.random.normal(0, scatter_strength, size=(num_points, 3))
        points_scattered = points + scatter
        
        # Randomly select points to replace with snow
        indices = np.random.choice(num_points, num_snow_points, replace=False)
        points_scattered[indices] = snow_points
        
        return points_scattered
    
    @staticmethod
    def add_rain(points, density=0.1, drop_length=0.05):
        """Add rain effect to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        density: float, rain density (0.0-1.0)
        drop_length: float, length of rain drops
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        num_rain_points = int(num_points * density)
        
        # Generate rain points (starting positions)
        rain_points = np.random.uniform(-1, 1, size=(num_rain_points, 3))
        # Make rain streaks by extending in mostly downward direction
        rain_directions = np.random.normal(0, 0.1, size=(num_rain_points, 2))
        rain_directions = np.column_stack([rain_directions, -np.abs(np.random.normal(0, 0.5, num_rain_points))])
        rain_directions /= np.linalg.norm(rain_directions, axis=1, keepdims=True)
        
        # Create rain droplet streaks
        rain_streaks = rain_points + rain_directions * drop_length
        
        # Randomly select points to replace with rain
        indices = np.random.choice(num_points, num_rain_points, replace=False)
        points_with_rain = points.copy()
        points_with_rain[indices] = rain_streaks
        
        return points_with_rain
    
    @staticmethod
    def add_fog(points, density=0.2):
        """Add fog effect to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        density: float, fog density (0.0-1.0)
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        
        # Calculate distance from origin for each point
        distances = np.linalg.norm(points, axis=1)
        
        # Attenuate points based on distance and fog density
        attenuation = np.exp(-density * distances)
        
        # Apply attenuation: closer to 0 means more fog effect
        # We'll add random displacement proportional to fog intensity
        fog_displacement = (1 - attenuation).reshape(-1, 1) * np.random.normal(0, 0.1, size=(num_points, 3))
        points_with_fog = points + fog_displacement
        
        # Randomly drop some distant points (completely obscured by fog)
        dropout_prob = 1 - np.exp(-density * distances * 2)
        dropout_mask = np.random.random(num_points) > dropout_prob
        
        # Create a mix of original and fogged points
        return points_with_fog * dropout_mask.reshape(-1, 1) + (1 - dropout_mask.reshape(-1, 1)) * np.random.uniform(-0.1, 0.1, size=(num_points, 3))
    
    @staticmethod
    def add_gaussian_noise(points, sigma=0.02):
        """Add Gaussian noise to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        sigma: float, standard deviation relative to object size
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        noise = np.random.normal(0, sigma, size=points.shape)
        return points + noise
    
    @staticmethod
    def add_depth_noise(points, k=0.05):
        """Add depth-dependent noise (more noise at greater distances)
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        k: float, noise coefficient
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        # Calculate distance from origin
        distances = np.linalg.norm(points, axis=1)
        
        # Scale noise by distance (farther = more noise)
        noise_scale = k * distances.reshape(-1, 1)
        noise = np.random.normal(0, 1, size=points.shape) * noise_scale
        
        return points + noise
    
    @staticmethod
    def add_occlusion(points, ratio=0.2, mode='random'):
        """Add occlusion to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        ratio: float, percentage of points to occlude (0.0-1.0)
        mode: str, 'random', 'region', or 'semantic'
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        num_to_occlude = int(num_points * ratio)
        
        if mode == 'random':
            # Random occlusion
            mask = np.ones(num_points, dtype=bool)
            indices = np.random.choice(num_points, num_to_occlude, replace=False)
            mask[indices] = False
            return points[mask]
        
        elif mode == 'region':
            # Region-based occlusion
            center = np.random.uniform(-0.5, 0.5, size=3)
            radius = np.random.uniform(0.2, 0.5)
            
            # Calculate distances to the center
            distances = np.linalg.norm(points - center, axis=1)
            
            # Keep points outside the sphere
            mask = distances > radius
            return points[mask]
        
        elif mode == 'semantic':
            # Semantic occlusion (simplified - just occludes one side)
            # In real implementation, this would target specific semantic parts
            dimension = np.random.randint(0, 3)  # Choose x, y, or z
            threshold = np.median(points[:, dimension])
            mask = points[:, dimension] > threshold
            return points[mask]
        
        return points

def prepare_dataloaders(dataset_path, batch_size=32, num_points=1024, num_workers=4):
    """
    Prepare train and test data loaders
    
    Parameters:
    -----------
    dataset_path: str, path to ModelNet40 dataset
    batch_size: int, batch size
    num_points: int, number of points per sample
    num_workers: int, number of dataloader workers
    
    Returns:
    --------
    train_loader: DataLoader for training data
    test_loader: DataLoader for test data
    """
    # Create datasets
    train_dataset = ModelNet40Dataset(
        root_dir=dataset_path,
        split='train',
        num_points=num_points,
        random_rotation=True,
        class_balance=True
    )
    
    test_dataset = ModelNet40Dataset(
        root_dir=dataset_path,
        split='test',
        num_points=num_points,
        random_rotation=False,
        class_balance=False
    )
    
    # Create weighted sampler for class balancing
    if train_dataset.weights is not None:
        sampler = WeightedRandomSampler(
            weights=train_dataset.sample_weights,
            num_samples=len(train_dataset),
            replacement=True
        )
    else:
        sampler = None
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, test_loader

def get_corrupted_dataset(dataset, corruption_type, severity):
    """
    Create a corrupted version of a dataset
    
    Parameters:
    -----------
    dataset: Dataset object
    corruption_type: str, type of corruption
    severity: int or float, severity level
    
    Returns:
    --------
    corrupted_dataset: Dataset with corruption applied
    """
    # Create a copy of the dataset
    corrupted_dataset = copy.deepcopy(dataset)
    
    # Define corruption parameters based on severity (1-5 scale)
    if corruption_type == 'gaussian_noise':
        param = {1: 0.01, 2: 0.02, 3: 0.03, 4: 0.04, 5: 0.05}[severity]
        transform = lambda x: PointCloudCorruptor.add_gaussian_noise(x, sigma=param)
    
    elif corruption_type == 'snow':
        density = {1: 0.05, 2: 0.1, 3: 0.15, 4: 0.2, 5: 0.3}[severity]
        transform = lambda x: PointCloudCorruptor.add_snow(x, density=density)
    
    elif corruption_type == 'rain':
        density = {1: 0.05, 2: 0.1, 3: 0.15, 4: 0.2, 5: 0.3}[severity]
        transform = lambda x: PointCloudCorruptor.add_rain(x, density=density)
    
    elif corruption_type == 'fog':
        density = {1: 0.05, 2: 0.1, 3: 0.2, 4: 0.3, 5: 0.4}[severity]
        transform = lambda x: PointCloudCorruptor.add_fog(x, density=density)
    
    elif corruption_type == 'depth_noise':
        param = {1: 0.01, 2: 0.02, 3: 0.04, 4: 0.06, 5: 0.1}[severity]
        transform = lambda x: PointCloudCorruptor.add_depth_noise(x, k=param)
    
    elif corruption_type == 'occlusion':
        ratio = {1: 0.1, 2: 0.15, 3: 0.2, 4: 0.25, 5: 0.3}[severity]
        transform = lambda x: PointCloudCorruptor.add_occlusion(x, ratio=ratio)
    
    else:
        raise ValueError(f"Unsupported corruption type: {corruption_type}")
    
    # Set the transform in the corrupted dataset
    corrupted_dataset.transform = transform
    
    return corrupted_dataset

In [11]:
import torch
import numpy as np
import copy
from tqdm import tqdm

class PointCloudAttacker:
    def __init__(self, model, device=None):
        """
        Class for implementing various adversarial attacks on point cloud models
        
        Parameters:
        -----------
        model: torch.nn.Module, the target model to attack
        device: torch.device, device to perform computations on
        """
        self.model = model
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()
    
    # ======== White-box Attacks ========
    
    def fgsm_attack(self, points, labels, epsilon=0.03):
        """
        Fast Gradient Sign Method attack
        
        Parameters:
        -----------
        points: torch.Tensor of shape (B, N, 3) or (B, 3, N)
        labels: torch.Tensor of shape (B,)
        epsilon: float, attack strength
        
        Returns:
        --------
        perturbed_points: torch.Tensor, adversarial examples
        """
        # Clone the input and make sure it requires grad
        points_format = 'channels_last' if points.shape[1] == 3 and len(points.shape) == 3 else 'channels_first'
        if points_format == 'channels_first':
            x = points.clone().detach().requires_grad_(True)
        else:
            x = points.clone().detach().transpose(2, 1).requires_grad_(True)
        
        # Forward pass
        outputs = self.model(x)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        
        # Backward pass
        self.model.zero_grad()
        loss.backward()
        
        # Generate perturbation
        data_grad = x.grad.data
        sign_data_grad = data_grad.sign()
        
        # Create perturbed data
        perturbed_points = x + epsilon * sign_data_grad
        
        # Project perturbations to keep points within original bounds
        # Assuming original points are normalized to unit sphere
        if points_format == 'channels_last':
            return perturbed_points.detach().transpose(2, 1)
        return perturbed_points.detach()
    
    def pgd_attack(self, points, labels, epsilon=0.03, alpha=None, num_iter=20):
        """
        Projected Gradient Descent attack
        
        Parameters:
        -----------
        points: torch.Tensor of shape (B, N, 3) or (B, 3, N)
        labels: torch.Tensor of shape (B,)
        epsilon: float, attack strength
        alpha: float, step size (if None, will be set to epsilon/4)
        num_iter: int, number of iterations
        
        Returns:
        --------
        perturbed_points: torch.Tensor, adversarial examples
        """
        if alpha is None:
            alpha = epsilon / 4
            
        points_format = 'channels_last' if points.shape[1] == 3 and len(points.shape) == 3 else 'channels_first'
        if points_format == 'channels_last':
            original_points = points.clone().detach().transpose(2, 1)
            perturbed_points = original_points.clone().detach().requires_grad_(True)
        else:
            original_points = points.clone().detach()
            perturbed_points = original_points.clone().detach().requires_grad_(True)
        
        for i in range(num_iter):
            # Forward pass
            outputs = self.model(perturbed_points)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
            
            # Backward pass
            self.model.zero_grad()
            loss.backward()
            
            # Update points
            with torch.no_grad():
                perturbation = alpha * perturbed_points.grad.sign()
                perturbed_points = perturbed_points + perturbation
                
                # Project back to epsilon ball
                delta = perturbed_points - original_points
                delta = torch.clamp(delta, -epsilon, epsilon)
                perturbed_points = original_points + delta
                
                # Project to unit sphere if needed (optional)
                # norm = torch.norm(perturbed_points, dim=1, keepdim=True)
                # perturbed_points = perturbed_points / torch.max(norm, torch.ones_like(norm))
            
            # Reset gradients
            if i < num_iter - 1:
                perturbed_points.requires_grad_(True)
        
        if points_format == 'channels_last':
            return perturbed_points.detach().transpose(2, 1)
        return perturbed_points.detach()
    
    def cw_attack(self, points, labels, confidence=0, lr=0.01, num_iter=100):
        """
        Carlini & Wagner attack
        
        Parameters:
        -----------
        points: torch.Tensor of shape (B, N, 3) or (B, 3, N)
        labels: torch.Tensor of shape (B,)
        confidence: float, confidence parameter κ
        lr: float, learning rate
        num_iter: int, number of iterations
        
        Returns:
        --------
        perturbed_points: torch.Tensor, adversarial examples
        """
        points_format = 'channels_last' if points.shape[1] == 3 and len(points.shape) == 3 else 'channels_first'
        if points_format == 'channels_last':
            original_points = points.clone().detach().transpose(2, 1)
        else:
            original_points = points.clone().detach()
            
        # Initialize perturbation
        delta = torch.zeros_like(original_points, requires_grad=True, device=self.device)
        optimizer = torch.optim.Adam([delta], lr=lr)
        
        batch_size = points.shape[0]
        target_labels = labels  # For untargeted attack, we just use the original labels
        
        for i in range(num_iter):
            optimizer.zero_grad()
            
            # Apply perturbation
            perturbed_points = original_points + delta
            
            # Forward pass
            outputs = self.model(perturbed_points)
            
            # CW loss: maximize the difference between target and highest non-target class
            target_values = outputs.gather(1, target_labels.unsqueeze(1)).squeeze(1)
            other_values = outputs.clone()
            other_values.scatter_(1, target_labels.unsqueeze(1), -float('inf'))
            other_values = other_values.max(dim=1)[0]
            
            # Loss with confidence parameter κ
            loss = (other_values - target_values + confidence).clamp(min=0).mean()
            
            # Add regularization term for perturbation magnitude
            loss += 0.01 * torch.norm(delta, dim=[1, 2]).mean()
            
            # Backward and update
            loss.backward()
            optimizer.step()
            
        perturbed_points = (original_points + delta).detach()
        if points_format == 'channels_last':
            return perturbed_points.transpose(2, 1)
        return perturbed_points
    
    # ======== Black-box Attacks ========
    
    def transfer_attack(self, points, labels, surrogate_model, attack_type='pgd', **attack_params):
        """
        Transfer attack using a surrogate model
        
        Parameters:
        -----------
        points: torch.Tensor of shape (B, N, 3) or (B, 3, N)
        labels: torch.Tensor of shape (B,)
        surrogate_model: torch.nn.Module, surrogate model to generate adversarial examples
        attack_type: str, type of attack to use ('fgsm', 'pgd', 'cw')
        attack_params: dict, parameters for the attack
        
        Returns:
        --------
        perturbed_points: torch.Tensor, adversarial examples
        """
        # Save original model
        original_model = self.model
        
        # Set surrogate model for attack generation
        self.model = surrogate_model
        
        # Generate adversarial examples
        if attack_type == 'fgsm':
            perturbed_points = self.fgsm_attack(points, labels, **attack_params)
        elif attack_type == 'pgd':
            perturbed_points = self.pgd_attack(points, labels, **attack_params)
        elif attack_type == 'cw':
            perturbed_points = self.cw_attack(points, labels, **attack_params)
        else:
            raise ValueError(f"Unsupported attack type: {attack_type}")
        
        # Restore original model
        self.model = original_model
        
        return perturbed_points
    
    # ======== Point Cloud-Specific Attacks ========
    
    def point_perturbation_attack(self, points, labels, max_displacement=0.05, num_iter=50, lr=0.01):
        """
        Point perturbation attack with maximum displacement constraints
        
        Parameters:
        -----------
        points: torch.Tensor of shape (B, N, 3) or (B, 3, N)
        labels: torch.Tensor of shape (B,)
        max_displacement: float, maximum displacement per point
        num_iter: int, number of iterations
        lr: float, learning rate
        
        Returns:
        --------
        perturbed_points: torch.Tensor, adversarial examples
        """
        points_format = 'channels_last' if points.shape[1] == 3 and len(points.shape) == 3 else 'channels_first'
        if points_format == 'channels_last':
            original_points = points.clone().detach().transpose(2, 1)
        else:
            original_points = points.clone().detach()
            
        # Initialize perturbation
        delta = torch.zeros_like(original_points, requires_grad=True, device=self.device)
        optimizer = torch.optim.Adam([delta], lr=lr)
        
        for i in range(num_iter):
            optimizer.zero_grad()
            
            # Apply perturbation with constraints
            delta_clamped = torch.clamp(delta, -max_displacement, max_displacement)
            perturbed_points = original_points + delta_clamped
            
            # Forward pass
            outputs = self.model(perturbed_points)
            
            # Use cross-entropy loss for untargeted attack
            loss = -torch.nn.functional.cross_entropy(outputs, labels)
            
            # Backward and update
            loss.backward()
            optimizer.step()
        
        # Final clamping
        delta_clamped = torch.clamp(delta, -max_displacement, max_displacement)
        perturbed_points = (original_points + delta_clamped).detach()
        
        if points_format == 'channels_last':
            return perturbed_points.transpose(2, 1)
        return perturbed_points
    
    def point_addition_removal(self, points, labels, ratio=0.1, mode='removal'):
        """
        Point addition or removal attack
        
        Parameters:
        -----------
        points: torch.Tensor of shape (B, N, 3) or (B, 3, N)
        labels: torch.Tensor of shape (B,)
        ratio: float, percentage of points to add/remove
        mode: str, 'addition' or 'removal'
        
        Returns:
        --------
        perturbed_points: torch.Tensor, adversarial examples
        """
        points_format = 'channels_last' if points.shape[1] == 3 and len(points.shape) == 3 else 'channels_first'
        if points_format == 'channels_last':
            # Points are in shape [B, N, 3]
            B, N, C = points.shape
            points_to_process = points.clone()
        else:
            # Points are in shape [B, 3, N]
            B, C, N = points.shape
            points_to_process = points.clone().transpose(2, 1)
            
        perturbed_points = []
        
        for i in range(B):
            point_cloud = points_to_process[i]  # [N, 3]
            
            if mode == 'removal':
                # Remove points
                num_to_remove = int(N * ratio)
                
                # Try to identify critical points (can be more sophisticated)
                # For now, just randomly remove points
                indices = torch.randperm(N)[:N-num_to_remove]
                perturbed_cloud = point_cloud[indices]
                
                # Pad back to original size by duplicating existing points
                if len(perturbed_cloud) < N:
                    pad_indices = torch.randint(0, len(perturbed_cloud), (N - len(perturbed_cloud),))
                    padding = perturbed_cloud[pad_indices]
                    perturbed_cloud = torch.cat([perturbed_cloud, padding], dim=0)
                
            elif mode == 'addition':
                # Add points
                num_to_add = int(N * ratio)
                
                # Create new points (could be more sophisticated)
                # For now, add noise to existing points
                indices = torch.randint(0, N, (num_to_add,))
                new_points = point_cloud[indices] + torch.randn_like(point_cloud[indices]) * 0.1
                
                # Remove some original points to keep the total constant
                keep_indices = torch.randperm(N)[:N-num_to_add]
                perturbed_cloud = torch.cat([point_cloud[keep_indices], new_points], dim=0)
                
            perturbed_points.append(perturbed_cloud)
            
        perturbed_points = torch.stack(perturbed_points, dim=0)
        
        if points_format == 'channels_first':
            return perturbed_points.transpose(2, 1)
        return perturbed_points
    
    # ======== Evaluation Metrics ========
    
    @staticmethod
    def chamfer_distance(x, y):
        """
        Calculate Chamfer distance between two point clouds
        
        Parameters:
        -----------
        x: torch.Tensor of shape (B, N, 3)
        y: torch.Tensor of shape (B, M, 3)
        
        Returns:
        --------
        distance: torch.Tensor of shape (B,)
        """
        x = x.unsqueeze(2)  # (B, N, 1, 3)
        y = y.unsqueeze(1)  # (B, 1, M, 3)
        
        # Calculate pairwise distances
        dist = torch.sum((x - y) ** 2, dim=3)  # (B, N, M)
        
        # Find minimum distances
        min_dist_xy = torch.min(dist, dim=2)[0]  # (B, N)
        min_dist_yx = torch.min(dist, dim=1)[0]  # (B, M)
        
        # Calculate Chamfer distance
        chamfer_dist = torch.mean(min_dist_xy, dim=1) + torch.mean(min_dist_yx, dim=1)
        
        return chamfer_dist
    
    @staticmethod
    def hausdorff_distance(x, y):
        """
        Calculate Hausdorff distance between two point clouds
        
        Parameters:
        -----------
        x: torch.Tensor of shape (B, N, 3)
        y: torch.Tensor of shape (B, M, 3)
        
        Returns:
        --------
        distance: torch.Tensor of shape (B,)
        """
        x = x.unsqueeze(2)  # (B, N, 1, 3)
        y = y.unsqueeze(1)  # (B, 1, M, 3)
        
        # Calculate pairwise distances
        dist = torch.sum((x - y) ** 2, dim=3)  # (B, N, M)
        
        # Find minimum distances
        min_dist_xy = torch.min(dist, dim=2)[0]  # (B, N)
        min_dist_yx = torch.min(dist, dim=1)[0]  # (B, M)
        
        # Calculate Hausdorff distance
        hausdorff_dist = torch.max(torch.max(min_dist_xy, dim=1)[0], torch.max(min_dist_yx, dim=1)[0])
        
        return hausdorff_dist

class AdversarialEvaluator:
    def __init__(self, model_list, device=None):
        """
        Class for evaluating model robustness against adversarial attacks
        
        Parameters:
        -----------
        model_list: dict, dictionary of models to evaluate {name: model}
        device: torch.device, device to perform computations on
        """
        self.models = model_list
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.attackers = {name: PointCloudAttacker(model, self.device) for name, model in self.models.items()}
        
    def evaluate_whitebox_attacks(self, dataloader, attack_params=None):
        """
        Evaluate models against white-box attacks
        
        Parameters:
        -----------
        dataloader: torch.utils.data.DataLoader, test data
        attack_params: dict, parameters for different attacks
        
        Returns:
        --------
        results: dict, evaluation results
        """
        if attack_params is None:
            attack_params = {
                'fgsm': {'epsilon': [0.01, 0.03, 0.05, 0.1]},
                'pgd': {
                    'epsilon': [0.03, 0.05],
                    'num_iter': [10, 20, 50],
                    'alpha': None  # Will be set to epsilon/4
                },
                'cw': {'confidence': [0, 20, 40], 'num_iter': 100}
            }
            
        results = {}
        
        # Evaluate each model
        for model_name, attacker in self.attackers.items():
            model_results = {}
            model = self.models[model_name]
            model.eval()
            
            print(f"Evaluating {model_name} against white-box attacks...")
            
            # FGSM attacks
            for eps in attack_params['fgsm']['epsilon']:
                clean_correct = 0
                adv_correct = 0
                total = 0
                chamfer_dists = []
                hausdorff_dists = []
                
                for data, labels in tqdm(dataloader, desc=f"FGSM ε={eps}"):
                    data, labels = data.to(self.device), labels.to(self.device)
                    
                    # Calculate clean accuracy
                    with torch.no_grad():
                        outputs = model(data)
                        _, predicted = outputs.max(1)
                        clean_correct += predicted.eq(labels).sum().item()
                    
                    # Generate adversarial examples
                    perturbed_data = attacker.fgsm_attack(data, labels, epsilon=eps)
                    
                    # Calculate adversarial accuracy
                    with torch.no_grad():
                        outputs = model(perturbed_data)
                        _, predicted = outputs.max(1)
                        adv_correct += predicted.eq(labels).sum().item()
                    
                    # Calculate distances
                    if data.shape[1] != 3:  # If data is in shape [B, 3, N]
                        data_reshaped = data.transpose(2, 1)
                        perturbed_reshaped = perturbed_data.transpose(2, 1)
                    else:
                        data_reshaped = data
                        perturbed_reshaped = perturbed_data
                        
                    chamfer_dists.append(attacker.chamfer_distance(data_reshaped, perturbed_reshaped).mean().item())
                    hausdorff_dists.append(attacker.hausdorff_distance(data_reshaped, perturbed_reshaped).mean().item())
                    
                    total += labels.size(0)
                
                model_results[f'fgsm_eps_{eps}'] = {
                    'clean_acc': clean_correct / total,
                    'adv_acc': adv_correct / total,
                    'chamfer_dist': np.mean(chamfer_dists),
                    'hausdorff_dist': np.mean(hausdorff_dists)
                }
                
            # PGD attacks
            for eps in attack_params['pgd']['epsilon']:
                for iters in attack_params['pgd']['num_iter']:
                    alpha = eps / 4  # Default step size
                    
                    clean_correct = 0
                    adv_correct = 0
                    total = 0
                    chamfer_dists = []
                    hausdorff_dists = []
                    
                    for data, labels in tqdm(dataloader, desc=f"PGD ε={eps}, iter={iters}"):
                        data, labels = data.to(self.device), labels.to(self.device)
                        
                        # Calculate clean accuracy
                        with torch.no_grad():
                            outputs = model(data)
                            _, predicted = outputs.max(1)
                            clean_correct += predicted.eq(labels).sum().item()
                        
                        # Generate adversarial examples
                        perturbed_data = attacker.pgd_attack(data, labels, epsilon=eps, alpha=alpha, num_iter=iters)
                        
                        # Calculate adversarial accuracy
                        with torch.no_grad():
                            outputs = model(perturbed_data)
                            _, predicted = outputs.max(1)
                            adv_correct += predicted.eq(labels).sum().item()
                        
                        # Calculate distances
                        if data.shape[1] != 3:  # If data is in shape [B, 3, N]
                            data_reshaped = data.transpose(2, 1)
                            perturbed_reshaped = perturbed_data.transpose(2, 1)
                        else:
                            data_reshaped = data
                            perturbed_reshaped = perturbed_data
                            
                        chamfer_dists.append(attacker.chamfer_distance(data_reshaped, perturbed_reshaped).mean().item())
                        hausdorff_dists.append(attacker.hausdorff_distance(data_reshaped, perturbed_reshaped).mean().item())
                        
                        total += labels.size(0)
                    
                    model_results[f'pgd_eps_{eps}_iter_{iters}'] = {
                        'clean_acc': clean_correct / total,
                        'adv_acc': adv_correct / total,
                        'chamfer_dist': np.mean(chamfer_dists),
                        'hausdorff_dist': np.mean(hausdorff_dists)
                    }
            
            # C&W attacks - more computation intensive, let's use a subset of the data
            for conf in attack_params['cw']['confidence']:
                clean_correct = 0
                adv_correct = 0
                total = 0
                chamfer_dists = []
                hausdorff_dists = []
                
                # Limit to first 200 samples for C&W due to computational cost
                sample_count = 0
                for data, labels in tqdm(dataloader, desc=f"C&W κ={conf}"):
                    if sample_count >= 200:
                        break
                        
                    data, labels = data.to(self.device), labels.to(self.device)
                    
                    # Calculate clean accuracy
                    with torch.no_grad():
                        outputs = model(data)
                        _, predicted = outputs.max(1)
                        clean_correct += predicted.eq(labels).sum().item()
                    
                    # Generate adversarial examples
                    perturbed_data = attacker.cw_attack(data, labels, confidence=conf, num_iter=attack_params['cw']['num_iter'])
                    
                    # Calculate adversarial accuracy
                    with torch.no_grad():
                        outputs = model(perturbed_data)
                        _, predicted = outputs.max(1)
                        adv_correct += predicted.eq(labels).sum().item()
                    
                    # Calculate distances
                    if data.shape[1] != 3:  # If data is in shape [B, 3, N]
                        data_reshaped = data.transpose(2, 1)
                        perturbed_reshaped = perturbed_data.transpose(2, 1)
                    else:
                        data_reshaped = data
                        perturbed_reshaped = perturbed_data
                        
                    chamfer_dists.append(attacker.chamfer_distance(data_reshaped, perturbed_reshaped).mean().item())
                    hausdorff_dists.append(attacker.hausdorff_distance(data_reshaped, perturbed_reshaped).mean().item())
                    
                    total += labels.size(0)
                    sample_count += labels.size(0)
                
                if total > 0:
                    model_results[f'cw_conf_{conf}'] = {
                        'clean_acc': clean_correct / total,
                        'adv_acc': adv_correct / total,
                        'chamfer_dist': np.mean(chamfer_dists),
                        'hausdorff_dist': np.mean(hausdorff_dists)
                    }
            
            results[model_name] = model_results
            
        return results

    def evaluate_point_specific_attacks(self, dataloader):
        """
        Evaluate models against point cloud-specific attacks
        
        Parameters:
        -----------
        dataloader: torch.utils.data.DataLoader, test data
        
        Returns:
        --------
        results: dict, evaluation results
        """
        results = {}
        
        # Define attack parameters
        perturbation_max = [0.02, 0.05, 0.1]
        removal_ratios = [0.05, 0.1, 0.15]
        
        # Evaluate each model
        for model_name, attacker in self.attackers.items():
            model_results = {}
            model = self.models[model_name]
            model.eval()
            
            print(f"Evaluating {model_name} against point cloud-specific attacks...")
            
            # Point perturbation attacks
            for max_disp in perturbation_max:
                clean_correct = 0
                adv_correct = 0
                total = 0
                chamfer_dists = []
                
                for data, labels in tqdm(dataloader, desc=f"Perturbation max={max_disp}"):
                    data, labels = data.to(self.device), labels.to(self.device)
                    
                    # Calculate clean accuracy
                    with torch.no_grad():
                        outputs = model(data)
                        _, predicted = outputs.max(1)
                        clean_correct += predicted.eq(labels).sum().item()
                    
                    # Generate adversarial examples
                    perturbed_data = attacker.point_perturbation_attack(data, labels, max_displacement=max_disp)
                    
                    # Calculate adversarial accuracy
                    with torch.no_grad():
                        outputs = model(perturbed_data)
                        _, predicted = outputs.max(1)
                        adv_correct += predicted.eq(labels).sum().item()
                    
                    # Calculate Chamfer distance
                    if data.shape[1] != 3:  # If data is in shape [B, 3, N]
                        data_reshaped = data.transpose(2, 1)
                        perturbed_reshaped = perturbed_data.transpose(2, 1)
                    else:
                        data_reshaped = data
                        perturbed_reshaped = perturbed_data
                        
                    chamfer_dists.append(attacker.chamfer_distance(data_reshaped, perturbed_reshaped).mean().item())
                    
                    total += labels.size(0)
                
                model_results[f'perturb_max_{max_disp}'] = {
                    'clean_acc': clean_correct / total,
                    'adv_acc': adv_correct / total,
                    'chamfer_dist': np.mean(chamfer_dists)
                }
            
            # Point removal attacks
            for ratio in removal_ratios:
                clean_correct = 0
                adv_correct = 0
                total = 0
                
                for data, labels in tqdm(dataloader, desc=f"Removal ratio={ratio}"):
                    data, labels = data.to(self.device), labels.to(self.device)
                    
                    # Calculate clean accuracy
                    with torch.no_grad():
                        outputs = model(data)
                        _, predicted = outputs.max(1)
                        clean_correct += predicted.eq(labels).sum().item()
                    
                    # Generate adversarial examples
                    perturbed_data = attacker.point_addition_removal(data, labels, ratio=ratio, mode='removal')
                    
                    # Calculate adversarial accuracy
                    with torch.no_grad():
                        outputs = model(perturbed_data)
                        _, predicted = outputs.max(1)
                        adv_correct += predicted.eq(labels).sum().item()
                    
                    total += labels.size(0)
                
                model_results[f'removal_ratio_{ratio}'] = {
                    'clean_acc': clean_correct / total,
                    'adv_acc': adv_correct / total
                }
            
            results[model_name] = model_results
            
        return results

    def cross_validate_attacks(self, dataset, attack_type, folds=5, **attack_params):
        """
        Perform cross-validation for adversarial attacks
        
        Parameters:
        -----------
        dataset: torch.utils.data.Dataset, full dataset
        attack_type: str, attack type to evaluate
        folds: int, number of folds for cross-validation
        attack_params: dict, parameters for the attack
        
        Returns:
        --------
        cv_results: dict, cross-validation results
        """
        # Define fold sizes
        dataset_size = len(dataset)
        fold_size = dataset_size // folds
        
        cv_results = {model_name: [] for model_name in self.models.keys()}
        
        for fold in range(folds):
            print(f"Cross-validation fold {fold+1}/{folds}")
            
            # Split dataset
            test_indices = list(range(fold * fold_size, (fold + 1) * fold_size))
            train_indices = list(set(range(dataset_size)) - set(test_indices))
            
            test_sampler = torch.utils.data.SubsetRandomSampler(test_indices)
            test_loader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=test_sampler)
            
            # Evaluate each model
            for model_name, model in self.models.items():
                model.eval()
                attacker = self.attackers[model_name]
                
                # Initialize attack function based on type
                if attack_type == 'fgsm':
                    attack_fn = attacker.fgsm_attack
                elif attack_type == 'pgd':
                    attack_fn = attacker.pgd_attack
                elif attack_type == 'cw':
                    attack_fn = attacker.cw_attack
                elif attack_type == 'perturb':
                    attack_fn = attacker.point_perturbation_attack
                else:
                    raise ValueError(f"Unsupported attack type for cross-validation: {attack_type}")
                
                # Evaluate on test set
                clean_correct = 0
                adv_correct = 0
                total = 0
                
                for data, labels in test_loader:
                    data, labels = data.to(self.device), labels.to(self.device)
                    
                    # Calculate clean accuracy
                    with torch.no_grad():
                        outputs = model(data)
                        _, predicted = outputs.max(1)
                        clean_correct += predicted.eq(labels).sum().item()
                    
                    # Generate adversarial examples
                    perturbed_data = attack_fn(data, labels, **attack_params)
                    
                    # Calculate adversarial accuracy
                    with torch.no_grad():
                        outputs = model(perturbed_data)
                        _, predicted = outputs.max(1)
                        adv_correct += predicted.eq(labels).sum().item()
                    
                    total += labels.size(0)
                
                fold_result = {
                    'fold': fold,
                    'clean_acc': clean_correct / total,
                    'adv_acc': adv_correct / total
                }
                cv_results[model_name].append(fold_result)
                
        # Calculate mean and std across folds
        for model_name in cv_results:
            clean_accs = [fold['clean_acc'] for fold in cv_results[model_name]]
            adv_accs = [fold['adv_acc'] for fold in cv_results[model_name]]
            
            cv_results[model_name].append({
                'mean_clean_acc': np.mean(clean_accs),
                'std_clean_acc': np.std(clean_accs),
                'mean_adv_acc': np.mean(adv_accs),
                'std_adv_acc': np.std(adv_accs)
            })
            
        return cv_results

In [12]:
import torch
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve, auc
import time
from tqdm.notebook import tqdm
from scipy import stats
import pandas as pd
from torch.profiler import profile, record_function, ProfilerActivity

import matplotlib.pyplot as plt

class EvaluationFramework:
    def __init__(self, models, device=None, num_classes=40):
        """
        Comprehensive evaluation framework for 3D point cloud models
        
        Parameters:
        -----------
        models: dict, dictionary of models to evaluate {name: model}
        device: torch.device, device to perform computations on
        num_classes: int, number of classes in the dataset
        """
        self.models = models
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.num_classes = num_classes
        self.class_names = None  # Should be set when evaluating with actual dataset
        
        # Move models to device
        for name, model in self.models.items():
            model.to(self.device)
    
    # ======== Performance Metrics ========
    
    def evaluate_classification_performance(self, dataloader, verbose=True):
        """
        Evaluate classification performance metrics
        
        Parameters:
        -----------
        dataloader: DataLoader, test dataloader
        verbose: bool, whether to print results
        
        Returns:
        --------
        results: dict, metrics for each model
        """
        results = {}
        
        for name, model in self.models.items():
            model.eval()
            all_preds = []
            all_labels = []
            
            with torch.no_grad():
                for points, labels in tqdm(dataloader, desc=f"Evaluating {name}", disable=not verbose):
                    points, labels = points.to(self.device), labels.to(self.device)
                    
                    # Model prediction
                    logits = model(points)
                    
                    # Store predictions and labels
                    _, preds = torch.topk(logits, k=5, dim=1)
                    all_preds.append(preds.cpu().numpy())
                    all_labels.append(labels.cpu().numpy())
            
            # Concatenate results
            all_preds = np.concatenate(all_preds, axis=0)
            all_labels = np.concatenate(all_labels, axis=0)
            
            # Calculate metrics
            top1_acc = (all_preds[:, 0] == all_labels).mean()
            top5_acc = np.any(all_preds == all_labels.reshape(-1, 1), axis=1).mean()
            
            # Per-class metrics
            precision, recall, f1, _ = precision_recall_fscore_support(
                all_labels, all_preds[:, 0], average=None, labels=range(self.num_classes)
            )
            
            # Average metrics
            avg_precision, avg_recall, avg_f1, _ = precision_recall_fscore_support(
                all_labels, all_preds[:, 0], average='macro'
            )
            
            # Confusion matrix
            conf_mat = confusion_matrix(all_labels, all_preds[:, 0], labels=range(self.num_classes))
            
            # Store results
            results[name] = {
                'top1_accuracy': top1_acc,
                'top5_accuracy': top5_acc,
                'per_class_precision': precision,
                'per_class_recall': recall,
                'per_class_f1': f1,
                'avg_precision': avg_precision,
                'avg_recall': avg_recall,
                'avg_f1': avg_f1,
                'confusion_matrix': conf_mat,
                'all_preds': all_preds,
                'all_labels': all_labels
            }
            
            if verbose:
                print(f"\n{name} Classification Results:")
                print(f"Top-1 Accuracy: {top1_acc:.4f}")
                print(f"Top-5 Accuracy: {top5_acc:.4f}")
                print(f"Average Precision: {avg_precision:.4f}")
                print(f"Average Recall: {avg_recall:.4f}")
                print(f"Average F1 Score: {avg_f1:.4f}")
                
        return results
    
    def evaluate_robustness_metrics(self, clean_dataloader, corrupted_loaders, verbose=True):
        """
        Evaluate robustness metrics across different corruptions
        
        Parameters:
        -----------
        clean_dataloader: DataLoader, clean test data
        corrupted_loaders: dict, {corruption_name: {severity: dataloader}}
        verbose: bool, whether to print results
        
        Returns:
        --------
        results: dict, robustness metrics for each model
        """
        results = {}
        
        # Evaluate clean accuracy first
        clean_results = self.evaluate_classification_performance(clean_dataloader, verbose=False)
        
        for name, model in self.models.items():
            model.eval()
            model_results = {
                'clean_accuracy': clean_results[name]['top1_accuracy'],
                'corruption_results': {}
            }
            
            # Evaluate on each corruption type and severity
            corruption_errors = []
            
            for corruption, severity_loaders in corrupted_loaders.items():
                corruption_accs = []
                
                for severity, loader in severity_loaders.items():
                    # Evaluate on corrupted data
                    all_preds = []
                    all_labels = []
                    
                    with torch.no_grad():
                        for points, labels in loader:
                            points, labels = points.to(self.device), labels.to(self.device)
                            logits = model(points)
                            preds = logits.argmax(dim=1)
                            all_preds.append(preds.cpu().numpy())
                            all_labels.append(labels.cpu().numpy())
                    
                    # Calculate accuracy
                    all_preds = np.concatenate(all_preds, axis=0)
                    all_labels = np.concatenate(all_labels, axis=0)
                    accuracy = (all_preds == all_labels).mean()
                    
                    # Store results
                    model_results['corruption_results'][(corruption, severity)] = {
                        'accuracy': accuracy,
                        'error': 1.0 - accuracy
                    }
                    
                    corruption_accs.append(accuracy)
                    corruption_errors.append(1.0 - accuracy)
                
                # Calculate clean-corrupted accuracy gap (CCAG) for this corruption
                avg_corrupt_acc = np.mean(corruption_accs)
                ccag = clean_results[name]['top1_accuracy'] - avg_corrupt_acc
                
                model_results['corruption_results'][corruption] = {
                    'avg_accuracy': avg_corrupt_acc,
                    'ccag': ccag
                }
            
            # Calculate average corruption error (ACE)
            ace = np.mean(corruption_errors)
            
            # Calculate effective robustness (ER) - relative to baseline performance
            # This would typically compare to a baseline model, but for now we'll use relative to clean performance
            er = clean_results[name]['top1_accuracy'] - ace
            
            # Calculate area under robustness curve (AURC)
            # For simplicity, we'll calculate area under error vs corruption severity curve
            aurc = np.trapz(corruption_errors) / len(corruption_errors)
            
            model_results['summary'] = {
                'ace': ace,
                'er': er,
                'aurc': aurc
            }
            
            results[name] = model_results
            
            if verbose:
                print(f"\n{name} Robustness Results:")
                print(f"Clean Accuracy: {clean_results[name]['top1_accuracy']:.4f}")
                print(f"Average Corruption Error (ACE): {ace:.4f}")
                print(f"Effective Robustness (ER): {er:.4f}")
                print(f"Area Under Robustness Curve (AURC): {aurc:.4f}")
                
        return results
    
    def benchmark_efficiency(self, sample_input, num_runs=100, verbose=True):
        """
        Benchmark the efficiency metrics of models
        
        Parameters:
        -----------
        sample_input: torch.Tensor, sample input for models
        num_runs: int, number of runs for time measurement
        verbose: bool, whether to print results
        
        Returns:
        --------
        results: dict, efficiency metrics for each model
        """
        results = {}
        sample_input = sample_input.to(self.device)
        
        for name, model in self.models.items():
            model.eval()
            model_results = {}
            
            # Measure inference time
            start_time = time.time()
            with torch.no_grad():
                for _ in range(num_runs):
                    _ = model(sample_input)
            end_time = time.time()
            avg_inference_time = (end_time - start_time) / num_runs
            
            # Measure memory consumption
            torch.cuda.reset_peak_memory_stats()
            torch.cuda.empty_cache()
            with torch.no_grad():
                _ = model(sample_input)
            memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB
            
            # Estimate FLOPs using PyTorch profiler
            with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                         with_flops=True) as prof:
                with record_function("model_inference"):
                    _ = model(sample_input)
            
            flops_estimate = sum(evt.flops for evt in prof.key_averages() if evt.flops > 0)
            
            model_results = {
                'inference_time_ms': avg_inference_time * 1000,
                'memory_usage_mb': memory_usage,
                'estimated_flops': flops_estimate
            }
            
            results[name] = model_results
            
            if verbose:
                print(f"\n{name} Efficiency Metrics:")
                print(f"Inference Time: {avg_inference_time * 1000:.2f} ms")
                print(f"Memory Usage: {memory_usage:.2f} MB")
                print(f"Estimated FLOPs: {flops_estimate:.2e}")
                
        return results
    
    # ======== Analysis Procedures ========
    
    def ablation_study(self, model_name, components, dataloader):
        """
        Perform ablation study on model components
        
        Parameters:
        -----------
        model_name: str, name of the model to ablate
        components: list of tuples, (component_name, ablation_function)
        dataloader: DataLoader, test data
        
        Returns:
        --------
        results: dict, performance with each component ablated
        """
        # Get original model
        original_model = self.models[model_name]
        original_state = copy.deepcopy(original_model.state_dict())
        
        # Evaluate original performance
        original_performance = self.evaluate_classification_performance(
            dataloader, verbose=False)[model_name]['top1_accuracy']
        
        results = {'original': original_performance}
        
        # Test each ablation
        for component_name, ablation_fn in components:
            # Apply ablation
            ablation_fn(original_model)
            
            # Evaluate ablated model
            ablated_performance = self.evaluate_classification_performance(
                dataloader, verbose=False)[model_name]['top1_accuracy']
            
            # Restore original model
            original_model.load_state_dict(original_state)
            
            # Store results
            results[component_name] = ablated_performance
            
        return results
    
    def parameter_sensitivity_analysis(self, model_name, parameter_ranges, dataloader):
        """
        Analyze sensitivity to critical parameters
        
        Parameters:
        -----------
        model_name: str, name of the model to analyze
        parameter_ranges: dict, {param_name: list of values}
        dataloader: DataLoader, test data
        
        Returns:
        --------
        results: dict, performance across parameter values
        """
        # Get original model
        original_model = self.models[model_name]
        original_state = copy.deepcopy(original_model.state_dict())
        
        results = {}
        
        # Test each parameter
        for param_name, param_values in parameter_ranges.items():
            param_results = []
            
            for value in param_values:
                # Set parameter value (this would depend on model architecture)
                # Example: setattr(original_model, param_name, value)
                
                # Evaluate model with this parameter value
                performance = self.evaluate_classification_performance(
                    dataloader, verbose=False)[model_name]['top1_accuracy']
                param_results.append((value, performance))
                
                # Restore original model
                original_model.load_state_dict(original_state)
            
            results[param_name] = param_results
            
        return results
    
    def attack_transferability_analysis(self, source_models, target_models, dataloader, attack_method, **attack_params):
        """
        Analyze the transferability of adversarial attacks between models
        
        Parameters:
        -----------
        source_models: list, names of models to generate attacks from
        target_models: list, names of models to test attacks on
        dataloader: DataLoader, test data
        attack_method: function, attack method to use
        attack_params: dict, parameters for the attack
        
        Returns:
        --------
        results: dict, transfer success rates between all model pairs
        """
        results = {}
        
        for source_name in source_models:
            source_model = self.models[source_name]
            source_attacker = PointCloudAttacker(source_model, self.device)
            
            # Results for this source model
            source_results = {}
            
            for target_name in target_models:
                if source_name == target_name:
                    continue  # Skip self-targeting
                    
                target_model = self.models[target_name]
                target_model.eval()
                
                # Track performance
                clean_correct = 0
                adv_correct = 0
                total = 0
                
                for points, labels in dataloader:
                    points, labels = points.to(self.device), labels.to(self.device)
                    
                    # Generate adversarial examples on source model
                    adv_points = attack_method(source_attacker, points, labels, **attack_params)
                    
                    # Evaluate performance on target model
                    with torch.no_grad():
                        # Clean performance
                        clean_outputs = target_model(points)
                        clean_preds = clean_outputs.argmax(dim=1)
                        clean_correct += (clean_preds == labels).sum().item()
                        
                        # Adversarial performance
                        adv_outputs = target_model(adv_points)
                        adv_preds = adv_outputs.argmax(dim=1)
                        adv_correct += (adv_preds == labels).sum().item()
                        
                    total += labels.size(0)
                
                # Calculate transfer success rate
                clean_acc = clean_correct / total
                adv_acc = adv_correct / total
                transfer_success_rate = (clean_acc - adv_acc) / clean_acc  # Normalized attack success
                
                source_results[target_name] = {
                    'clean_accuracy': clean_acc,
                    'adv_accuracy': adv_acc,
                    'transfer_success_rate': transfer_success_rate
                }
                
            results[source_name] = source_results
            
        return results
    
    def statistical_significance_testing(self, model1_name, model2_name, dataloader, num_bootstrap=1000):
        """
        Perform statistical significance testing between two models
        
        Parameters:
        -----------
        model1_name: str, name of first model
        model2_name: str, name of second model
        dataloader: DataLoader, test data
        num_bootstrap: int, number of bootstrap samples
        
        Returns:
        --------
        results: dict, statistical test results
        """
        model1 = self.models[model1_name]
        model2 = self.models[model2_name]
        
        model1.eval()
        model2.eval()
        
        # Collect all predictions
        all_labels = []
        model1_preds = []
        model2_preds = []
        
        with torch.no_grad():
            for points, labels in dataloader:
                points, labels = points.to(self.device), labels.to(self.device)
                
                # Model 1 predictions
                outputs1 = model1(points)
                preds1 = outputs1.argmax(dim=1)
                
                # Model 2 predictions
                outputs2 = model2(points)
                preds2 = outputs2.argmax(dim=1)
                
                # Store results
                all_labels.append(labels.cpu().numpy())
                model1_preds.append(preds1.cpu().numpy())
                model2_preds.append(preds2.cpu().numpy())
        
        # Concatenate results
        all_labels = np.concatenate(all_labels, axis=0)
        model1_preds = np.concatenate(model1_preds, axis=0)
        model2_preds = np.concatenate(model2_preds, axis=0)
        
        # Calculate accuracy
        model1_correct = (model1_preds == all_labels)
        model2_correct = (model2_preds == all_labels)
        
        model1_acc = model1_correct.mean()
        model2_acc = model2_correct.mean()
        
        # Paired t-test
        t_stat, p_value = stats.ttest_rel(model1_correct, model2_correct)
        
        # Bootstrap confidence intervals
        bootstrap_diffs = []
        for _ in range(num_bootstrap):
            indices = np.random.choice(len(all_labels), len(all_labels), replace=True)
            model1_sample = model1_correct[indices].mean()
            model2_sample = model2_correct[indices].mean()
            bootstrap_diffs.append(model1_sample - model2_sample)
            
        bootstrap_diffs = np.array(bootstrap_diffs)
        ci_lower = np.percentile(bootstrap_diffs, 2.5)
        ci_upper = np.percentile(bootstrap_diffs, 97.5)
        
        results = {
            f"{model1_name}_accuracy": model1_acc,
            f"{model2_name}_accuracy": model2_acc,
            'accuracy_diff': model1_acc - model2_acc,
            't_statistic': t_stat,
            'p_value': p_value,
            'significant': p_value < 0.05,
            'bootstrap_ci': (ci_lower, ci_upper),
            'bootstrap_significant': (ci_lower > 0 or ci_upper < 0)  # CI doesn't contain 0
        }
        
        return results
    
    # ======== Real-world Validation ========
    
    def synthetic_to_real_transfer(self, synthetic_loader, real_loader):
        """
        Evaluate synthetic-to-real transfer performance
        
        Parameters:
        -----------
        synthetic_loader: DataLoader, synthetic test data
        real_loader: DataLoader, real-world test data
        
        Returns:
        --------
        results: dict, performance comparison on synthetic and real data
        """
        results = {}
        
        for name, model in self.models.items():
            model.eval()
            
            # Evaluate on synthetic data
            synthetic_performance = self.evaluate_classification_performance(
                synthetic_loader, verbose=False)[name]
            
            # Evaluate on real data
            real_performance = self.evaluate_classification_performance(
                real_loader, verbose=False)[name]
            
            # Calculate transfer gap
            transfer_gap = synthetic_performance['top1_accuracy'] - real_performance['top1_accuracy']
            
            results[name] = {
                'synthetic_accuracy': synthetic_performance['top1_accuracy'],
                'real_accuracy': real_performance['top1_accuracy'],
                'transfer_gap': transfer_gap
            }
            
        return results
    
    # ======== Visualization Methods ========
    
    def plot_confusion_matrix(self, model_name, results=None, figsize=(12, 10)):
        """
        Plot confusion matrix for a model
        
        Parameters:
        -----------
        model_name: str, name of the model
        results: dict, results from evaluate_classification_performance
        figsize: tuple, figure size
        """
        if results is None:
            raise ValueError("Results dictionary must be provided")
            
        if model_name not in results:
            raise ValueError(f"No results found for model {model_name}")
            
        conf_mat = results[model_name]['confusion_matrix']
        
        # Normalize confusion matrix
        conf_mat_norm = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis]
        
        plt.figure(figsize=figsize)
        sns.heatmap(conf_mat_norm, annot=False, cmap='Blues', fmt='.2f', 
                    xticklabels=self.class_names if self.class_names else range(self.num_classes),
                    yticklabels=self.class_names if self.class_names else range(self.num_classes))
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'Confusion Matrix - {model_name}')
        plt.tight_layout()
        plt.show()
    
    def plot_per_class_metrics(self, model_name, results=None, figsize=(15, 6)):
        """
        Plot per-class precision, recall and F1
        
        Parameters:
        -----------
        model_name: str, name of the model
        results: dict, results from evaluate_classification_performance
        figsize: tuple, figure size
        """
        if results is None:
            raise ValueError("Results dictionary must be provided")
            
        if model_name not in results:
            raise ValueError(f"No results found for model {model_name}")
            
        precision = results[model_name]['per_class_precision']
        recall = results[model_name]['per_class_recall']
        f1 = results[model_name]['per_class_f1']
        
        # Create dataframe for plotting
        metrics_df = pd.DataFrame({
            'Class': self.class_names if self.class_names else range(self.num_classes),
            'Precision': precision,
            'Recall': recall,
            'F1 Score': f1
        })
        
        # Melt the dataframe for easier plotting
        melted_df = pd.melt(metrics_df, id_vars=['Class'], var_name='Metric', value_name='Score')
        
        plt.figure(figsize=figsize)
        sns.barplot(x='Class', y='Score', hue='Metric', data=melted_df)
        plt.title(f'Per-Class Performance Metrics - {model_name}')
        plt.xticks(rotation=90)
        plt.ylim(0, 1)
        plt.tight_layout()
        plt.show()
    
    def plot_robustness_curves(self, robustness_results, corruption_types=None, figsize=(15, 6)):
        """
        Plot robustness curves across corruption types and severities
        
        Parameters:
        -----------
        robustness_results: dict, results from evaluate_robustness_metrics
        corruption_types: list, corruption types to plot (None for all)
        figsize: tuple, figure size
        """
        if corruption_types is None:
            # Extract all corruption types from first model
            first_model = list(robustness_results.keys())[0]
            corruption_types = set()
            for key in robustness_results[first_model]['corruption_results'].keys():
                if isinstance(key, tuple):
                    corruption_types.add(key[0])
            corruption_types = list(corruption_types)
        
        # Create a plot for each corruption type
        for corruption in corruption_types:
            plt.figure(figsize=figsize)
            
            for model_name, results in robustness_results.items():
                severities = []
                accuracies = []
                
                for key, val in results['corruption_results'].items():
                    if isinstance(key, tuple) and key[0] == corruption:
                        severity = key[1]
                        severities.append(severity)
                        accuracies.append(val['accuracy'])
                
                # Sort by severity
                sorted_indices = np.argsort(severities)
                severities = [severities[i] for i in sorted_indices]
                accuracies = [accuracies[i] for i in sorted_indices]
                
                # Plot
                plt.plot(severities, accuracies, marker='o', label=model_name)
                
                # Add clean performance as severity 0
                plt.plot([0], [results['clean_accuracy']], marker='x', color='black')
                
            plt.xlabel('Corruption Severity')
            plt.ylabel('Accuracy')
            plt.title(f'Model Robustness Under {corruption.capitalize()} Corruption')
            plt.legend()
            plt.grid(True, linestyle='--', alpha=0.7)
            plt.tight_layout()
            plt.show()
    
    def plot_efficiency_comparison(self, efficiency_results, metric='inference_time_ms', figsize=(10, 6)):
        """
        Plot efficiency comparison between models
        
        Parameters:
        -----------
        efficiency_results: dict, results from benchmark_efficiency
        metric: str, which efficiency metric to plot
        figsize: tuple, figure size
        """
        model_names = list(efficiency_results.keys())
        metric_values = [efficiency_results[name][metric] for name in model_names]
        
        plt.figure(figsize=figsize)
        bars = plt.bar(model_names, metric_values)
        
        # Add values on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.2f}',
                    ha='center', va='bottom', rotation=0)
        
        plt.ylabel(metric.replace('_', ' ').title())
        plt.title(f'Model Efficiency Comparison - {metric.replace("_", " ").title()}')
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()
    
    def plot_ablation_results(self, ablation_results, figsize=(12, 6)):
        """
        Plot results of ablation study
        
        Parameters:
        -----------
        ablation_results: dict, results from ablation_study
        figsize: tuple, figure size
        """
        components = list(ablation_results.keys())
        values = list(ablation_results.values())
        
        plt.figure(figsize=figsize)
        bars = plt.bar(components, values)
        
        # Add values on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.4f}',
                    ha='center', va='bottom', rotation=0)
        
        plt.ylabel('Accuracy')
        plt.title('Model Performance with Component Ablations')
        plt.xticks(rotation=45)
        plt.ylim(0, max(values) * 1.1)
        plt.tight_layout()
        plt.show()
    
    def plot_transferability_heatmap(self, transferability_results, figsize=(10, 8)):
        """
        Plot heatmap of attack transferability
        
        Parameters:
        -----------
        transferability_results: dict, results from attack_transferability_analysis
        figsize: tuple, figure size
        """
        source_models = list(transferability_results.keys())
        target_models = set()
        for source in source_models:
            target_models.update(transferability_results[source].keys())
        target_models = list(target_models)
        
        # Create transfer success rate matrix
        transfer_matrix = np.zeros((len(source_models), len(target_models)))
        
        for i, source in enumerate(source_models):
            for j, target in enumerate(target_models):
                if target in transferability_results[source]:
                    transfer_matrix[i, j] = transferability_results[source][target]['transfer_success_rate']
                else:
                    transfer_matrix[i, j] = np.nan  # Self-transfer
        
        plt.figure(figsize=figsize)
        mask = np.isnan(transfer_matrix)
        sns.heatmap(transfer_matrix, annot=True, cmap='YlOrRd', fmt='.2f', 
                    xticklabels=target_models, yticklabels=source_models,
                    mask=mask)
        plt.xlabel('Target Model')
        plt.ylabel('Source Model')
        plt.title('Attack Transferability Success Rate')
        plt.tight_layout()
        plt.show()

In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt

class ExperimentalProtocol:
    """Comprehensive experimental protocol for point cloud model training and evaluation."""
    
    def __init__(self, models, dataset_path, num_points=1024, num_classes=40, 
                 batch_size=32, device=None):
        """
        Initialize the experimental protocol
        
        Parameters:
        -----------
        models: dict, dictionary of models to evaluate {name: model_init_function}
        dataset_path: str, path to the ModelNet40 dataset
        num_points: int, number of points per point cloud
        num_classes: int, number of classes
        batch_size: int, batch size for training
        device: torch.device, device to use for training
        """
        self.models = models
        self.dataset_path = dataset_path
        self.num_points = num_points
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize dataset
        self.full_dataset = ModelNet40Dataset(
            root_dir=dataset_path,
            split='train',  # We'll split this manually later
            num_points=num_points,
            random_rotation=True
        )
        
        # Initialize helper classes
        self.training_manager = TrainingManager(device=self.device)
        self.cross_validator = CrossValidator()
        self.hyperparameter_optimizer = HyperparameterOptimizer()
        self.benchmark_protocol = BenchmarkProtocol(device=self.device)

    def prepare_data_splits(self, test_size=0.15, val_size=0.15, seed=42):
        """
        Prepare data splits for training, validation, and testing
        
        Parameters:
        -----------
        test_size: float, fraction of data for testing
        val_size: float, fraction of data for validation
        seed: int, random seed for reproducibility
        
        Returns:
        --------
        train_dataset, val_dataset, test_dataset: Dataset objects
        """
        # Set random seeds for reproducibility
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        
        # Get all data indices and labels for stratification
        all_labels = [self.full_dataset.labels[i] for i in range(len(self.full_dataset))]
        
        # Calculate split sizes
        dataset_size = len(self.full_dataset)
        test_count = int(dataset_size * test_size)
        val_count = int(dataset_size * val_size)
        train_count = dataset_size - test_count - val_count
        
        # Create stratified splits
        train_idx, temp_idx = self.cross_validator.stratified_split(
            all_labels, train_count / dataset_size, seed
        )
        
        # Further split temporary indices into validation and test sets
        val_proportion = val_count / (val_count + test_count)
        val_idx, test_idx = self.cross_validator.stratified_split(
            [all_labels[i] for i in temp_idx], val_proportion, seed
        )
        val_idx = [temp_idx[i] for i in val_idx]
        test_idx = [temp_idx[i] for i in test_idx]
        
        # Create dataset subsets
        train_dataset = Subset(self.full_dataset, train_idx)
        val_dataset = Subset(self.full_dataset, val_idx)
        test_dataset = Subset(self.full_dataset, test_idx)
        
        print(f"Dataset split: Training={len(train_dataset)}, Validation={len(val_dataset)}, Test={len(test_dataset)}")
        
        return train_dataset, val_dataset, test_dataset
    
    def run_full_protocol(self, model_name, train_dataset, val_dataset, test_dataset, output_dir="results"):
        """
        Run the complete experimental protocol for a single model
        
        Parameters:
        -----------
        model_name: str, name of the model to evaluate
        train_dataset, val_dataset, test_dataset: Dataset objects
        output_dir: str, directory to save results
        
        Returns:
        --------
        results: dict, complete evaluation results
        """
        os.makedirs(output_dir, exist_ok=True)
        results = {}
        
        # 1. Hyperparameter optimization on validation set
        print(f"Starting hyperparameter optimization for {model_name}...")
        best_params = self.hyperparameter_optimizer.optimize_model_hyperparams(
            model_name, self.models[model_name], train_dataset, val_dataset, 
            self.num_classes, self.device
        )
        results['best_hyperparams'] = best_params
        
        # 2. Initialize model with best hyperparameters
        model_init_fn = self.models[model_name]
        model = model_init_fn(num_classes=self.num_classes, **best_params)
        model.to(self.device)
        
        # 3. Multi-phase training
        print(f"Starting multi-phase training for {model_name}...")
        
        # Clean pre-training phase
        print("Phase 1: Clean pre-training")
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        
        model = self.training_manager.train_clean(
            model, train_loader, val_loader, 
            epochs=50, 
            lr=0.001, 
            weight_decay=0.01
        )
        
        # Save checkpoint after clean training
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_clean_trained.pth")
        
        # Basic adversarial training phase
        print("Phase 2: Adversarial training")
        model = self.training_manager.train_adversarial(
            model, train_loader, val_loader,
            epochs=30,
            lr=0.0005,
            attack_type='pgd',
            epsilon=0.03,
            alpha=0.007,
            steps=10
        )
        
        # Save checkpoint after adversarial training
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_adv_trained.pth")
        
        # Environmental corruption training phase
        print("Phase 3: Environmental corruption training")
        model = self.training_manager.train_corruptions(
            model, train_dataset, val_dataset,
            epochs=30,
            lr=0.0003,
            batch_size=16,
            corruptions=['gaussian_noise', 'snow', 'fog'],
            severities=[1, 3, 5]
        )
        
        # Save checkpoint after corruption training
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_corruption_trained.pth")
        
        # Combined robustness fine-tuning
        print("Phase 4: Combined robustness fine-tuning")
        model = self.training_manager.train_combined_robustness(
            model, train_dataset, val_dataset,
            epochs=20,
            lr=0.0001,
            batch_size=16,
            adv_weight=0.5,
            corruption_weight=0.5
        )
        
        # Save final model
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_final.pth")
        
        # 4. Comprehensive evaluation
        print(f"Evaluating final model for {model_name}...")
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
        
        evaluator = EvaluationFramework({model_name: model}, device=self.device, num_classes=self.num_classes)
        
        # Basic classification performance
        class_results = evaluator.evaluate_classification_performance(test_loader)
        results['classification_performance'] = class_results
        
        # Adversarial robustness
        attacker = PointCloudAttacker(model, self.device)
        adv_results = {}
        
        for attack_type in ['fgsm', 'pgd']:
            for strength in [0.01, 0.03, 0.05]:
                adv_acc = self._evaluate_adversarial_robustness(
                    model, test_loader, attacker, attack_type, strength
                )
                adv_results[f"{attack_type}_{strength}"] = adv_acc
        
        results['adversarial_robustness'] = adv_results
        
        # Corruption robustness
        corruption_results = {}
        for corruption in ['gaussian_noise', 'fog', 'snow']:
            for severity in [1, 3, 5]:
                corrupted_dataset = get_corrupted_dataset(
                    test_dataset, corruption, severity
                )
                corrupted_loader = DataLoader(
                    corrupted_dataset, batch_size=self.batch_size, shuffle=False
                )
                
                with torch.no_grad():
                    correct = 0
                    total = 0
                    for points, labels in corrupted_loader:
                        points, labels = points.to(self.device), labels.to(self.device)
                        outputs = model(points)
                        _, predicted = outputs.max(1)
                        correct += predicted.eq(labels).sum().item()
                        total += labels.size(0)
                
                corruption_results[f"{corruption}_{severity}"] = correct / total
        
        results['corruption_robustness'] = corruption_results
        
        # Save results
        np.save(f"{output_dir}/{model_name}_results.npy", results)
        
        return results
    
    def _evaluate_adversarial_robustness(self, model, dataloader, attacker, attack_type, strength):
        """Helper method to evaluate adversarial robustness"""
        correct = 0
        total = 0
        
        for points, labels in dataloader:
            points, labels = points.to(self.device), labels.to(self.device)
            
            # Generate adversarial examples
            if attack_type == 'fgsm':
                adv_points = attacker.fgsm_attack(points, labels, epsilon=strength)
            elif attack_type == 'pgd':
                adv_points = attacker.pgd_attack(points, labels, epsilon=strength, alpha=strength/4)
                
            # Evaluate
            with torch.no_grad():
                outputs = model(adv_points)
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        
        return correct / total
        
    def run_cross_validation(self, model_name, n_folds=5, seed=42):
        """
        Run the full k-fold cross-validation procedure
        
        Parameters:
        -----------
        model_name: str, name of the model to evaluate
        n_folds: int, number of folds for cross-validation
        seed: int, random seed for reproducibility
        
        Returns:
        --------
        cv_results: dict, cross-validation results
        """
        # Get all data points and labels
        all_data_idx = list(range(len(self.full_dataset)))
        all_labels = [self.full_dataset.labels[i] for i in range(len(self.full_dataset))]
        
        # Initialize stratified k-fold
        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
        
        fold_results = []
        
        for fold, (train_val_idx, test_idx) in enumerate(skf.split(all_data_idx, all_labels)):
            print(f"\n--- Fold {fold+1}/{n_folds} ---")
            
            # Split train_val into train and validation
            train_idx, val_idx = self.cross_validator.stratified_split(
                [all_labels[i] for i in train_val_idx], 0.8, seed + fold
            )
            train_idx = [train_val_idx[i] for i in train_idx]
            val_idx = [train_val_idx[i] for i in val_idx]
            
            # Create dataset subsets
            train_dataset = Subset(self.full_dataset, train_idx)
            val_dataset = Subset(self.full_dataset, val_idx)
            test_dataset = Subset(self.full_dataset, test_idx)
            
            # Run the full protocol
            output_dir = f"results/cv_fold_{fold+1}"
            fold_result = self.run_full_protocol(
                model_name, train_dataset, val_dataset, test_dataset, output_dir
            )
            
            fold_results.append(fold_result)
        
        # Aggregate cross-validation results
        cv_results = self._aggregate_cv_results(fold_results)
        
        return cv_results
    
    def _aggregate_cv_results(self, fold_results):
        """Helper method to aggregate cross-validation results"""
        # Extract metrics across folds
        classification_metrics = {
            'top1_accuracy': [],
            'avg_precision': [],
            'avg_recall': [],
            'avg_f1': []
        }
        
        for result in fold_results:
            model_name = list(result['classification_performance'].keys())[0]
            metrics = result['classification_performance'][model_name]
            
            classification_metrics['top1_accuracy'].append(metrics['top1_accuracy'])
            classification_metrics['avg_precision'].append(metrics['avg_precision'])
            classification_metrics['avg_recall'].append(metrics['avg_recall'])
            classification_metrics['avg_f1'].append(metrics['avg_f1'])
        
        # Calculate mean and standard deviation
        aggregated = {}
        for metric, values in classification_metrics.items():
            aggregated[f'mean_{metric}'] = np.mean(values)
            aggregated[f'std_{metric}'] = np.std(values)
        
        return aggregated
    
    def benchmark_against_baselines(self, model_name, baselines, train_dataset, val_dataset, test_dataset):
        """
        Benchmark the model against baseline methods
        
        Parameters:
        -----------
        model_name: str, name of the model to evaluate
        baselines: dict, {name: model_init_function} for baseline models
        train_dataset, val_dataset, test_dataset: Dataset objects
        
        Returns:
        --------
        benchmark_results: dict, comparative results
        """
        # Add target model to the baselines
        all_models = {**{model_name: self.models[model_name]}, **baselines}
        model_performances = {}
        
        # Train and evaluate each model
        for name, model_fn in all_models.items():
            print(f"\nEvaluating model: {name}")
            
            # Initialize model
            model = model_fn(num_classes=self.num_classes)
    
class ExperimentalProtocol:
    """Comprehensive experimental protocol for point cloud model training and evaluation."""
    
    def __init__(self, models, dataset_path, num_points=1024, num_classes=40, 
                 batch_size=32, device=None):
        """
        Initialize the experimental protocol
        
        Parameters:
        -----------
        models: dict, dictionary of models to evaluate {name: model_init_function}
        dataset_path: str, path to the ModelNet40 dataset
        num_points: int, number of points per point cloud
        num_classes: int, number of classes
        batch_size: int, batch size for training
        device: torch.device, device to use for training
        """
        self.models = models
        self.dataset_path = dataset_path
        self.num_points = num_points
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Initialize dataset
        self.full_dataset = ModelNet40Dataset(
            root_dir=dataset_path,
            split='train',  # We'll split this manually later
            num_points=num_points,
            random_rotation=True
        )
        
        # Initialize helper classes
        self.training_manager = TrainingManager(device=self.device)
        self.cross_validator = CrossValidator()
        self.hyperparameter_optimizer = HyperparameterOptimizer()
        self.benchmark_protocol = BenchmarkProtocol(device=self.device)

    def prepare_data_splits(self, test_size=0.15, val_size=0.15, seed=42):
        """
        Prepare data splits for training, validation, and testing
        
        Parameters:
        -----------
        test_size: float, fraction of data for testing
        val_size: float, fraction of data for validation
        seed: int, random seed for reproducibility
        
        Returns:
        --------
        train_dataset, val_dataset, test_dataset: Dataset objects
        """
        # Set random seeds for reproducibility
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
        
        # Get all data indices and labels for stratification
        all_labels = [self.full_dataset.labels[i] for i in range(len(self.full_dataset))]
        
        # Calculate split sizes
        dataset_size = len(self.full_dataset)
        test_count = int(dataset_size * test_size)
        val_count = int(dataset_size * val_size)
        train_count = dataset_size - test_count - val_count
        
        # Create stratified splits
        train_idx, temp_idx = self.cross_validator.stratified_split(
            all_labels, train_count / dataset_size, seed
        )
        
        # Further split temporary indices into validation and test sets
        val_proportion = val_count / (val_count + test_count)
        val_idx, test_idx = self.cross_validator.stratified_split(
            [all_labels[i] for i in temp_idx], val_proportion, seed
        )
        val_idx = [temp_idx[i] for i in val_idx]
        test_idx = [temp_idx[i] for i in test_idx]
        
        # Create dataset subsets
        train_dataset = Subset(self.full_dataset, train_idx)
        val_dataset = Subset(self.full_dataset, val_idx)
        test_dataset = Subset(self.full_dataset, test_idx)
        
        print(f"Dataset split: Training={len(train_dataset)}, Validation={len(val_dataset)}, Test={len(test_dataset)}")
        
        return train_dataset, val_dataset, test_dataset
    
    def run_full_protocol(self, model_name, train_dataset, val_dataset, test_dataset, output_dir="results"):
        """
        Run the complete experimental protocol for a single model
        
        Parameters:
        -----------
        model_name: str, name of the model to evaluate
        train_dataset, val_dataset, test_dataset: Dataset objects
        output_dir: str, directory to save results
        
        Returns:
        --------
        results: dict, complete evaluation results
        """
        os.makedirs(output_dir, exist_ok=True)
        results = {}
        
        # 1. Hyperparameter optimization on validation set
        print(f"Starting hyperparameter optimization for {model_name}...")
        best_params = self.hyperparameter_optimizer.optimize_model_hyperparams(
            model_name, self.models[model_name], train_dataset, val_dataset, 
            self.num_classes, self.device
        )
        results['best_hyperparams'] = best_params
        
        # 2. Initialize model with best hyperparameters
        model_init_fn = self.models[model_name]
        model = model_init_fn(num_classes=self.num_classes, **best_params)
        model.to(self.device)
        
        # 3. Multi-phase training
        print(f"Starting multi-phase training for {model_name}...")
        
        # Clean pre-training phase
        print("Phase 1: Clean pre-training")
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        
        model = self.training_manager.train_clean(
            model, train_loader, val_loader, 
            epochs=50, 
            lr=0.001, 
            weight_decay=0.01
        )
        
        # Save checkpoint after clean training
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_clean_trained.pth")
        
        # Basic adversarial training phase
        print("Phase 2: Adversarial training")
        model = self.training_manager.train_adversarial(
            model, train_loader, val_loader,
            epochs=30,
            lr=0.0005,
            attack_type='pgd',
            epsilon=0.03,
            alpha=0.007,
            steps=10
        )
        
        # Save checkpoint after adversarial training
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_adv_trained.pth")
        
        # Environmental corruption training phase
        print("Phase 3: Environmental corruption training")
        model = self.training_manager.train_corruptions(
            model, train_dataset, val_dataset,
            epochs=30,
            lr=0.0003,
            batch_size=16,
            corruptions=['gaussian_noise', 'snow', 'fog'],
            severities=[1, 3, 5]
        )
        
        # Save checkpoint after corruption training
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_corruption_trained.pth")
        
        # Combined robustness fine-tuning
        print("Phase 4: Combined robustness fine-tuning")
        model = self.training_manager.train_combined_robustness(
            model, train_dataset, val_dataset,
            epochs=20,
            lr=0.0001,
            batch_size=16,
            adv_weight=0.5,
            corruption_weight=0.5
        )
        
        # Save final model
        torch.save(model.state_dict(), f"{output_dir}/{model_name}_final.pth")
        
        # 4. Comprehensive evaluation
        print(f"Evaluating final model for {model_name}...")
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
        
        evaluator = EvaluationFramework({model_name: model}, device=self.device, num_classes=self.num_classes)
        
        # Basic classification performance
        class_results = evaluator.evaluate_classification_performance(test_loader)
        results['classification_performance'] = class_results
        
        # Adversarial robustness
        attacker = PointCloudAttacker(model, self.device)
        adv_results = {}
        
        for attack_type in ['fgsm', 'pgd']:
            for strength in [0.01, 0.03, 0.05]:
                adv_acc = self._evaluate_adversarial_robustness(
                    model, test_loader, attacker, attack_type, strength
                )
                adv_results[f"{attack_type}_{strength}"] = adv_acc
        
        results['adversarial_robustness'] = adv_results
        
        # Corruption robustness
        corruption_results = {}
        for corruption in ['gaussian_noise', 'fog', 'snow']:
            for severity in [1, 3, 5]:
                corrupted_dataset = get_corrupted_dataset(
                    test_dataset, corruption, severity
                )
                corrupted_loader = DataLoader(
                    corrupted_dataset, batch_size=self.batch_size, shuffle=False
                )
                
                with torch.no_grad():
                    correct = 0
                    total = 0
                    for points, labels in corrupted_loader:
                        points, labels = points.to(self.device), labels.to(self.device)
                        outputs = model(points)
                        _, predicted = outputs.max(1)
                        correct += predicted.eq(labels).sum().item()
                        total += labels.size(0)
                
                corruption_results[f"{corruption}_{severity}"] = correct / total
        
        results['corruption_robustness'] = corruption_results
        
        # Save results
        np.save(f"{output_dir}/{model_name}_results.npy", results)
        
        return results
    
    def _evaluate_adversarial_robustness(self, model, dataloader, attacker, attack_type, strength):
        """Helper method to evaluate adversarial robustness"""
        correct = 0
        total = 0
        
        for points, labels in dataloader:
            points, labels = points.to(self.device), labels.to(self.device)
            
            # Generate adversarial examples
            if attack_type == 'fgsm':
                adv_points = attacker.fgsm_attack(points, labels, epsilon=strength)
            elif attack_type == 'pgd':
                adv_points = attacker.pgd_attack(points, labels, epsilon=strength, alpha=strength/4)
                
            # Evaluate
            with torch.no_grad():
                outputs = model(adv_points)
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        
        return correct / total
        
    def run_cross_validation(self, model_name, n_folds=5, seed=42):
        """
        Run the full k-fold cross-validation procedure
        
        Parameters:
        -----------
        model_name: str, name of the model to evaluate
        n_folds: int, number of folds for cross-validation
        seed: int, random seed for reproducibility
        
        Returns:
        --------
        cv_results: dict, cross-validation results
        """
        # Get all data points and labels
        all_data_idx = list(range(len(self.full_dataset)))
        all_labels = [self.full_dataset.labels[i] for i in range(len(self.full_dataset))]
        
        # Initialize stratified k-fold
        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
        
        fold_results = []
        
        for fold, (train_val_idx, test_idx) in enumerate(skf.split(all_data_idx, all_labels)):
            print(f"\n--- Fold {fold+1}/{n_folds} ---")
            
            # Split train_val into train and validation
            train_idx, val_idx = self.cross_validator.stratified_split(
                [all_labels[i] for i in train_val_idx], 0.8, seed + fold
            )
            train_idx = [train_val_idx[i] for i in train_idx]
            val_idx = [train_val_idx[i] for i in val_idx]
            
            # Create dataset subsets
            train_dataset = Subset(self.full_dataset, train_idx)
            val_dataset = Subset(self.full_dataset, val_idx)
            test_dataset = Subset(self.full_dataset, test_idx)
            
            # Run the full protocol
            output_dir = f"results/cv_fold_{fold+1}"
            fold_result = self.run_full_protocol(
                model_name, train_dataset, val_dataset, test_dataset, output_dir
            )
            
            fold_results.append(fold_result)
        
        # Aggregate cross-validation results
        cv_results = self._aggregate_cv_results(fold_results)
        
        return cv_results
    
    def _aggregate_cv_results(self, fold_results):
        """Helper method to aggregate cross-validation results"""
        # Extract metrics across folds
        classification_metrics = {
            'top1_accuracy': [],
            'avg_precision': [],
            'avg_recall': [],
            'avg_f1': []
        }
        
        for result in fold_results:
            model_name = list(result['classification_performance'].keys())[0]
            metrics = result['classification_performance'][model_name]
            
            classification_metrics['top1_accuracy'].append(metrics['top1_accuracy'])
            classification_metrics['avg_precision'].append(metrics['avg_precision'])
            classification_metrics['avg_recall'].append(metrics['avg_recall'])
            classification_metrics['avg_f1'].append(metrics['avg_f1'])
        
        # Calculate mean and standard deviation
        aggregated = {}
        for metric, values in classification_metrics.items():
            aggregated[f'mean_{metric}'] = np.mean(values)
            aggregated[f'std_{metric}'] = np.std(values)
        
        return aggregated
    
    def benchmark_against_baselines(self, model_name, baselines, train_dataset, val_dataset, test_dataset):
        """
        Benchmark the model against baseline methods
        
        Parameters:
        -----------
        model_name: str, name of the model to evaluate
        baselines: dict, {name: model_init_function} for baseline models
        train_dataset, val_dataset, test_dataset: Dataset objects
        
        Returns:
        --------
        benchmark_results: dict, comparative results
        """
        # Add target model to the baselines
        all_models = {**{model_name: self.models[model_name]}, **baselines}
        model_performances = {}
        
        # Train and evaluate each model
        for name, model_fn in all_models.items():
            print(f"\nEvaluating model: {name}")
            
            # Initialize model
            model = model_fn(num_classes=self.num_classes)

In [None]:
class TrainingManager:
    """Class for managing model training with various techniques"""
    
    def __init__(self, device=None):
        """
        Initialize the training manager
        
        Parameters:
        -----------
        device: torch.device, device to use for training
        """
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    def train_clean(self, model, train_loader, val_loader, epochs=100, lr=0.001, weight_decay=0.0):
        """
        Train a model on clean data
        
        Parameters:
        -----------
        model: torch.nn.Module, model to train
        train_loader: DataLoader, training data
        val_loader: DataLoader, validation data
        epochs: int, number of epochs to train
        lr: float, learning rate
        weight_decay: float, weight decay for regularization
        
        Returns:
        --------
        model: torch.nn.Module, trained model
        """
        model.to(self.device)
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        criterion = nn.CrossEntropyLoss()
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5)
        
        best_val_loss = float('inf')
        best_model_state = None
        
        for epoch in range(epochs):
            # Training
            model.train()
            train_loss = 0
            correct = 0
            total = 0
            
            for points, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False):
                points, labels = points.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(points)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
            
            train_acc = correct / total
            train_loss = train_loss / len(train_loader)
            
            # Validation
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for points, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False):
                    points, labels = points.to(self.device), labels.to(self.device)
                    outputs = model(points)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            val_acc = correct / total
            val_loss = val_loss / len(val_loader)
            
            # Learning rate scheduler
            scheduler.step(val_loss)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = copy.deepcopy(model.state_dict())
            
            print(f"Epoch {epoch+1}/{epochs}: train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        
        # Load best model
        model.load_state_dict(best_model_state)
        return model
        
    def train_adversarial(self, model, train_loader, val_loader, epochs=30, lr=0.0005, 
                          attack_type='pgd', epsilon=0.03, alpha=0.007, steps=10):
        """
        Train a model with adversarial examples
        
        Parameters:
        -----------
        model: torch.nn.Module, model to train
        train_loader: DataLoader, training data
        val_loader: DataLoader, validation data
        epochs: int, number of epochs to train
        lr: float, learning rate
        attack_type: str, type of attack ('pgd' or 'fgsm')
        epsilon: float, attack strength
        alpha: float, step size for PGD attack
        steps: int, number of steps for PGD attack
        
        Returns:
        --------
        model: torch.nn.Module, trained model
        """
        model.to(self.device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        # Create attacker
        attacker = PointCloudAttacker(model, self.device)
        
        for epoch in range(epochs):
            model.train()
            train_loss = 0
            clean_correct = 0
            adv_correct = 0
            total = 0
            
            for points, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False):
                points, labels = points.to(self.device), labels.to(self.device)
                
                # Generate adversarial examples
                if attack_type == 'pgd':
                    adv_points = attacker.pgd_attack(points, labels, epsilon=epsilon, alpha=alpha, num_iter=steps)
                else:  # Default to FGSM
                    adv_points = attacker.fgsm_attack(points, labels, epsilon=epsilon)
                
                # Train on mixture of clean and adversarial examples
                optimizer.zero_grad()
                
                # Forward pass on clean points
                outputs_clean = model(points)
                loss_clean = criterion(outputs_clean, labels)
                
                # Forward pass on adversarial points
                outputs_adv = model(adv_points)
                loss_adv = criterion(outputs_adv, labels)
                
                # Combined loss
                loss = 0.5 * loss_clean + 0.5 * loss_adv
                
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                
                # Clean accuracy
                _, predicted = outputs_clean.max(1)
                clean_correct += predicted.eq(labels).sum().item()
                
                # Adversarial accuracy
                _, predicted = outputs_adv.max(1)
                adv_correct += predicted.eq(labels).sum().item()
                
                total += labels.size(0)
            
            train_clean_acc = clean_correct / total
            train_adv_acc = adv_correct / total
            train_loss = train_loss / len(train_loader)
            
            # Validation
            model.eval()
            val_clean_acc, val_adv_acc = self._evaluate_adversarial(model, val_loader, attacker, attack_type, epsilon, alpha, steps)
            
            print(f"Epoch {epoch+1}/{epochs}: train_loss={train_loss:.4f}, train_clean_acc={train_clean_acc:.4f}, train_adv_acc={train_adv_acc:.4f}, "
                  f"val_clean_acc={val_clean_acc:.4f}, val_adv_acc={val_adv_acc:.4f}")
        
        return model
    
    def _evaluate_adversarial(self, model, dataloader, attacker, attack_type, epsilon, alpha=None, steps=None):
        """Helper method to evaluate adversarial robustness"""
        model.eval()
        clean_correct = 0
        adv_correct = 0
        total = 0
        
        with torch.no_grad():
            for points, labels in dataloader:
                points, labels = points.to(self.device), labels.to(self.device)
                
                # Clean accuracy
                outputs = model(points)
                _, predicted = outputs.max(1)
                clean_correct += predicted.eq(labels).sum().item()
                
                total += labels.size(0)
        
        # Adversarial accuracy
        for points, labels in dataloader:
            points, labels = points.to(self.device), labels.to(self.device)
            
            # Generate adversarial examples
            if attack_type == 'pgd':
                adv_points = attacker.pgd_attack(points, labels, epsilon=epsilon, alpha=alpha, num_iter=steps)
            else:  # Default to FGSM
                adv_points = attacker.fgsm_attack(points, labels, epsilon=epsilon)
            
            with torch.no_grad():
                outputs = model(adv_points)
                _, predicted = outputs.max(1)
                adv_correct += predicted.eq(labels).sum().item()
        
        return clean_correct / total, adv_correct / total
    
    def train_corruptions(self, model, train_dataset, val_dataset, epochs=30, lr=0.0003, 
                          batch_size=16, corruptions=None, severities=None):
        """
        Train a model with environmental corruptions
        
        Parameters:
        -----------
        model: torch.nn.Module, model to train
        train_dataset: Dataset, training dataset
        val_dataset: Dataset, validation dataset
        epochs: int, number of epochs to train
        lr: float, learning rate
        batch_size: int, batch size
        corruptions: list, types of corruptions to use
        severities: list, severity levels to use
        
        Returns:
        --------
        model: torch.nn.Module, trained model
        """
        if corruptions is None:
            corruptions = ['gaussian_noise', 'snow', 'fog']
        
        if severities is None:
            severities = [1, 3, 5]
        
        model.to(self.device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        # Create clean data loader
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        best_val_loss = float('inf')
        best_model_state = None
        
        for epoch in range(epochs):
            model.train()
            total_loss = 0
            clean_correct = 0
            corrupt_correct = 0
            total = 0
            
            # Random corruption type and severity for this epoch
            corruption = random.choice(corruptions)
            severity = random.choice(severities)
            
            # Create corrupted dataset
            corrupt_train_dataset = get_corrupted_dataset(train_dataset, corruption, severity)
            corrupt_train_loader = DataLoader(corrupt_train_dataset, batch_size=batch_size, shuffle=True)
            
            # Train on both clean and corrupted data
            for (clean_points, labels), (corrupt_points, _) in zip(train_loader, corrupt_train_loader):
                clean_points, corrupt_points, labels = clean_points.to(self.device), corrupt_points.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                
                # Forward pass on clean points
                outputs_clean = model(clean_points)
                loss_clean = criterion(outputs_clean, labels)
                
                # Forward pass on corrupted points
                outputs_corrupt = model(corrupt_points)
                loss_corrupt = criterion(outputs_corrupt, labels)
                
                # Combined loss
                loss = 0.5 * loss_clean + 0.5 * loss_corrupt
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                # Clean accuracy
                _, predicted = outputs_clean.max(1)
                clean_correct += predicted.eq(labels).sum().item()
                
                # Corrupted accuracy
                _, predicted = outputs_corrupt.max(1)
                corrupt_correct += predicted.eq(labels).sum().item()
                
                total += labels.size(0)
            
            train_clean_acc = clean_correct / total
            train_corrupt_acc = corrupt_correct / total
            avg_loss = total_loss / len(train_loader)
            
            # Validation
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            # Random corruption for validation
            val_corruption = random.choice(corruptions)
            val_severity = random.choice(severities)
            corrupt_val_dataset = get_corrupted_dataset(val_dataset, val_corruption, val_severity)
            corrupt_val_loader = DataLoader(corrupt_val_dataset, batch_size=batch_size, shuffle=False)
            
            with torch.no_grad():
                for corrupt_points, labels in corrupt_val_loader:
                    corrupt_points, labels = corrupt_points.to(self.device), labels.to(self.device)
                    outputs = model(corrupt_points)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            val_acc = correct / total
            val_loss = val_loss / len(corrupt_val_loader)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = copy.deepcopy(model.state_dict())
            
            print(f"Epoch {epoch+1}/{epochs}: train_loss={avg_loss:.4f}, train_clean_acc={train_clean_acc:.4f}, "
                  f"train_corrupt_acc={train_corrupt_acc:.4f}, val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
        
        # Load best model
        model.load_state_dict(best_model_state)
        return model
    
    def train_combined_robustness(self, model, train_dataset, val_dataset, epochs=20, lr=0.0001,
                                 batch_size=16, adv_weight=0.5, corruption_weight=0.5):
        """
        Train a model with combined adversarial and corruption robustness
        
        Parameters:
        -----------
        model: torch.nn.Module, model to train
        train_dataset: Dataset, training dataset
        val_dataset: Dataset, validation dataset
        epochs: int, number of epochs to train
        lr: float, learning rate
        batch_size: int, batch size
        adv_weight: float, weight for adversarial loss
        corruption_weight: float, weight for corruption loss
        
        Returns:
        --------
        model: torch.nn.Module, trained model
        """
        model.to(self.device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        
        # Create attacker
        attacker = PointCloudAttacker(model, self.device)
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        # Define corruptions
        corruptions = ['gaussian_noise', 'snow', 'fog']
        severities = [1, 3, 5]
        
        for epoch in range(epochs):
            model.train()
            total_loss = 0
            clean_correct = 0
            adv_correct = 0
            corrupt_correct = 0
            total = 0
            
            # Random corruption for this epoch
            corruption = random.choice(corruptions)
            severity = random.choice(severities)
            corrupt_train_dataset = get_corrupted_dataset(train_dataset, corruption, severity)
            corrupt_train_loader = DataLoader(corrupt_train_dataset, batch_size=batch_size, shuffle=True)
            
            for i, ((clean_points, labels), (corrupt_points, _)) in enumerate(zip(train_loader, corrupt_train_loader)):
                clean_points = clean_points.to(self.device)
                corrupt_points = corrupt_points.to(self.device)
                labels = labels.to(self.device)
                
                optimizer.zero_grad()
                
                # Clean forward pass
                outputs_clean = model(clean_points)
                loss_clean = criterion(outputs_clean, labels)
                
                # Adversarial examples
                adv_points = attacker.pgd_attack(clean_points, labels, epsilon=0.03, alpha=0.007, num_iter=10)
                outputs_adv = model(adv_points)
                loss_adv = criterion(outputs_adv, labels)
                
                # Corrupted examples
                outputs_corrupt = model(corrupt_points)
                loss_corrupt = criterion(outputs_corrupt, labels)
                
                # Combined loss
                clean_weight = 1.0 - adv_weight - corruption_weight
                loss = clean_weight * loss_clean + adv_weight * loss_adv + corruption_weight * loss_corrupt
                
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                
                # Track accuracy
                _, predicted = outputs_clean.max(1)
                clean_correct += predicted.eq(labels).sum().item()
                
                _, predicted = outputs_adv.max(1)
                adv_correct += predicted.eq(labels).sum().item()
                
                _, predicted = outputs_corrupt.max(1)
                corrupt_correct += predicted.eq(labels).sum().item()
                
                total += labels.size(0)
                
                # Print progress every 50 batches
                if (i+1) % 50 == 0:
                    print(f"Epoch {epoch+1}/{epochs}, Batch {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}")
            
            avg_loss = total_loss / len(train_loader)
            clean_acc = clean_correct / total
            adv_acc = adv_correct / total
            corrupt_acc = corrupt_correct / total
            
            print(f"Epoch {epoch+1}/{epochs}: loss={avg_loss:.4f}, clean_acc={clean_acc:.4f}, "
                  f"adv_acc={adv_acc:.4f}, corrupt_acc={corrupt_acc:.4f}")
            
            # Validation
            val_clean_acc, val_adv_acc, val_corrupt_acc = self._evaluate_combined_robustness(
                model, val_dataset, attacker, corruption, severity, batch_size)
            
            print(f"Validation: clean_acc={val_clean_acc:.4f}, adv_acc={val_adv_acc:.4f}, corrupt_acc={val_corrupt_acc:.4f}")
        
        return model
    
    def _evaluate_combined_robustness(self, model, dataset, attacker, corruption, severity, batch_size):
        """Helper method to evaluate combined robustness"""
        model.eval()
        
        # Create data loaders
        val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        
        # Create corrupted dataset
        corrupt_dataset = get_corrupted_dataset(dataset, corruption, severity)
        corrupt_loader = DataLoader(corrupt_dataset, batch_size=batch_size, shuffle=False)
        
        # Clean accuracy
        clean_correct = 0
        total = 0
        with torch.no_grad():
            for points, labels in val_loader:
                points, labels = points.to(self.device), labels.to(self.device)
                outputs = model(points)
                _, predicted = outputs.max(1)
                clean_correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        
        clean_acc = clean_correct / total
        
        # Adversarial accuracy
        adv_correct = 0
        total = 0
        for points, labels in val_loader:
            points, labels = points.to(self.device), labels.to(self.device)
            
            # Generate adversarial examples
            adv_points = attacker.pgd_attack(points, labels, epsilon=0.03, alpha=0.007, num_iter=10)
            
            with torch.no_grad():
                outputs = model(adv_points)
                _, predicted = outputs.max(1)
                adv_correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        
        adv_acc = adv_correct / total
        
        # Corruption accuracy
        corrupt_correct = 0
        total = 0
        with torch.no_grad():
            for points, labels in corrupt_loader:
                points, labels = points.to(self.device), labels.to(self.device)
                outputs = model(points)
                _, predicted = outputs.max(1)
                corrupt_correct += predicted.eq(labels).sum().item()
                total += labels.size(0)
        
        corrupt_acc = corrupt_correct / total
        
        return clean_acc, adv_acc, corrupt_acc


class CrossValidator:
    """Class for cross-validation procedures"""
    
    def stratified_split(self, labels, train_ratio, random_seed=None):
        """
        Perform stratified split of data indices
        
        Parameters:
        -----------
        labels: list, class labels for each sample
        train_ratio: float, ratio of data for training
        random_seed: int, random seed for reproducibility
        
        Returns:
        --------
        train_indices, test_indices: lists of indices for training and testing
        """
        if random_seed is not None:
            np.random.seed(random_seed)
        
        # Get unique labels and their counts
        unique_labels = np.unique(labels)
        train_indices = []
        test_indices = []
        
        # Split each class proportionally
        for label in unique_labels:
            indices = [i for i, l in enumerate(labels) if l == label]
            np.random.shuffle(indices)
            
            # Calculate split point
            split = int(len(indices) * train_ratio)
            
            # Add indices to respective sets
            train_indices.extend(indices[:split])
            test_indices.extend(indices[split:])
        
        return train_indices, test_indices


class HyperparameterOptimizer:
    """Class for hyperparameter optimization"""
    
    def optimize_model_hyperparams(self, model_name, model_init_fn, train_dataset, val_dataset, 
                                  num_classes=40, device=None):
        """
        Optimize hyperparameters for a model
        
        Parameters:
        -----------
        model_name: str, name of the model
        model_init_fn: function, model initialization function
        train_dataset: Dataset, training dataset
        val_dataset: Dataset, validation dataset
        num_classes: int, number of classes
        device: torch.device, device to use
        
        Returns:
        --------
        best_params: dict, best hyperparameters
        """
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print(f"Optimizing hyperparameters for {model_name}...")
        
        # Define hyperparameter search space based on model
        if model_name == "pointnet++":
            param_grid = {
                "lr": [0.001, 0.0005],
                "weight_decay": [0.0, 0.001],
                "batch_size": [16, 32]
            }
        elif model_name == "dgcnn":
            param_grid = {
                "lr": [0.001, 0.0005],
                "weight_decay": [0.0, 0.001],
                "batch_size": [16, 32],
                "k": [20, 40]  # k-nearest neighbors
            }
        elif model_name == "pointmlp":
            param_grid = {
                "lr": [0.001, 0.0005],
                "weight_decay": [0.0, 0.001],
                "batch_size": [16, 32],
                "embed_dim": [64, 128]
            }
        else:  # Default for custom attention model
            param_grid = {
                "lr": [0.001, 0.0005],
                "weight_decay": [0.0, 0.001],
                "batch_size": [16, 32],
                "embed_dim": [64, 128],
                "heads": [4, 8]
            }
        
        # For simplicity, we'll just use a predefined best configuration
        # In a real implementation, you'd perform grid search or random search
        print("Note: Using predefined hyperparameters for demonstration.")
        print("In a real implementation, this would perform actual hyperparameter optimization.")
        
        if model_name == "pointnet++":
            best_params = {"lr": 0.001, "weight_decay": 0.001, "batch_size": 32}
        elif model_name == "dgcnn":
            best_params = {"lr": 0.001, "weight_decay": 0.001, "batch_size": 32, "k": 20}
        elif model_name == "pointmlp":
            best_params = {"lr": 0.001, "weight_decay": 0.001, "batch_size": 32, "embed_dim": 64}
        else:  # Custom attention model
            best_params = {"lr": 0.001, "weight_decay": 0.001, "batch_size": 32, "embed_dim": 128, "heads": 4}
        
        # Remove training-specific parameters to keep only model initialization parameters
        model_params = {k: v for k, v in best_params.items() if k not in ["lr", "weight_decay", "batch_size"]}
        
        return model_params


class BenchmarkProtocol:
    """Class for benchmarking procedures"""
    
    def __init__(self, device=None):
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    def benchmark_models(self, models, test_loader):
        """
        Benchmark multiple models on classification performance
        
        Parameters:
        -----------
        models: dict, {name: model} dictionary
        test_loader: DataLoader, test data
        
        Returns:
        --------
        results: dict, benchmark results
        """
        results = {}
        
        for name, model in models.items():
            model.eval()
            model.to(self.device)
            
            correct = 0
            total = 0
            inference_times = []
            
            with torch.no_grad():
                for points, labels in tqdm(test_loader, desc=f"Benchmarking {name}"):
                    points, labels = points.to(self.device), labels.to(self.device)
                    
                    # Measure inference time
                    start_time = time.time()
                    outputs = model(points)
                    inference_time = time.time() - start_time
                    inference_times.append(inference_time)
                    
                    # Calculate accuracy
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            accuracy = correct / total
            avg_inference_time = sum(inference_times) / len(inference_times)
            
            results[name] = {
                "accuracy": accuracy,
                "avg_inference_time_ms": avg_inference_time * 1000
            }
            
            print(f"{name}: Accuracy = {accuracy:.4f}, Avg. Inference Time = {avg_inference_time*1000:.2f} ms")
        
        return results


# Fix the ExperimentalProtocol class that was duplicated and incomplete
def _evaluate_adversarial_robustness(model, dataloader, attacker, attack_type, strength):
    """Helper method to evaluate adversarial robustness"""
    correct = 0
    total = 0
    
    for points, labels in dataloader:
        points, labels = points.to(attacker.device), labels.to(attacker.device)
        
        # Generate adversarial examples
        if attack_type == 'fgsm':
            adv_points = attacker.fgsm_attack(points, labels, epsilon=strength)
        elif attack_type == 'pgd':
            adv_points = attacker.pgd_attack(points, labels, epsilon=strength, alpha=strength/4)
            
        # Evaluate
        with torch.no_grad():
            outputs = model(adv_points)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    
    return correct / total