In [None]:
class SetAbstractionLayer(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp):
        super(SetAbstractionLayer, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, N, 3]
            points: input points data, [B, N, D]
        Return:
            new_xyz: sampled points position data, [B, S, 3]
            new_points_concat: sample points feature data, [B, S, D']
        """
        # Simplified implementation - in a real model this would use farthest point sampling
        # and ball query operations from PointNet++ 
        B, N, C = xyz.shape
        if self.npoint is None:
            S = N
        else:
            S = min(self.npoint, N)
            
        # Simplified: just use the first S points instead of FPS
        new_xyz = xyz[:, :S, :]
        
        # Simplified feature aggregation (in real implementation would be grouped by radius)
        if points is not None:
            new_points = points[:, :S, :]
        else:
            new_points = new_xyz
            
        # Apply MLP - simplified for this example
        new_points = new_points.permute(0, 2, 1).unsqueeze(-1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = torch.relu(bn(conv(new_points)))
        
        new_points = new_points.squeeze(-1).permute(0, 2, 1)
        return new_xyz, new_points

class PointNet2Classification(nn.Module):
    def __init__(self, num_classes=40, normal_channel=False):
        super(PointNet2Classification, self).__init__()
        in_channel = 6 if normal_channel else 3
        
        self.normal_channel = normal_channel
        
        # SA modules
        self.sa1 = SetAbstractionLayer(npoint=512, radius=0.2, nsample=32, 
                                        in_channel=in_channel, mlp=[64, 64, 128])
        self.sa2 = SetAbstractionLayer(npoint=128, radius=0.4, nsample=64, 
                                        in_channel=128 + 3, mlp=[128, 128, 256])
        self.sa3 = SetAbstractionLayer(npoint=None, radius=None, nsample=None, 
                                        in_channel=256 + 3, mlp=[256, 512, 1024])
        
        # FC layers
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]
        else:
            norm = None
        
        # Transpose to match expected input format
        xyz = xyz.transpose(2, 1)
        
        # Set Abstraction layers
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        
        # FC layers
        x = l3_points.view(B, 1024)
        x = self.drop1(torch.relu(self.bn1(self.fc1(x))))
        x = self.drop2(torch.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        
        return x

def _init_pointnet_plus_plus(num_classes=40):
    """Create and initialize a PointNet++ classification model"""
    model = PointNet2Classification(num_classes=num_classes)
    return model

In [None]:
class EdgeConv(nn.Module):
    def __init__(self, k, in_channels, out_channels):
        super(EdgeConv, self).__init__()
        self.k = k
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels*2, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
    def knn(self, x, k):
        inner = -2 * torch.matmul(x.transpose(2, 1), x)
        xx = torch.sum(x**2, dim=1, keepdim=True)
        distance = -xx - inner - xx.transpose(2, 1)
        
        idx = distance.topk(k=k, dim=-1)[1]
        return idx
    
    def get_graph_feature(self, x, k=20):
        batch_size, num_dims, num_points = x.size()
        idx = self.knn(x, k)
        idx_base = torch.arange(0, batch_size, device=x.device).view(-1, 1, 1) * num_points
        idx = idx + idx_base
        idx = idx.view(-1)
        
        x = x.transpose(2, 1).contiguous()
        feature = x.view(batch_size * num_points, -1)[idx, :]
        feature = feature.view(batch_size, num_points, k, num_dims)
        x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
        
        feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()
        return feature
    
    def forward(self, x):
        x = self.get_graph_feature(x, self.k)
        x = self.conv(x)
        x = x.max(dim=-1, keepdim=False)[0]
        return x

class DGCNN(nn.Module):
    def __init__(self, num_classes=40, k=20):
        super(DGCNN, self).__init__()
        self.k = k
        
        # Edge convolution layers
        self.edge_conv1 = EdgeConv(k=k, in_channels=3, out_channels=64)
        self.edge_conv2 = EdgeConv(k=k, in_channels=64, out_channels=64)
        self.edge_conv3 = EdgeConv(k=k, in_channels=64, out_channels=128)
        self.edge_conv4 = EdgeConv(k=k, in_channels=128, out_channels=256)
        
        # MLP layers
        self.mlp = nn.Sequential(
            nn.Conv1d(512, 1024, kernel_size=1, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(negative_slope=0.2)
        )
        
        # Fully connected layers
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.5)
        
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.5)
        
        self.fc3 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        # Input transform
        batch_size = x.size(0)
        if x.shape[1] == 3 and len(x.shape) == 3:  # BxCxN format
            pass
        else:  # Assuming BxNxC format (like PointNet)
            x = x.transpose(2, 1)
            
        # Extract edge features
        x1 = self.edge_conv1(x)
        x2 = self.edge_conv2(x1)
        x3 = self.edge_conv3(x2)
        x4 = self.edge_conv4(x3)
        
        # Concatenate features
        x = torch.cat([x1, x2, x3, x4], dim=1)
        
        # MLP
        x = self.mlp(x)
        
        # Global max pooling
        x = torch.max(x, 2)[0]
        
        # FC layers
        x = self.drop1(torch.relu(self.bn1(self.fc1(x))))
        x = self.drop2(torch.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        
        return x

def _init_dgcnn(num_classes=40):
    """Create and initialize a DGCNN classification model"""
    model = DGCNN(num_classes=num_classes)
    return model

In [None]:
class PointMLP(nn.Module):
    def __init__(self, num_classes=40, points=1024, embed_dim=64):
        super(PointMLP, self).__init__()
        self.num_classes = num_classes
        self.points = points
        
        # Point embedding layers
        self.embedding = nn.Sequential(
            nn.Conv1d(3, embed_dim, 1),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True),
            nn.Conv1d(embed_dim, embed_dim, 1),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True)
        )
        
        # Hierarchical feature extraction layers
        self.layer1 = nn.Sequential(
            nn.Conv1d(embed_dim, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True)
        )
        
        self.layer2 = nn.Sequential(
            nn.Conv1d(128, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True)
        )
        
        self.layer3 = nn.Sequential(
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True)
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        # Handle input format (B, N, 3) or (B, 3, N)
        batch_size = x.size(0)
        if x.shape[1] != 3 and x.shape[2] == 3:
            x = x.transpose(2, 1)
        
        # Feature extraction
        x = self.embedding(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        # Global pooling
        x = torch.max(x, 2)[0]
        
        # Classification
        x = self.classifier(x)
        
        return x

def _init_pointmlp(num_classes=40):
    """Create and initialize a PointMLP classification model"""
    model = PointMLP(num_classes=num_classes)
    return model

In [None]:
class PointAttentionNet(nn.Module):
    def __init__(self, num_classes=40, num_points=1024, embed_dim=128):
        super(PointAttentionNet, self).__init__()
        self.num_classes = num_classes
        
        # Point embedding
        self.embedding = nn.Sequential(
            nn.Conv1d(3, embed_dim, 1),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(inplace=True)
        )
        
        # Self-attention modules
        self.self_attn1 = PointSelfAttention(embed_dim, 4)
        self.self_attn2 = PointSelfAttention(embed_dim, 4)
        
        # Feature processing after attention
        self.feature_conv = nn.Sequential(
            nn.Conv1d(embed_dim, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True)
        )
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        # Handle input format (B, N, 3) or (B, 3, N)
        batch_size = x.size(0)
        if x.shape[1] != 3 and x.shape[2] == 3:
            x = x.transpose(2, 1)
        
        # Feature embedding
        x = self.embedding(x)  # B x embed_dim x N
        
        # Apply attention
        x = self.self_attn1(x)
        x = self.self_attn2(x)
        
        # Feature processing
        x = self.feature_conv(x)
        
        # Global pooling
        x = torch.max(x, 2)[0]
        
        # Classification
        x = self.classifier(x)
        
        return x


class PointSelfAttention(nn.Module):
    def __init__(self, channels, heads=4):
        super(PointSelfAttention, self).__init__()
        self.heads = heads
        self.channels = channels
        self.head_dim = channels // heads
        assert self.head_dim * heads == channels, "channels must be divisible by heads"
        
        self.qkv_conv = nn.Conv1d(channels, channels * 3, 1, bias=False)
        self.out_conv = nn.Conv1d(channels, channels, 1)
        
        self.norm1 = nn.BatchNorm1d(channels)
        self.norm2 = nn.BatchNorm1d(channels)
        
        self.ff = nn.Sequential(
            nn.Conv1d(channels, channels * 2, 1),
            nn.BatchNorm1d(channels * 2),
            nn.ReLU(inplace=True),
            nn.Conv1d(channels * 2, channels, 1)
        )
        
    def forward(self, x):
        # x: B x C x N
        batch_size, C, N = x.shape
        residual = x
        
        # Self-attention
        qkv = self.qkv_conv(x)  # B x 3C x N
        qkv = qkv.reshape(batch_size, 3, self.heads, self.head_dim, N)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]  # B x H x D x N
        
        # Compute attention scores
        q = q.permute(0, 1, 3, 2)  # B x H x N x D
        k = k.permute(0, 1, 2, 3)  # B x H x D x N
        attn = torch.matmul(q, k) / (self.head_dim ** 0.5)  # B x H x N x N
        attn = torch.softmax(attn, dim=-1)
        
        # Apply attention
        v = v.permute(0, 1, 3, 2)  # B x H x N x D
        x = torch.matmul(attn, v)  # B x H x N x D
        x = x.permute(0, 1, 3, 2).reshape(batch_size, C, N)
        
        # Output projection
        x = self.out_conv(x)
        x = self.norm1(x + residual)
        
        # Feed forward
        residual = x
        x = self.ff(x)
        x = self.norm2(x + residual)
        
        return x


def _init_custom_attention(num_classes=40):
    """Create and initialize a custom attention-based model"""
    return PointAttentionNet(num_classes=num_classes)

In [None]:
import torch
import numpy as np
import os
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim

# Import model architectures
# Note: These would typically be imported from external files
# For demonstration, we'll include placeholders for imports

# from models.pointnet2 import PointNet2Classification
# from models.dgcnn import DGCNN
# from models.pointmlp import PointMLP
# from models.custom_attention import PointAttentionNet

class ModelEvaluator:
    def __init__(self, model_name, pretrained_path, num_classes=40):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = model_name
        self.num_classes = num_classes
        
        # Initialize selected model
        if model_name == "pointnet++":
            self.model = self._init_pointnet_plus_plus()
        elif model_name == "dgcnn":
            self.model = self._init_dgcnn()
        elif model_name == "pointmlp":
            self.model = self._init_pointmlp()
        elif model_name == "custom":
            self.model = self._init_custom_attention()
        else:
            raise ValueError(f"Unsupported model: {model_name}")
            
        # Load pre-trained weights
        self.load_pretrained(pretrained_path)
        self.model.to(self.device)
        self.model.eval()
        
    def _init_pointnet_plus_plus(self):
        return PointNet2Classification(num_classes=self.num_classes)
        
    def _init_dgcnn(self):
        # Placeholder for DGCNN initialization
        return DGCNN(num_classes=self.num_classes)
        
    def _init_pointmlp(self):
        # Placeholder for PointMLP initialization
        return PointMLP(num_classes=self.num_classes)
        
    def _init_custom_attention(self):
        # Placeholder for custom attention model initialization
        # return PointAttentionNet(num_classes=self.num_classes)
        print("Initializing custom attention model")
        return None  # Replace with actual model
    
    def load_pretrained(self, path):
        # Load pre-trained weights if the model exists
        if self.model is not None and os.path.exists(path):
            print(f"Loading pre-trained weights for {self.model_name} from {path}")
            # self.model.load_state_dict(torch.load(path))
        else:
            print(f"No pre-trained weights found at {path}")
    
    def evaluate_adversarial_robustness(self, test_loader, attack_type, epsilon):
        """
        Evaluate model robustness against adversarial attacks
        
        Parameters:
        -----------
        test_loader: DataLoader containing test data
        attack_type: str, type of attack (e.g., 'fgsm', 'pgd')
        epsilon: float, attack strength
        
        Returns:
        --------
        accuracy: float, model accuracy under adversarial attack
        """
        # Placeholder for adversarial evaluation
        print(f"Evaluating {self.model_name} against {attack_type} with ε={epsilon}")
        return 0.0
    
    def evaluate_corruption_robustness(self, test_loader, corruption_type, severity):
        """
        Evaluate model robustness against environmental corruptions
        
        Parameters:
        -----------
        test_loader: DataLoader containing test data
        corruption_type: str, type of corruption (e.g., 'gaussian', 'impulse')
        severity: int, corruption severity level (usually 1-5)
        
        Returns:
        --------
        accuracy: float, model accuracy under corruption
        """
        # Placeholder for corruption evaluation
        print(f"Evaluating {self.model_name} against {corruption_type} corruption at severity {severity}")
        return 0.0

# Example usage:
# evaluator = ModelEvaluator("pointnet++", "pretrained/pointnet_modelnet40.pth")
# adv_acc = evaluator.evaluate_adversarial_robustness(test_loader, "pgd", 0.1)
# corr_acc = evaluator.evaluate_corruption_robustness(test_loader, "gaussian_noise", 3)

In [None]:
import torch
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import glob
from pathlib import Path
import h5py
import random
from scipy.spatial.transform import Rotation
import math

class ModelNet40Dataset(Dataset):
    def __init__(self, root_dir, split='train', num_points=1024, transform=None, random_rotation=True, class_balance=True):
        """
        ModelNet40 Dataset for 3D point cloud classification
        
        Parameters:
        -----------
        root_dir: str, path to the ModelNet40 dataset
        split: str, 'train' or 'test'
        num_points: int, number of points to sample (1024, 2048, or 4096)
        transform: callable, optional transform to be applied on a sample
        random_rotation: bool, whether to apply random rotation augmentation
        class_balance: bool, whether to use class balancing
        """
        self.root_dir = root_dir
        self.split = split
        self.num_points = num_points
        self.transform = transform
        self.random_rotation = random_rotation
        
        # Get all class folders
        self.classes = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        # Get all model files
        self.files = []
        self.labels = []
        for cls_name in self.classes:
            cls_path = os.path.join(root_dir, cls_name, split)
            if os.path.exists(cls_path):
                model_files = glob.glob(os.path.join(cls_path, '*.off'))
                self.files.extend(model_files)
                self.labels.extend([self.class_to_idx[cls_name]] * len(model_files))
        
        # Prepare weights for class balancing
        self.weights = None
        if class_balance:
            # Count instances per class
            class_counts = np.zeros(len(self.classes))
            for label in self.labels:
                class_counts[label] += 1
            # Calculate weights inversely proportional to class frequency
            self.weights = 1.0 / class_counts
            # Assign weight to each sample
            self.sample_weights = np.array([self.weights[label] for label in self.labels])
            self.sample_weights = torch.from_numpy(self.sample_weights).float()
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        file_path = self.files[idx]
        label = self.labels[idx]
        
        # Load point cloud (placeholder - actual loading depends on file format)
        # For OFF files, you'd parse the OFF format
        # For this example, we'll generate a random point cloud
        # In practice, replace this with actual file loading
        points = np.random.rand(1024, 3)  # Placeholder
        
        # Sample num_points points
        if points.shape[0] > self.num_points:
            indices = np.random.choice(points.shape[0], self.num_points, replace=False)
            points = points[indices]
        elif points.shape[0] < self.num_points:
            # Upsample by duplicating
            indices = np.random.choice(points.shape[0], self.num_points - points.shape[0], replace=True)
            points = np.vstack([points, points[indices]])
        
        # Normalize to unit sphere
        center = np.mean(points, axis=0)
        points = points - center
        dist = np.max(np.sqrt(np.sum(points ** 2, axis=1)))
        points = points / dist
        
        # Apply random rotation
        if self.split == 'train' and self.random_rotation:
            # Random rotation around z-axis
            theta = np.random.uniform(0, 2 * np.pi)
            rotation_matrix = np.array([
                [np.cos(theta), -np.sin(theta), 0],
                [np.sin(theta), np.cos(theta), 0],
                [0, 0, 1]
            ])
            points = points @ rotation_matrix
        
        if self.transform:
            points = self.transform(points)
        
        return points, label

class PointCloudCorruptor:
    """Class for adding various environmental corruptions to point clouds"""
    
    @staticmethod
    def add_snow(points, density=0.1, scatter_strength=0.05):
        """Add snow effect to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        density: float, snow density (0.0-1.0)
        scatter_strength: float, scattering effect strength
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        num_snow_points = int(num_points * density)
        
        # Generate snow points
        snow_points = np.random.uniform(-1, 1, size=(num_snow_points, 3))
        
        # Create scattering effect
        scatter = np.random.normal(0, scatter_strength, size=(num_points, 3))
        points_scattered = points + scatter
        
        # Randomly select points to replace with snow
        indices = np.random.choice(num_points, num_snow_points, replace=False)
        points_scattered[indices] = snow_points
        
        return points_scattered
    
    @staticmethod
    def add_rain(points, density=0.1, drop_length=0.05):
        """Add rain effect to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        density: float, rain density (0.0-1.0)
        drop_length: float, length of rain drops
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        num_rain_points = int(num_points * density)
        
        # Generate rain points (starting positions)
        rain_points = np.random.uniform(-1, 1, size=(num_rain_points, 3))
        # Make rain streaks by extending in mostly downward direction
        rain_directions = np.random.normal(0, 0.1, size=(num_rain_points, 2))
        rain_directions = np.column_stack([rain_directions, -np.abs(np.random.normal(0, 0.5, num_rain_points))])
        rain_directions /= np.linalg.norm(rain_directions, axis=1, keepdims=True)
        
        # Create rain droplet streaks
        rain_streaks = rain_points + rain_directions * drop_length
        
        # Randomly select points to replace with rain
        indices = np.random.choice(num_points, num_rain_points, replace=False)
        points_with_rain = points.copy()
        points_with_rain[indices] = rain_streaks
        
        return points_with_rain
    
    @staticmethod
    def add_fog(points, density=0.2):
        """Add fog effect to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        density: float, fog density (0.0-1.0)
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        
        # Calculate distance from origin for each point
        distances = np.linalg.norm(points, axis=1)
        
        # Attenuate points based on distance and fog density
        attenuation = np.exp(-density * distances)
        
        # Apply attenuation: closer to 0 means more fog effect
        # We'll add random displacement proportional to fog intensity
        fog_displacement = (1 - attenuation).reshape(-1, 1) * np.random.normal(0, 0.1, size=(num_points, 3))
        points_with_fog = points + fog_displacement
        
        # Randomly drop some distant points (completely obscured by fog)
        dropout_prob = 1 - np.exp(-density * distances * 2)
        dropout_mask = np.random.random(num_points) > dropout_prob
        
        # Create a mix of original and fogged points
        return points_with_fog * dropout_mask.reshape(-1, 1) + (1 - dropout_mask.reshape(-1, 1)) * np.random.uniform(-0.1, 0.1, size=(num_points, 3))
    
    @staticmethod
    def add_gaussian_noise(points, sigma=0.02):
        """Add Gaussian noise to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        sigma: float, standard deviation relative to object size
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        noise = np.random.normal(0, sigma, size=points.shape)
        return points + noise
    
    @staticmethod
    def add_depth_noise(points, k=0.05):
        """Add depth-dependent noise (more noise at greater distances)
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        k: float, noise coefficient
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        # Calculate distance from origin
        distances = np.linalg.norm(points, axis=1)
        
        # Scale noise by distance (farther = more noise)
        noise_scale = k * distances.reshape(-1, 1)
        noise = np.random.normal(0, 1, size=points.shape) * noise_scale
        
        return points + noise
    
    @staticmethod
    def add_occlusion(points, ratio=0.2, mode='random'):
        """Add occlusion to point cloud
        
        Parameters:
        -----------
        points: np.ndarray, shape (N, 3)
        ratio: float, percentage of points to occlude (0.0-1.0)
        mode: str, 'random', 'region', or 'semantic'
        
        Returns:
        --------
        corrupted_points: np.ndarray, shape (N, 3)
        """
        num_points = points.shape[0]
        num_to_occlude = int(num_points * ratio)
        
        if mode == 'random':
            # Random occlusion
            mask = np.ones(num_points, dtype=bool)
            indices = np.random.choice(num_points, num_to_occlude, replace=False)
            mask[indices] = False
            return points[mask]
        
        elif mode == 'region':
            # Region-based occlusion
            center = np.random.uniform(-0.5, 0.5, size=3)
            radius = np.random.uniform(0.2, 0.5)
            
            # Calculate distances to the center
            distances = np.linalg.norm(points - center, axis=1)
            
            # Keep points outside the sphere
            mask = distances > radius
            return points[mask]
        
        elif mode == 'semantic':
            # Semantic occlusion (simplified - just occludes one side)
            # In real implementation, this would target specific semantic parts
            dimension = np.random.randint(0, 3)  # Choose x, y, or z
            threshold = np.median(points[:, dimension])
            mask = points[:, dimension] > threshold
            return points[mask]
        
        return points

def prepare_dataloaders(dataset_path, batch_size=32, num_points=1024, num_workers=4):
    """
    Prepare train and test data loaders
    
    Parameters:
    -----------
    dataset_path: str, path to ModelNet40 dataset
    batch_size: int, batch size
    num_points: int, number of points per sample
    num_workers: int, number of dataloader workers
    
    Returns:
    --------
    train_loader: DataLoader for training data
    test_loader: DataLoader for test data
    """
    # Create datasets
    train_dataset = ModelNet40Dataset(
        root_dir=dataset_path,
        split='train',
        num_points=num_points,
        random_rotation=True,
        class_balance=True
    )
    
    test_dataset = ModelNet40Dataset(
        root_dir=dataset_path,
        split='test',
        num_points=num_points,
        random_rotation=False,
        class_balance=False
    )
    
    # Create weighted sampler for class balancing
    if train_dataset.weights is not None:
        sampler = WeightedRandomSampler(
            weights=train_dataset.sample_weights,
            num_samples=len(train_dataset),
            replacement=True
        )
    else:
        sampler = None
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, test_loader

def get_corrupted_dataset(dataset, corruption_type, severity):
    """
    Create a corrupted version of a dataset
    
    Parameters:
    -----------
    dataset: Dataset object
    corruption_type: str, type of corruption
    severity: int or float, severity level
    
    Returns:
    --------
    corrupted_dataset: Dataset with corruption applied
    """
    # Create a copy of the dataset
    corrupted_dataset = copy.deepcopy(dataset)
    
    # Define corruption parameters based on severity (1-5 scale)
    if corruption_type == 'gaussian_noise':
        param = {1: 0.01, 2: 0.02, 3: 0.03, 4: 0.04, 5: 0.05}[severity]
        transform = lambda x: PointCloudCorruptor.add_gaussian_noise(x, sigma=param)
    
    elif corruption_type == 'snow':
        density = {1: 0.05, 2: 0.1, 3: 0.15, 4: 0.2, 5: 0.3}[severity]
        transform = lambda x: PointCloudCorruptor.add_snow(x, density=density)
    
    elif corruption_type == 'rain':
        density = {1: 0.05, 2: 0.1, 3: 0.15, 4: 0.2, 5: 0.3}[severity]
        transform = lambda x: PointCloudCorruptor.add_rain(x, density=density)
    
    elif corruption_type == 'fog':
        density = {1: 0.05, 2: 0.1, 3: 0.2, 4: 0.3, 5: 0.4}[severity]
        transform = lambda x: PointCloudCorruptor.add_fog(x, density=density)
    
    elif corruption_type == 'depth_noise':
        param = {1: 0.01, 2: 0.02, 3: 0.04, 4: 0.06, 5: 0.1}[severity]
        transform = lambda x: PointCloudCorruptor.add_depth_noise(x, k=param)
    
    elif corruption_type == 'occlusion':
        ratio = {1: 0.1, 2: 0.15, 3: 0.2, 4: 0.25, 5: 0.3}[severity]
        transform = lambda x: PointCloudCorruptor.add_occlusion(x, ratio=ratio)
    
    else:
        raise ValueError(f"Unsupported corruption type: {corruption_type}")
    
    # Set the transform in the corrupted dataset
    corrupted_dataset.transform = transform
    
    return corrupted_dataset