In [None]:
%load_ext autoreload
%autoreload 2
CUDA_LAUNCH_BLOCKING=1
!nvidia-smi

In [2]:
from os import PathLike
from pathlib import Path
import random
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import json
import os, glob
import sys

import yaml
from tqdm import tqdm, trange

import numpy as np
from numpy import zeros, newaxis

import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.nn.functional as nnf
import torch.optim as optim
from torch.autograd import grad
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR,CosineAnnealingWarmRestarts,StepLR, ReduceLROnPlateau
from torchvision import datasets, transforms
import torchvision.models as models

from PIL import Image
import SimpleITK as sitk
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import pickle
import cv2

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split


if torch.cuda.is_available():
    dev = "cuda"
else:
    dev = "cpu"

In [None]:
import os
import nibabel as nib
import numpy as np
from tqdm import tqdm

input_folder = r"..."

images = []
labels = []

label_mapping = {
    "CN": 0,
    "AD": 1,
    "EMCI": 2,
    "LMCI": 3
}

file_list = [f for f in os.listdir(input_folder) if f.endswith(".nii.gz")]

for filename in tqdm(file_list, desc="Processing Files", unit="file"):
    for label_key, label_value in label_mapping.items():
        if f"_{label_key}." in filename:
            file_path = os.path.join(input_folder, filename)
            try:
                img = nib.load(file_path)
                data = img.get_fdata()
                if len(data.shape) != 2:
                    continue
                images.append(data)
                labels.append(label_value)
            except Exception as e:
                print(f"Error reading {filename}: {e}")
            break

images_array = np.stack(images)
labels_array = np.array(labels)

print("Images array shape:", images_array.shape)
print("Labels array shape:", labels_array.shape)

In [None]:
import numpy as np

def sample_per_class_and_average(images_array, labels_array, samples_per_class=10):
    unique_labels = np.unique(labels_array)
    all_sampled_indices = []
    
    for label in unique_labels:
        class_indices = np.where(labels_array == label)[0]
        
        if len(class_indices) < samples_per_class:
            print(f"Warning: Class {label} has only {len(class_indices)} samples")
            n_samples = len(class_indices)
        else:
            n_samples = samples_per_class
        
        class_sampled_indices = np.random.choice(class_indices, size=n_samples, replace=False)
        all_sampled_indices.extend(class_sampled_indices)
    
    all_sampled_indices = np.array(all_sampled_indices)
    
    sampled_images = images_array[all_sampled_indices]
    sampled_labels = labels_array[all_sampled_indices]
    
    average_image = np.mean(sampled_images, axis=0)
    
    return average_image, all_sampled_indices, sampled_labels

np.random.seed(42)

average_image, sampled_indices, sampled_labels = sample_per_class_and_average(images_array, labels_array)

print(f"\nTotal number of sampled images: {len(sampled_indices)}")
print(f"Average image shape: {average_image.shape}")

print(f"\nAverage image stats:")
print(f"  Min: {average_image.min():.3f}")
print(f"  Max: {average_image.max():.3f}")
print(f"  Mean: {average_image.mean():.3f}")
print(f"  Std: {average_image.std():.3f}")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
import nibabel as nib
import torchvision.transforms as transforms

class CustomDataset(Dataset):
    def __init__(self, images, labels, registered_img_path, average_image, transform=None):
        images = (images - images.min()) / (images.max() - images.min())
        self.images = torch.FloatTensor(images)
        self.labels = torch.LongTensor(labels)
        self.transform = transform
        
        average_image = (average_image - average_image.min()) / (average_image.max() - average_image.min())
        self.average_image = torch.FloatTensor(average_image)
        
        registered_img = nib.load(registered_img_path).get_fdata()
        registered_img = (registered_img - registered_img.min()) / (registered_img.max() - registered_img.min())
        self.registered_middle_slice = registered_img[:, :, registered_img.shape[2] // 2]
        self.registered_middle_slice = torch.FloatTensor(self.registered_middle_slice)

        # Define fixed volume paths for each class
        self.fixed_volume_paths = {
            0: '...',
            1: '...',
            2: '...',
            3: '...'
        }
        
        self.fixed_middle_slices = {}
        for label, path in self.fixed_volume_paths.items():
            try:
                print(f"\nProcessing file for label {label}: {path}")
                fixed_volume = nib.load(path).get_fdata()
                print(f"Shape of fixed volume for label {label}: {fixed_volume.shape}")
                
                if len(fixed_volume.shape) < 3:
                    raise ValueError(f"Volume for label {label} has less than 3 dimensions: {fixed_volume.shape}")
                
                fixed_volume = (fixed_volume - fixed_volume.min()) / (fixed_volume.max() - fixed_volume.min())
                middle_slice = fixed_volume[:, :, fixed_volume.shape[2] // 2]
                self.fixed_middle_slices[label] = torch.FloatTensor(middle_slice)
                print(f"Successfully processed label {label}")
            except Exception as e:
                print(f"Error processing label {label}: {str(e)}")
                print(f"File path: {path}")
                raise  
        
        if self.transform:
            self.registered_middle_slice = self.transform(self.registered_middle_slice.unsqueeze(0))
            self.average_image = self.transform(self.average_image.unsqueeze(0))
            for label in self.fixed_middle_slices:
                self.fixed_middle_slices[label] = self.transform(self.fixed_middle_slices[label].unsqueeze(0))

    def __len__(self):
        return len(self.labels)
        
    def __getitem__(self, idx):
        image = self.images[idx]
        registered_image = self.registered_middle_slice
        average_image = self.average_image
        label = self.labels[idx]
        
        fixed_image = self.fixed_middle_slices[label.item()]
        
        if self.transform:
            image = self.transform(image.unsqueeze(0))
        else:
            image = image.unsqueeze(0)
            
        return (image.squeeze(0), 
                registered_image.squeeze(0), 
                average_image.squeeze(0),
                fixed_image.squeeze(0),
                label)

def create_stratified_splits(images_array, labels_array, registered_img_path, average_image, batch_size=32, num_workers=4):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  
    ])

    train_images, temp_images, train_labels, temp_labels = train_test_split(
        images_array,
        labels_array,
        train_size=0.7,
        stratify=labels_array,
        random_state=42
    )
    
    val_images, test_images, val_labels, test_labels = train_test_split(
        temp_images,
        temp_labels,
        train_size=0.5,
        stratify=temp_labels,
        random_state=42
    )
    
    train_dataset = CustomDataset(train_images, train_labels, registered_img_path, average_image, transform=transform)
    val_dataset = CustomDataset(val_images, val_labels, registered_img_path, average_image, transform=transform)
    test_dataset = CustomDataset(test_images, test_labels, registered_img_path, average_image, transform=transform)
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    print("\nDetailed class distribution in splits:")
    print("-" * 50)
    
    train_dist = np.bincount(train_labels)
    val_dist = np.bincount(val_labels)
    test_dist = np.bincount(test_labels)
    
    print(f"{'Class':<6} {'Train':<20} {'Validation':<20} {'Test':<20}")
    print("-" * 50)
    
    for class_idx in range(len(train_dist)):
        train_percent = (train_dist[class_idx] / len(train_labels)) * 100
        val_percent = (val_dist[class_idx] / len(val_labels)) * 100
        test_percent = (test_dist[class_idx] / len(test_labels)) * 100
        
        print(f"{class_idx:<6} {train_dist[class_idx]:>4} ({train_percent:>5.1f}%)  {val_dist[class_idx]:>6} ({val_percent:>5.1f}%)  {test_dist[class_idx]:>6} ({test_percent:>5.1f}%)")
    
    return train_loader, val_loader, test_loader

train_loader, val_loader, test_loader = create_stratified_splits(
    images_array,
    labels_array,
    registered_img_path='...',
    average_image=average_image,
    batch_size=512,
    num_workers=4
)

for batch_img, batch_reg_img, batch_avg_img, batch_fixed_img, batch_labels in train_loader:
    print("Regular images shape:", batch_img.shape)
    print("Regular images range:", batch_img.min().item(), "to", batch_img.max().item())

    print("Registered image shape:", batch_reg_img.shape)
    print("Registered image range:", batch_reg_img.min().item(), "to", batch_reg_img.max().item())

    print("Average image shape:", batch_avg_img.shape)
    print("Average image range:", batch_avg_img.min().item(), "to", batch_avg_img.max().item())

    print("Fixed image shape:", batch_fixed_img.shape)
    print("Fixed image range:", batch_fixed_img.min().item(), "to", batch_fixed_img.max().item())

    print("Labels shape:", batch_labels.shape)
    break

In [None]:
def visualize_batch(batch_img, batch_reg_img, batch_labels, num_samples=8):
    if torch.is_tensor(batch_img):
        batch_img = batch_img.numpy()
    if torch.is_tensor(batch_reg_img):
        batch_reg_img = batch_reg_img.numpy()
    if torch.is_tensor(batch_labels):
        batch_labels = batch_labels.numpy()
    
    batch_img = batch_img[:num_samples]
    batch_reg_img = batch_reg_img[:num_samples]
    batch_labels = batch_labels[:num_samples]
    
    fig, axes = plt.subplots(2, num_samples, figsize=(20, 5))
    plt.suptitle('Batch Visualization: Original Images (top) vs Registered Images (bottom)', fontsize=14)
    
    for i in range(num_samples):
        axes[0, i].imshow(batch_img[i], cmap='gray')
        axes[0, i].axis('off')
        axes[0, i].set_title(f'Label: {batch_labels[i]}')
    
    for i in range(num_samples):
        axes[1, i].imshow(batch_reg_img[i], cmap='gray')
        axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

for batch_img, batch_reg_img, batch_avg_img, batch_fixed_img, batch_labels in train_loader:
    visualize_batch(batch_img, batch_fixed_img, batch_labels)
    break

In [13]:
################Parameter Loading#######################
def read_yaml(path):
    try:
        with open(path, 'r') as f:
            file = edict(yaml.load(f, Loader=yaml.FullLoader))
        return file
    except:
        print('NO FILE READ!')
        return None
para = read_yaml('./parameters.yml')

xDim = 128 
yDim = 128
zDim = 1

def loss_Reg(y_pred):
        # For 3D reg
        # dy = torch.abs(y_pred[:, :, 1:, :, :] - y_pred[:, :, :-1, :, :])
        # dx = torch.abs(y_pred[:, :, :, 1:, :] - y_pred[:, :, :, :-1, :])
        # dz = torch.abs(y_pred[:, :, :, :, 1:] - y_pred[:, :, :, :, :-1])
        # dy = dy * dy
        # dx = dx * dx
        # dz = dz * dz
        # d = torch.mean(dx) + torch.mean(dy) + torch.mean(dz)
        # grad = d / 3.0

        dy = torch.abs(y_pred[:, :, 1:, :] - y_pred[:, :, :-1, :])
        dx = torch.abs(y_pred[:, :, :, 1:] - y_pred[:, :, :, :-1])

        dy = dy * dy
        dx = dx * dx
        d = torch.mean(dx) + torch.mean(dy) 
        grad = d / 2.0
        return grad
    

def jacobian_determinant(displacement):
    """
    Calculate Jacobian determinant of the deformation field
    displacement: tensor of shape [batch, 2, H, W]
    """
    # Get x and y components of displacement field
    disp_x = displacement[:, 0, :, :]  # [batch, H, W]
    disp_y = displacement[:, 1, :, :]  # [batch, H, W]
    
    # Calculate gradients
    dx_x = disp_x[:, :, 1:] - disp_x[:, :, :-1]  # partial_x of x displacement
    dx_y = disp_y[:, :, 1:] - disp_y[:, :, :-1]  # partial_x of y displacement
    dy_x = disp_x[:, 1:, :-1] - disp_x[:, :-1, :-1]  # partial_y of x displacement
    dy_y = disp_y[:, 1:, :-1] - disp_y[:, :-1, :-1]  # partial_y of y displacement

    # Calculate Jacobian determinant
    det = (1 + dx_x[:, :-1, :]) * (1 + dy_y) - dx_y[:, :-1, :] * dy_x

    return torch.mean((det - 1)**2)

class NCCLoss(nn.Module):
    """
    Normalized Cross Correlation loss.
    Returns a loss that when minimized maximizes the NCC between the input and the template.
    Zero is the perfect score.
    """
    def __init__(self, eps=1e-6):
        super(NCCLoss, self).__init__()
        self.eps = eps

    def forward(self, pred, target):
        # Compute means
        pred_mean = pred.mean(dim=(-2, -1), keepdim=True)
        target_mean = target.mean(dim=(-2, -1), keepdim=True)

        # Compute normalized variables
        pred_norm = pred - pred_mean
        target_norm = target - target_mean

        # Compute variances
        pred_var = torch.sum(pred_norm ** 2, dim=(-2, -1), keepdim=True)
        target_var = torch.sum(target_norm ** 2, dim=(-2, -1), keepdim=True)

        # Compute cross correlation
        cross_corr = torch.sum(pred_norm * target_norm, dim=(-2, -1), keepdim=True)

        # Compute NCC
        ncc = cross_corr / (torch.sqrt(pred_var) * torch.sqrt(target_var) + self.eps)

        # Return loss (1 - NCC), bounded between 0 and 2
        return 1 - ncc.mean()
    
from losses import MSE, Grad
#################Network optimization########################
from networks import DiffeoDense  
net = []
for i in range(3):
    temp = DiffeoDense(inshape = (xDim,yDim),
				 nb_unet_features= [[16, 32],[ 32, 32, 16, 16]],
                 nb_unet_conv_per_level=1,
                 int_steps=7,
                 int_downsize=2,
                 src_feats=1,
                 trg_feats=1,
                 unet_half_res= True)
    net.append(temp)
net = net[0].to(dev)

if(para.model.loss == 'L2'):
    criterion = nn.MSELoss()
elif (para.model.loss == 'L1'):
    criterion = nn.L1Loss()
if(para.model.optimizer == 'Adam'):
    optimizer = optim.Adam(net.parameters(), lr= para.solver.lr)
elif (para.model.optimizer == 'SGD'):
    optimizer = optim.SGD(net.parameters(), lr= para.solver.lr, momentum=0.9)
if (para.model.scheduler == 'CosAn'):
    scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0)
    
class SupConLoss(nn.Module):
    """Enhanced Supervised Contrastive Learning loss with diagnostics"""
    def __init__(self, temperature=0.1):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, features, labels):
        if len(features.shape) > 2:
            b, c, h, w = features.shape
            features = features.view(b, c * h * w)
            # print(f"Reshaped features - shape: {features.shape}")
        
        device = features.device
        batch_size = features.shape[0]
        
        # Normalize features
        features = F.normalize(features, dim=1)
        # print(f"Normalized features - min: {features.min():.4f}, "
        #       f"max: {features.max():.4f}, mean: {features.mean():.4f}")
        
        # Compute similarity matrix
        sim_matrix = torch.matmul(features, features.T) / self.temperature
        # print(f"Similarity matrix - min: {sim_matrix.min():.4f}, "
        #       f"max: {sim_matrix.max():.4f}, mean: {sim_matrix.mean():.4f}")
        
        # Get positive pair mask
        labels = labels.view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        
        # Print positive pair statistics
        pos_pairs = mask.sum().item() - batch_size  # subtract diagonal
        total_possible = batch_size * (batch_size - 1)
        # print(f"Positive pairs: {pos_pairs}, Negative pairs: {total_possible - pos_pairs}")
        
        # Mask out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        
        # Compute log probabilities
        exp_sim = torch.exp(sim_matrix) * logits_mask
        log_prob = sim_matrix - torch.log(exp_sim.sum(1, keepdim=True))
        
        # Mean of log-probability over positive pairs
        mean_log_prob_pos = (mask * log_prob).sum(1) / torch.clamp(mask.sum(1), min=1e-12)
        
        # Loss
        loss = -mean_log_prob_pos.mean()
        # print(f"Contrastive loss: {loss.item():.4f}")
        
        return loss
    
class FeatureTransform(nn.Module):
    """Transform features to be more suitable for contrastive learning"""
    def __init__(self, in_channels=32):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        self.bn = nn.BatchNorm2d(in_channels)
        self.pool = nn.AdaptiveAvgPool2d((8, 8))
        
    def forward(self, x):
        x = F.relu(self.bn(self.conv(x)))
        x = self.pool(x)
        
        # Handle any NaN values
        if torch.isnan(x).any():
            print("Warning: Features contain NaN values. Replacing with zeros.")
            x = torch.nan_to_num(x, nan=0.0)
        
        return x

In [None]:
# Initialize losses and feature transform
criterion = NCCLoss()
feature_transform = FeatureTransform().to(dev)
con_criterion = SupConLoss(temperature=0.1)

# Initialize optimizer with both networks
optimizer = torch.optim.Adam([
    {'params': net.parameters()},
    {'params': feature_transform.parameters()}
], lr=para.solver.lr)

training_losses = []
reg_losses = []
con_losses = []
total_batches = len(train_loader) * para.solver.epochs

# Initialize progress bar for the entire training process
with tqdm(total=total_batches, desc="Training Progress") as pbar:
    for epoch in range(para.solver.epochs):
        total = 0
        total_reg = 0
        total_con = 0
        running_loss = 0.0
        net.train()
        feature_transform.train()
        
        for src_bch, _, _, tar_bch, labels in train_loader:
            b, w, h = src_bch.shape
            optimizer.zero_grad()
            
            # Prepare input batches
            src_bch = src_bch.reshape(b, 1, w, h).to(dev)
            tar_bch = tar_bch.reshape(b, 1, w, h).to(dev)
            labels = labels.to(dev)
            
            # Forward pass
            pred = net(src_bch, tar_bch, registration=True)
            
            # Registration loss
            reg_loss = criterion(pred[0], src_bch)
            smoothness_loss = loss_Reg(pred[1])
            
            # Transform features and compute contrastive loss
            transformed_features = feature_transform(pred[2])
            con_loss = con_criterion(transformed_features, labels)
            
            # Combined loss with balanced weights
            loss_total = reg_loss + smoothness_loss + 0.001 * con_loss
            
            # Gradient clipping before backward pass
            torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(feature_transform.parameters(), max_norm=1.0)
            
            # Backward pass
            loss_total.backward()
            optimizer.step()
            
            # Update running losses
            running_loss += loss_total.item()
            total += loss_total.item()
            total_reg += reg_loss.item()
            total_con += con_loss.item()
            
            # Update progress bar
            pbar.set_postfix({
                "Total Loss": f"{loss_total.item():.4f}",
                "Reg Loss": f"{reg_loss.item():.4f}",
                "Con Loss": f"{con_loss.item():.4f}"
            })
            pbar.update(1)
        
        # Append epoch losses to lists
        avg_loss = total / len(train_loader)
        avg_reg_loss = total_reg / len(train_loader)
        avg_con_loss = total_con / len(train_loader)
        
        training_losses.append(avg_loss)
        reg_losses.append(avg_reg_loss)
        con_losses.append(avg_con_loss)

# Plot the training losses
plt.figure(figsize=(15, 4))

# Plot total loss
plt.subplot(1, 3, 1)
plt.plot(range(1, para.solver.epochs + 1), training_losses, marker='o', color='b', label='Total Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Total Training Loss')
plt.legend()
plt.grid()

# Plot registration loss
plt.subplot(1, 3, 2)
plt.plot(range(1, para.solver.epochs + 1), reg_losses, marker='o', color='g', label='Registration Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Registration Loss')
plt.legend()
plt.grid()

# Plot contrastive loss
plt.subplot(1, 3, 3)
plt.plot(range(1, para.solver.epochs + 1), con_losses, marker='o', color='r', label='Contrastive Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Contrastive Loss')
plt.legend()
plt.grid()

plt.tight_layout()
plt.show()

# Save the complete model
torch.save(net, '...')
torch.save(feature_transform, '...')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import tqdm

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class RawFeatureClassifier(nn.Module):
    """Classifier for raw features - original architecture"""
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, x):
        return self.classifier(x)

class TransformedFeatureClassifier(nn.Module):
    """Classifier for transformed features - more sophisticated architecture"""
    def __init__(self, input_dim, hidden_dim, num_classes):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, num_classes)
        )
        
    def forward(self, x):
        return self.classifier(x)

def extract_features(loader, pretrained_net, feature_transform=None, device='cuda'):
    """Extract both raw and transformed features from loader"""
    raw_features_list = []
    transformed_features_list = []
    labels_list = []
    
    pretrained_net.eval()
    if feature_transform is not None:
        feature_transform.eval()
    
    with torch.no_grad():
        for batch_img, batch_reg_img, batch_avg_img, batch_const_img, batch_labels in tqdm(loader, desc="Extracting features"):
            # Move data to device
            batch_img = batch_img.reshape(batch_img.shape[0], 1, batch_img.shape[1], batch_img.shape[2]).to(device)
            batch_const_img = batch_const_img.reshape(batch_const_img.shape[0], 1, batch_const_img.shape[1], batch_const_img.shape[2]).to(device)
            
            # Forward pass through pretrained network
            pred = pretrained_net(batch_img, batch_const_img, registration=True)
            raw_features = pred[2]
            
            # Flatten raw features
            raw_features = raw_features.view(raw_features.size(0), -1)
            raw_features_list.append(raw_features.cpu())
            
            # Get transformed features if feature_transform is provided
            if feature_transform is not None:
                transformed_features = feature_transform(pred[2])
                transformed_features = transformed_features.view(transformed_features.size(0), -1)
                transformed_features_list.append(transformed_features.cpu())
            
            labels_list.append(batch_labels)
    
    raw_features = torch.cat(raw_features_list, dim=0)
    labels = torch.cat(labels_list, dim=0)
    
    if feature_transform is not None:
        transformed_features = torch.cat(transformed_features_list, dim=0)
        return raw_features, transformed_features, labels
    
    return raw_features, labels

def calculate_metrics(y_true, y_pred, y_score=None):
    """Calculate metrics including AUC score when probability scores are provided"""
    # Overall accuracy
    accuracy = (y_true == y_pred).mean()
    
    # Per-class metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    
    # Per-class accuracy
    conf_matrix = confusion_matrix(y_true, y_pred)
    per_class_acc = conf_matrix.diagonal() / conf_matrix.sum(axis=1)
    
    metrics_dict = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'per_class_acc': per_class_acc,
        'conf_matrix': conf_matrix
    }
    
    # Add AUC score if probability scores are provided
    if y_score is not None:
        try:
            auc = roc_auc_score(y_true, y_score, multi_class='ovr')
            metrics_dict['auc'] = auc
        except:
            metrics_dict['auc'] = None
            print("Warning: Could not calculate AUC score")
    else:
        metrics_dict['auc'] = None
    
    return metrics_dict

def plot_confusion_matrix(conf_matrix, save_path=None):
    """Plot and optionally save confusion matrix"""
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    if save_path:
        plt.savefig(save_path)
    plt.close()

def train_classifier(pretrained_net, train_loader, val_loader, feature_transform=None, save_dir='./classifier_results', feature_type='raw'):
    """Train classifier using either raw or transformed features"""
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Extract features
    print("Extracting features...")
    if feature_transform is not None:
        train_raw, train_transformed, train_labels = extract_features(train_loader, pretrained_net, feature_transform, device)
        val_raw, val_transformed, val_labels = extract_features(val_loader, pretrained_net, feature_transform, device)
        
        # Select features based on feature_type
        train_features = train_transformed if feature_type == 'transformed' else train_raw
        val_features = val_transformed if feature_type == 'transformed' else val_raw
    else:
        train_features, train_labels = extract_features(train_loader, pretrained_net, None, device)
        val_features, val_labels = extract_features(val_loader, pretrained_net, None, device)
    
    # Convert labels to long type
    train_labels = train_labels.long()
    val_labels = val_labels.long()
    
    # Create datasets and dataloaders
    train_dataset = TensorDataset(train_features, train_labels)
    val_dataset = TensorDataset(val_features, val_labels)
    
    batch_size = 32
    train_loader_clf = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader_clf = DataLoader(val_dataset, batch_size=batch_size)
    
    # Initialize model based on feature type
    input_dim = train_features.shape[1]
    hidden_dim = 256
    num_classes = len(torch.unique(train_labels))
    
    if feature_type == 'transformed':
        model = TransformedFeatureClassifier(input_dim, hidden_dim, num_classes).to(device)
        learning_rate = 0.000005
    else:
        model = RawFeatureClassifier(input_dim, hidden_dim, num_classes).to(device)
        learning_rate = 0.00005
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    
    # Training parameters
    num_epochs = 300
    early_stopping_patience = 30
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None
    
    # Training history
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    # Training loop with progress bar
    print(f"Starting training with {feature_type} features...")
    pbar = tqdm(range(num_epochs), desc="Training")
    for epoch in pbar:
        model.train()
        train_loss = 0
        train_correct = 0
        total_train = 0
        
        for features, labels in train_loader_clf:
            features, labels = features.to(device), labels.to(device)
            features = features.float()
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        total_val = 0
        
        with torch.no_grad():
            for features, labels in val_loader_clf:
                features, labels = features.to(device), labels.to(device)
                features = features.float()
                outputs = model(features)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Calculate epoch metrics
        train_loss = train_loss / len(train_loader_clf)
        val_loss = val_loss / len(val_loader_clf)
        train_acc = train_correct / total_train
        val_acc = val_correct / total_val
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        # Update progress bar
        pbar.set_postfix({
            'train_loss': f'{train_loss:.4f}',
            'val_loss': f'{val_loss:.4f}',
            'train_acc': f'{train_acc:.4f}',
            'val_acc': f'{val_acc:.4f}'
        })
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
    
    # Save best model
    model_filename = f'best_feature_classifier_{feature_type}.pth'
    torch.save(best_model_state, os.path.join(save_dir, model_filename))
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title(f'Loss History ({feature_type} features)')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title(f'Accuracy History ({feature_type} features)')
    plt.legend()
    plt.savefig(os.path.join(save_dir, f'training_history_{feature_type}.png'))
    plt.close()
    
    return model, best_model_state

def test_classifier(pretrained_net, test_loader, model_state, feature_transform=None, feature_type='raw', save_dir='./classifier_results'):
    """Test the trained classifier with either raw or transformed features"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Extract test features
    if feature_transform is not None:
        test_raw, test_transformed, test_labels = extract_features(test_loader, pretrained_net, feature_transform, device)
        test_features = test_transformed if feature_type == 'transformed' else test_raw
    else:
        test_features, test_labels = extract_features(test_loader, pretrained_net, None, device)
    
    test_labels = test_labels.long()
    test_dataset = TensorDataset(test_features, test_labels)
    test_loader_clf = DataLoader(test_dataset, batch_size=32)
    
    # Initialize appropriate model type
    input_dim = test_features.shape[1]
    hidden_dim = 256
    num_classes = len(torch.unique(test_labels))
    
    if feature_type == 'transformed':
        model = TransformedFeatureClassifier(input_dim, hidden_dim, num_classes).to(device)
    else:
        model = RawFeatureClassifier(input_dim, hidden_dim, num_classes).to(device)
    
    model.load_state_dict(model_state)
    model.eval()
    
    all_preds = []
    all_labels = []
    all_scores = []
    
    with torch.no_grad():
        for features, labels in tqdm(test_loader_clf, desc=f"Testing ({feature_type} features)"):
            features, labels = features.to(device), labels.to(device)
            features = features.float()
            outputs = model(features)
            scores = torch.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_scores.extend(scores.cpu().numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_scores = np.array(all_scores)
    
    # Calculate metrics
    metrics = calculate_metrics(all_labels, all_preds, all_scores)
    
    print(f"\nTest Set Metrics ({feature_type} features):")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    if metrics['auc'] is not None:
        print(f"AUC Score: {metrics['auc']:.4f}")
    print("\nPer-class Accuracy:")
    for i, acc in enumerate(metrics['per_class_acc']):
        print(f"Class {i}: {acc:.4f}")
    
    # Plot confusion matrix
    plot_confusion_matrix(metrics['conf_matrix'], os.path.join(save_dir, f'confusion_matrix_{feature_type}.png'))
    
    return metrics

# Example usage:
if __name__ == "__main__":
    # Set random seed for reproducibility
    set_seed(42)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load models
    pretrained_net = torch.load('...').to(device)
    feature_transform = torch.load('...').to(device)
    
    save_dir = './saved_models'
    os.makedirs(save_dir, exist_ok=True)
    
    # Train and test with raw features
    print("\nTraining with raw features...")
    model_raw, state_raw = train_classifier(pretrained_net, train_loader, val_loader, 
                                          feature_transform=feature_transform, 
                                          feature_type='raw', 
                                          save_dir=save_dir)
    
    print("\nTesting with raw features...")
    metrics_raw = test_classifier(pretrained_net, test_loader, state_raw, 
                                feature_transform=feature_transform, 
                                feature_type='raw', 
                                save_dir=save_dir)
    
    # Train and test with transformed features
    print("\nTraining with transformed features...")
    model_transformed, state_transformed = train_classifier(pretrained_net, train_loader, val_loader, 
                                                         feature_transform=feature_transform, 
                                                         feature_type='transformed', 
                                                         save_dir=save_dir)
    
    print("\nTesting with transformed features...")
    metrics_transformed = test_classifier(pretrained_net, test_loader, state_transformed, 
                                       feature_transform=feature_transform, 
                                       feature_type='transformed', 
                                       save_dir=save_dir)

In [None]:
def add_gaussian_noise(images, noise_level):
    """Add Gaussian noise to images"""
    noise = torch.randn_like(images) * noise_level
    noisy_images = images + noise
    return torch.clamp(noisy_images, 0, 1)

def test_noise_robustness(pretrained_net, feature_transform, test_loader, model_states, 
                         noise_levels=[0, 0.001, 0.005, 0.01, 0.05], save_dir='./classifier_results'):
    """Test model robustness against different levels of Gaussian noise for both raw and transformed features"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    pretrained_net.eval()
    feature_transform.eval()
    
    # Results dictionary to store metrics for each noise level and feature type
    noise_results = {'raw': {}, 'transformed': {}}
    
    # Test for each noise level
    for noise_level in noise_levels:
        print(f"\nTesting with noise level: {noise_level}")
        
        all_raw_preds = []
        all_transformed_preds = []
        all_labels = []
        
        # Extract features with noisy images
        with torch.no_grad():
            for batch_img, batch_reg_img, batch_avg_img, batch_const_img, batch_labels in tqdm(test_loader, 
                                                                                              desc=f"Processing (noise={noise_level})"):
                # Add noise to the image
                batch_img = batch_img.reshape(batch_img.shape[0], 1, batch_img.shape[1], batch_img.shape[2])
                noisy_img = add_gaussian_noise(batch_img, noise_level).to(device)
                batch_const_img = batch_const_img.reshape(batch_const_img.shape[0], 1, batch_const_img.shape[1], 
                                                        batch_const_img.shape[2]).to(device)
                
                # Get features
                pred = pretrained_net(noisy_img, batch_const_img, registration=True)
                raw_features = pred[2]
                transformed_features = feature_transform(raw_features)
                
                # Flatten features
                raw_features = raw_features.view(raw_features.size(0), -1)
                transformed_features = transformed_features.view(transformed_features.size(0), -1)
                
                # Initialize classifiers if first batch
                if len(all_raw_preds) == 0:
                    raw_dim = raw_features.shape[1]
                    transformed_dim = transformed_features.shape[1]
                    hidden_dim = 256
                    num_classes = len(torch.unique(batch_labels))
                    
                    # Initialize both classifiers
                    raw_classifier = RawFeatureClassifier(raw_dim, hidden_dim, num_classes).to(device)
                    transformed_classifier = TransformedFeatureClassifier(transformed_dim, hidden_dim, num_classes).to(device)
                    
                    # Load states
                    raw_classifier.load_state_dict(model_states['raw'])
                    transformed_classifier.load_state_dict(model_states['transformed'])
                    
                    raw_classifier.eval()
                    transformed_classifier.eval()
                
                # Get predictions
                raw_outputs = raw_classifier(raw_features.float())
                transformed_outputs = transformed_classifier(transformed_features.float())
                
                _, raw_predicted = torch.max(raw_outputs.data, 1)
                _, transformed_predicted = torch.max(transformed_outputs.data, 1)
                
                all_raw_preds.extend(raw_predicted.cpu().numpy())
                all_transformed_preds.extend(transformed_predicted.cpu().numpy())
                all_labels.extend(batch_labels.numpy())
        
        # Calculate metrics for this noise level
        raw_metrics = calculate_metrics(np.array(all_labels), np.array(all_raw_preds))
        transformed_metrics = calculate_metrics(np.array(all_labels), np.array(all_transformed_preds))
        
        noise_results['raw'][noise_level] = raw_metrics
        noise_results['transformed'][noise_level] = transformed_metrics
        
        # Print results for this noise level
        for feature_type in ['raw', 'transformed']:
            metrics = noise_results[feature_type][noise_level]
            print(f"\nMetrics for {feature_type} features at noise level {noise_level}:")
            print(f"Accuracy: {metrics['accuracy']:.4f}")
            print(f"Precision: {metrics['precision']:.4f}")
            print(f"Recall: {metrics['recall']:.4f}")
            print(f"F1 Score: {metrics['f1']:.4f}")
            print("\nPer-class Accuracy:")
            for i, acc in enumerate(metrics['per_class_acc']):
                print(f"Class {i}: {acc:.4f}")
            
            # Plot and save confusion matrix
            plot_confusion_matrix(
                metrics['conf_matrix'], 
                os.path.join(save_dir, f'confusion_matrix_{feature_type}_noise_{noise_level}.png')
            )
    
    # Plot accuracy vs noise level for both feature types
    plt.figure(figsize=(12, 6))
    noise_levels_list = list(noise_results['raw'].keys())
    
    # Plot raw features accuracy
    raw_accuracies = [noise_results['raw'][level]['accuracy'] for level in noise_levels_list]
    plt.plot(noise_levels_list, raw_accuracies, 'bo-', label='Raw Features')
    
    # Plot transformed features accuracy
    transformed_accuracies = [noise_results['transformed'][level]['accuracy'] for level in noise_levels_list]
    plt.plot(noise_levels_list, transformed_accuracies, 'ro-', label='Transformed Features')
    
    plt.xlabel('Noise Level (σ)')
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy vs. Gaussian Noise Level')
    plt.grid(True)
    plt.legend()
    plt.savefig(os.path.join(save_dir, 'noise_robustness_comparison.png'))
    plt.close()
    
    return noise_results

# Usage after training both classifiers:
print("\nTesting noise robustness...")
model_states = {
    'raw': state_raw,
    'transformed': state_transformed
}

noise_results = test_noise_robustness(
    pretrained_net=pretrained_net,
    feature_transform=feature_transform,
    test_loader=test_loader,
    model_states=model_states,
    noise_levels=[0, 0.001, 0.005, 0.01, 0.05],
    save_dir=save_dir
)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm import tqdm
import torchvision.models as models

# Feature Extractors
class ImageFeatureExtractor(nn.Module):
    def __init__(self, finetune=False):
        super().__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
        
        if not finetune:
            for param in self.resnet.parameters():
                param.requires_grad = False
    
    def forward(self, x):
        x = self.resnet(x)
        return x.view(x.size(0), -1)

# Combined Classifiers
class CombinedRawClassifier(nn.Module):
    def __init__(self, img_dim, shape_dim, hidden_dim, num_classes):
        super().__init__()
        self.combined_dim = img_dim + shape_dim
        self.classifier = nn.Sequential(
            nn.Linear(self.combined_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, img_features, shape_features):
        combined = torch.cat([img_features, shape_features], dim=1)
        return self.classifier(combined)

class CombinedTransformedClassifier(nn.Module):
    def __init__(self, img_dim, shape_dim, hidden_dim, num_classes):
        super().__init__()
        self.combined_dim = img_dim + shape_dim
        self.classifier = nn.Sequential(
            nn.Linear(self.combined_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),  # Lower dropout for transformed features
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, img_features, shape_features):
        combined = torch.cat([img_features, shape_features], dim=1)
        return self.classifier(combined)

def extract_features(loader, img_net, shape_net, feature_transform=None, feature_type='raw', device='cuda'):
    """Extract both image and shape features"""
    img_features_list = []
    shape_features_list = []
    labels_list = []
    
    img_net.eval()
    shape_net.eval()
    if feature_transform is not None:
        feature_transform.eval()
    
    with torch.no_grad():
        for batch_img, batch_reg_img, batch_avg_img, batch_const_img, batch_labels in tqdm(loader, desc="Extracting features"):
            # Process images
            batch_img = batch_img.reshape(batch_img.shape[0], 1, batch_img.shape[1], batch_img.shape[2]).to(device)
            batch_const_img = batch_const_img.reshape(batch_const_img.shape[0], 1, batch_const_img.shape[1], batch_const_img.shape[2]).to(device)
            
            # Get image features
            img_features = img_net(batch_img)
            
            # Get shape features
            shape_pred = shape_net(batch_img, batch_const_img, registration=True)
            shape_features = shape_pred[2]
            
            if feature_type == 'transformed' and feature_transform is not None:
                shape_features = feature_transform(shape_features)
            
            # Flatten shape features
            shape_features = shape_features.view(shape_features.size(0), -1)
            
            img_features_list.append(img_features.cpu())
            shape_features_list.append(shape_features.cpu())
            labels_list.append(batch_labels)
    
    img_features = torch.cat(img_features_list, dim=0)
    shape_features = torch.cat(shape_features_list, dim=0)
    labels = torch.cat(labels_list, dim=0)
    
    return img_features, shape_features, labels

def train_combined_classifier(train_loader, val_loader, img_net, shape_net, feature_transform=None, 
                            feature_type='raw', finetune_resnet=False, save_dir='./classifier_results'):
    """Train classifier combining image and shape features"""
    os.makedirs(save_dir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Extract initial features to get dimensions
    sample_batch = next(iter(train_loader))
    batch_img = sample_batch[0].reshape(-1, 1, sample_batch[0].shape[1], sample_batch[0].shape[2]).to(device)
    batch_const_img = sample_batch[3].reshape(-1, 1, sample_batch[3].shape[1], sample_batch[3].shape[2]).to(device)
    
    with torch.no_grad():
        img_features = img_net(batch_img)
        shape_pred = shape_net(batch_img, batch_const_img, registration=True)
        shape_features = shape_pred[2]
        
        if feature_type == 'transformed' and feature_transform is not None:
            shape_features = feature_transform(shape_features)
        
        shape_features = shape_features.view(shape_features.size(0), -1)
    
    # Get dimensions
    img_dim = img_features.shape[1]
    shape_dim = shape_features.shape[1]
    num_classes = len(torch.unique(sample_batch[-1]))
    hidden_dim = 1024 if feature_type == 'raw' else 512  # Larger for raw features
    
    # Initialize appropriate classifier
    if feature_type == 'raw':
        model = CombinedRawClassifier(img_dim, shape_dim, hidden_dim, num_classes).to(device)
    else:
        model = CombinedTransformedClassifier(img_dim, shape_dim, hidden_dim, num_classes).to(device)
    
    # Create parameter groups
    params = []
    if finetune_resnet:
        params.extend(list(img_net.parameters()))
    params.extend(list(model.parameters()))
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params, lr=0.00005, weight_decay=0)
    
    # Training parameters
    num_epochs = 300
    early_stopping_patience = 10
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_states = None
    
    # Training history
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    
    # Training loop
    print(f"Starting training with {feature_type} shape features...")
    pbar = tqdm(range(num_epochs), desc="Training")
    for epoch in pbar:
        if finetune_resnet:
            img_net.train()
        model.train()
        
        train_loss = 0
        train_correct = 0
        total_train = 0
        
        # Training
        for batch_img, _, _, batch_const_img, batch_labels in train_loader:
            batch_img = batch_img.reshape(batch_img.shape[0], 1, batch_img.shape[1], batch_img.shape[2]).to(device)
            batch_const_img = batch_const_img.reshape(batch_const_img.shape[0], 1, batch_const_img.shape[1], batch_const_img.shape[2]).to(device)
            batch_labels = batch_labels.to(device)
            
            # Get features
            img_features = img_net(batch_img)
            shape_pred = shape_net(batch_img, batch_const_img, registration=True)
            shape_features = shape_pred[2]
            
            if feature_type == 'transformed' and feature_transform is not None:
                shape_features = feature_transform(shape_features)
            
            shape_features = shape_features.view(shape_features.size(0), -1)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(img_features, shape_features)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += batch_labels.size(0)
            train_correct += (predicted == batch_labels).sum().item()
        
        # Validation
        model.eval()
        img_net.eval()
        val_loss = 0
        val_correct = 0
        total_val = 0
        
        with torch.no_grad():
            for batch_img, _, _, batch_const_img, batch_labels in val_loader:
                batch_img = batch_img.reshape(batch_img.shape[0], 1, batch_img.shape[1], batch_img.shape[2]).to(device)
                batch_const_img = batch_const_img.reshape(batch_const_img.shape[0], 1, batch_const_img.shape[1], batch_const_img.shape[2]).to(device)
                batch_labels = batch_labels.to(device)
                
                # Get features
                img_features = img_net(batch_img)
                shape_pred = shape_net(batch_img, batch_const_img, registration=True)
                shape_features = shape_pred[2]
                
                if feature_type == 'transformed' and feature_transform is not None:
                    shape_features = feature_transform(shape_features)
                
                shape_features = shape_features.view(shape_features.size(0), -1)
                
                outputs = model(img_features, shape_features)
                loss = criterion(outputs, batch_labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total_val += batch_labels.size(0)
                val_correct += (predicted == batch_labels).sum().item()
        
        # Calculate metrics
        train_loss = train_loss / len(train_loader)
        val_loss = val_loss / len(val_loader)
        train_acc = train_correct / total_train
        val_acc = val_correct / total_val
        
        # Update history
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        # Update progress bar
        pbar.set_postfix({
            'train_loss': f'{train_loss:.4f}',
            'val_loss': f'{val_loss:.4f}',
            'train_acc': f'{train_acc:.4f}',
            'val_acc': f'{val_acc:.4f}'
        })
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_states = {
                'classifier': model.state_dict(),
                'img_net': img_net.state_dict() if finetune_resnet else None
            }
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
    
    # Save best model
    torch.save(best_model_states, os.path.join(save_dir, f'best_combined_classifier_{feature_type}.pth'))
    
    # Plot training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title(f'Loss History ({feature_type})')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.title(f'Accuracy History ({feature_type})')
    plt.legend()
    plt.savefig(os.path.join(save_dir, f'training_history_{feature_type}.png'))
    plt.close()
    
    return model, img_net, best_model_states

def test_combined_classifier(test_loader, img_net, shape_net, feature_transform=None, 
                           feature_type='raw', model_states=None, save_dir='./classifier_results'):
    """Test the combined classifier"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize networks
    if model_states['img_net'] is not None:
        img_net.load_state_dict(model_states['img_net'])
    img_net.eval()
    shape_net.eval()
    if feature_transform is not None:
        feature_transform.eval()
    
    # Get dimensions from a single batch
    sample_batch = next(iter(test_loader))
    batch_img = sample_batch[0].reshape(-1, 1, sample_batch[0].shape[1], sample_batch[0].shape[2]).to(device)
    batch_const_img = sample_batch[3].reshape(-1, 1, sample_batch[3].shape[1], sample_batch[3].shape[2]).to(device)
    
    with torch.no_grad():
        img_features = img_net(batch_img)
        shape_pred = shape_net(batch_img, batch_const_img, registration=True)
        shape_features = shape_pred[2]
        
        if feature_type == 'transformed' and feature_transform is not None:
            shape_features = feature_transform(shape_features)
        
        shape_features = shape_features.view(shape_features.size(0), -1)
    
    # Get dimensions
    img_dim = img_features.shape[1]
    shape_dim = shape_features.shape[1]
    num_classes = len(torch.unique(sample_batch[-1]))
    hidden_dim = 1024 if feature_type == 'raw' else 512
    
    # Initialize appropriate classifier
    # Initialize appropriate classifier
    if feature_type == 'raw':
        model = CombinedRawClassifier(img_dim, shape_dim, hidden_dim, num_classes).to(device)
    else:
        model = CombinedTransformedClassifier(img_dim, shape_dim, hidden_dim, num_classes).to(device)
    
    model.load_state_dict(model_states['classifier'])
    model.eval()
    
    # Testing
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch_img, _, _, batch_const_img, batch_labels in tqdm(test_loader, desc=f"Testing ({feature_type})"):
            batch_img = batch_img.reshape(batch_img.shape[0], 1, batch_img.shape[1], batch_img.shape[2]).to(device)
            batch_const_img = batch_const_img.reshape(batch_const_img.shape[0], 1, batch_const_img.shape[1], batch_const_img.shape[2]).to(device)
            
            # Extract features
            img_features = img_net(batch_img)
            shape_pred = shape_net(batch_img, batch_const_img, registration=True)
            shape_features = shape_pred[2]
            
            if feature_type == 'transformed' and feature_transform is not None:
                shape_features = feature_transform(shape_features)
            
            shape_features = shape_features.view(shape_features.size(0), -1)
            
            # Get predictions
            outputs = model(img_features, shape_features)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(batch_labels.numpy())
    
    # Calculate metrics
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    metrics = calculate_metrics(all_labels, all_preds)
    
    print(f"\nTest Set Metrics ({feature_type}):")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}")
    print(f"Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}")
    print("\nPer-class Accuracy:")
    for i, acc in enumerate(metrics['per_class_acc']):
        print(f"Class {i}: {acc:.4f}")
    
    # Plot confusion matrix
    plot_confusion_matrix(metrics['conf_matrix'], 
                        os.path.join(save_dir, f'confusion_matrix_{feature_type}.png'))
    
    return metrics

def test_noise_robustness(test_loader, img_net, shape_net, feature_transform=None, 
                         feature_type='raw', model_states=None, 
                         noise_levels=[0, 0.001, 0.005, 0.01, 0.05], 
                         save_dir='./classifier_results'):
    """Test model robustness against different levels of Gaussian noise"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Initialize networks
    if model_states['img_net'] is not None:
        img_net.load_state_dict(model_states['img_net'])
    img_net.eval()
    shape_net.eval()
    if feature_transform is not None:
        feature_transform.eval()
    
    # Get dimensions from a single batch
    sample_batch = next(iter(test_loader))
    batch_img = sample_batch[0].reshape(-1, 1, sample_batch[0].shape[1], sample_batch[0].shape[2]).to(device)
    batch_const_img = sample_batch[3].reshape(-1, 1, sample_batch[3].shape[1], sample_batch[3].shape[2]).to(device)
    
    with torch.no_grad():
        img_features = img_net(batch_img)
        shape_pred = shape_net(batch_img, batch_const_img, registration=True)
        shape_features = shape_pred[2]
        
        if feature_type == 'transformed' and feature_transform is not None:
            shape_features = feature_transform(shape_features)
        
        shape_features = shape_features.view(shape_features.size(0), -1)
    
    # Get dimensions
    img_dim = img_features.shape[1]
    shape_dim = shape_features.shape[1]
    num_classes = len(torch.unique(sample_batch[-1]))
    hidden_dim = 1024 if feature_type == 'raw' else 512
    
    # Initialize appropriate classifier
    if feature_type == 'raw':
        model = CombinedRawClassifier(img_dim, shape_dim, hidden_dim, num_classes).to(device)
    else:
        model = CombinedTransformedClassifier(img_dim, shape_dim, hidden_dim, num_classes).to(device)
    
    model.load_state_dict(model_states['classifier'])
    model.eval()
    
    # Results dictionary to store metrics for each noise level
    noise_results = {}
    
    # Test for each noise level
    for noise_level in noise_levels:
        print(f"\nTesting with noise level: {noise_level}")
        
        all_preds = []
        all_labels = []
        
        # Test with noisy images
        with torch.no_grad():
            for batch_img, _, _, batch_const_img, batch_labels in tqdm(test_loader, desc=f"Testing noise={noise_level}"):
                # Add noise to the images
                batch_img = batch_img.reshape(batch_img.shape[0], 1, batch_img.shape[1], batch_img.shape[2])
                batch_img = add_gaussian_noise(batch_img, noise_level)
                batch_img = batch_img.to(device)
                
                batch_const_img = batch_const_img.reshape(batch_const_img.shape[0], 1, batch_const_img.shape[1], batch_const_img.shape[2]).to(device)
                
                # Extract features
                img_features = img_net(batch_img)
                shape_pred = shape_net(batch_img, batch_const_img, registration=True)
                shape_features = shape_pred[2]
                
                if feature_type == 'transformed' and feature_transform is not None:
                    shape_features = feature_transform(shape_features)
                
                shape_features = shape_features.view(shape_features.size(0), -1)
                
                # Get predictions
                outputs = model(img_features, shape_features)
                _, predicted = torch.max(outputs.data, 1)
                
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(batch_labels.numpy())
        
        # Calculate metrics for this noise level
        metrics = calculate_metrics(np.array(all_labels), np.array(all_preds))
        noise_results[noise_level] = metrics
        
        # Print results for this noise level
        print(f"\nMetrics for noise level {noise_level}:")
        print(f"Accuracy: {metrics['accuracy']:.4f}")
        print(f"Precision: {metrics['precision']:.4f}")
        print(f"Recall: {metrics['recall']:.4f}")
        print(f"F1 Score: {metrics['f1']:.4f}")
        print("\nPer-class Accuracy:")
        for i, acc in enumerate(metrics['per_class_acc']):
            print(f"Class {i}: {acc:.4f}")
        
        # Plot confusion matrix
        plot_confusion_matrix(
            metrics['conf_matrix'], 
            os.path.join(save_dir, f'confusion_matrix_{feature_type}_noise_{noise_level}.png')
        )
    
    # Plot accuracy vs noise level
    plt.figure(figsize=(10, 6))
    noise_levels_list = list(noise_results.keys())
    accuracies = [noise_results[level]['accuracy'] for level in noise_levels_list]
    
    plt.plot(noise_levels_list, accuracies, 'bo-')
    plt.xlabel('Noise Level (σ)')
    plt.ylabel('Accuracy')
    plt.title(f'Model Accuracy vs. Gaussian Noise Level ({feature_type})')
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, f'noise_robustness_{feature_type}.png'))
    plt.close()
    
    return noise_results

# Utility functions
def add_gaussian_noise(images, noise_level):
    """Add Gaussian noise to images"""
    noise = torch.randn_like(images) * noise_level
    noisy_images = images + noise
    return torch.clamp(noisy_images, 0, 1)

def calculate_metrics(y_true, y_pred):
    """Calculate all metrics"""
    accuracy = (y_true == y_pred).mean()
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
    conf_matrix = confusion_matrix(y_true, y_pred)
    per_class_acc = conf_matrix.diagonal() / conf_matrix.sum(axis=1)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'per_class_acc': per_class_acc,
        'conf_matrix': conf_matrix
    }

def plot_confusion_matrix(conf_matrix, save_path=None):
    """Plot and optionally save confusion matrix"""
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    if save_path:
        plt.savefig(save_path)
    plt.close()

# Example usage
if __name__ == "__main__":
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load models
    pretrained_net = torch.load('...').to(device)
    feature_transform = torch.load('...').to(device)
    img_net = ImageFeatureExtractor(finetune=True).to(device)
    
    save_dir = './saved_models'
    os.makedirs(save_dir, exist_ok=True)
    
    # Train and test with raw shape features
    print("\nTraining with raw shape features...")
    model_raw, img_net_raw, states_raw = train_combined_classifier(
        train_loader, 
        val_loader,
        img_net=img_net,
        shape_net=pretrained_net,
        feature_transform=None,
        feature_type='raw',
        finetune_resnet=True,
        save_dir=save_dir
    )
    
    print("\nTesting with raw shape features...")
    metrics_raw = test_combined_classifier(
        test_loader,
        img_net=img_net_raw,
        shape_net=pretrained_net,
        feature_transform=None,
        feature_type='raw',
        model_states=states_raw,
        save_dir=save_dir
    )
    
    print("\nTesting noise robustness with raw shape features...")
    noise_results_raw = test_noise_robustness(
        test_loader,
        img_net=img_net_raw,
        shape_net=pretrained_net,
        feature_transform=None,
        feature_type='raw',
        model_states=states_raw,
        save_dir=save_dir
    )
    
    # Train and test with transformed shape features
    print("\nTraining with transformed shape features...")
    model_transformed, img_net_transformed, states_transformed = train_combined_classifier(
        train_loader, 
        val_loader,
        img_net=img_net,
        shape_net=pretrained_net,
        feature_transform=feature_transform,
        feature_type='transformed',
        finetune_resnet=True,
        save_dir=save_dir
    )
    
    print("\nTesting with transformed shape features...")
    metrics_transformed = test_combined_classifier(
        test_loader,
        img_net=img_net_transformed,
        shape_net=pretrained_net,
        feature_transform=feature_transform,
        feature_type='transformed',
        model_states=states_transformed,
        save_dir=save_dir
    )
    
    print("\nTesting noise robustness with transformed shape features...")
    noise_results_transformed = test_noise_robustness(
        test_loader,
        img_net=img_net_transformed,
        shape_net=pretrained_net,
        feature_transform=feature_transform,
        feature_type='transformed',
        model_states=states_transformed,
        save_dir=save_dir
    )