In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
import numpy as np
from skimage.filters import gabor_kernel
import copy
from PIL import Image
import matplotlib.pyplot as plt

RuntimeError: operator torchvision::nms does not exist

In [None]:
# --- FINAL, IMPROVED MODEL DEFINITION WITH CONCATENATION + SEPARATE HEADS ---
class SkipConvNet(nn.Module):
    def __init__(self):
        super(SkipConvNet, self).__init__()
        # Shared early layer
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        
        # Middle path for high-level features (will be noisy)
        self.middle_layers = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1)
        )
        
        # Skip path for low-level features (will stay clean)
        self.match_channels = nn.Conv2d(16, 32, kernel_size=1)
        
        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

        # --- Task-Specific Heads ---
        # Both heads now take the full concatenated feature vector as input
        # Input size is (32 channels from middle + 32 from skip) * 16 * 16
        fc_input_size = 64 * 16 * 16
        
        # Head for CIFAR-10 (high-level task)
        self.cifar_fc1 = nn.Linear(fc_input_size, 128)
        self.cifar_fc2 = nn.Linear(128, 10)

        # Head for Gabor Orientation (low-level task)
        self.gabor_fc1 = nn.Linear(fc_input_size, 128)
        self.gabor_fc2 = nn.Linear(128, 12) # 12 orientation classes

    def forward(self, x):
        out_conv1 = self.relu(self.conv1(x))
        
        # Process middle and skip paths
        out_middle = self.relu(self.middle_layers(out_conv1))
        skip_out = self.match_channels(out_conv1)

        # Concatenate the clean and (potentially noisy) paths
        out_combined = torch.cat((out_middle, skip_out), 1)

        # Pool the combined feature map
        pooled_combined = self.pool(out_combined)

        # Flatten the result for the fully-connected layers
        flat_combined = pooled_combined.view(-1, 64 * 16 * 16)
        
        # Process through separate, dedicated heads using the same combined input
        cifar_out = self.cifar_fc2(self.relu(self.cifar_fc1(flat_combined)))
        gabor_out = self.gabor_fc2(self.relu(self.gabor_fc1(flat_combined)))

        # Return both outputs
        return cifar_out, gabor_out

# --- Noise injection function ---
def add_noise_to_middle_layers(model, noise_level=0.5):
    """Injects Gaussian noise ONLY into the middle_layers path."""
    print(f"Injecting noise with level: {noise_level:.2f}")
    with torch.no_grad():
        for layer in model.middle_layers:
            if isinstance(layer, nn.Conv2d):
                noise = torch.randn_like(layer.weight.data) * noise_level
                layer.weight.data += noise
    return model

# --- 1. Gabor Filter Dataset Generation ---
class GaborDataset(Dataset):
    def __init__(self, num_samples=10000, img_size=32, orientations=12, frequencies=(0.04, 0.08)):
        self.num_samples = num_samples
        self.img_size = img_size
        self.orientations = np.linspace(0, np.pi, orientations, endpoint=False)
        self.frequencies = frequencies
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        orientation_idx = np.random.randint(0, len(self.orientations))
        theta = self.orientations[orientation_idx]
        freq = np.random.choice(self.frequencies)
        kernel = gabor_kernel(frequency=freq, theta=theta, sigma_x=15, sigma_y=15)
        gabor_img_real = (kernel.real - np.min(kernel.real)) / (np.max(kernel.real) - np.min(kernel.real)) * 255
        gabor_img_real = gabor_img_real.astype(np.uint8)
        gabor_img_rgb = np.stack([gabor_img_real] * 3, axis=-1)
        gabor_img_pil = Image.fromarray(gabor_img_rgb)
        if self.transform:
            gabor_img_transformed = self.transform(gabor_img_pil)
        label = orientation_idx
        return gabor_img_transformed, label

# --- 2. Model and Finetuning Setup ---
def finetune_model(model, combined_loader, num_epochs=5):
    """Finetunes the model on the combined CIFAR-10 and Gabor dataset."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0005)
    print("\nStarting finetuning with separate heads...")
    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(combined_loader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            cifar_outputs, gabor_outputs = model(inputs)
            
            # Identify data type by label range to calculate correct loss
            is_cifar = labels < 10
            is_gabor = labels >= 10
            
            loss = 0
            # Calculate loss only for the relevant head and data
            if is_cifar.any():
                loss += criterion(cifar_outputs[is_cifar], labels[is_cifar])
            
            if is_gabor.any():
                gabor_labels = labels[is_gabor] - 10 # Shift labels 10-21 -> 0-11
                loss += criterion(gabor_outputs[is_gabor], gabor_labels)

            if loss != 0:
                loss.backward()
                optimizer.step()
            
            running_loss += loss.item() if isinstance(loss, torch.Tensor) else loss
            if i % 100 == 99:
                print(f'[Finetune Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
                running_loss = 0.0
    print('Finished Finetuning')
    return model

# --- 3. Evaluation Functions for Multi-Task Model ---
def evaluate_cifar_finetuned(model, testloader):
    """Evaluates the CIFAR-10 head."""
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            cifar_outputs, _ = model(images) # Only use CIFAR output
            _, predicted = torch.max(cifar_outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total if total > 0 else 0
    return accuracy

def evaluate_gabor_finetuned(model, gabor_test_loader):
    """Evaluates the Gabor head."""
    correct = 0
    total = 0
    with torch.no_grad():
        for data in gabor_test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            _, gabor_outputs = model(images) # Only use Gabor output
            _, predicted = torch.max(gabor_outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total if total > 0 else 0
    return accuracy

# --- 4. Function to test robustness and plot results ---
def test_and_plot_robustness(model, cifar_loader, gabor_loader, noise_levels):
    """
    Evaluates model performance on both tasks across a range of noise levels
    and plots the results.
    """
    cifar_accuracies = []
    gabor_accuracies = []

    print("\n--- Testing Robustness Across Noise Levels ---")
    for level in noise_levels:
        noisy_model = copy.deepcopy(model)
        noisy_model = add_noise_to_middle_layers(noisy_model, noise_level=level)
        cifar_acc = evaluate_cifar_finetuned(noisy_model, cifar_loader)
        gabor_acc = evaluate_gabor_finetuned(noisy_model, gabor_loader)
        cifar_accuracies.append(cifar_acc)
        gabor_accuracies.append(gabor_acc)

    # Plotting absolute accuracies
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(noise_levels, cifar_accuracies, 'o-', label='CIFAR-10 Classification', color='blue')
    plt.plot(noise_levels, gabor_accuracies, 's-', label='Gabor Orientation Classification', color='red')
    plt.title('Model Performance vs. Noise Level')
    plt.xlabel('Noise Level')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    plt.ylim(0, 100)
    
    # Plotting performance drop
    pristine_cifar_acc = cifar_accuracies[0]
    pristine_gabor_acc = gabor_accuracies[0]
    cifar_drop = [pristine_cifar_acc - acc for acc in cifar_accuracies]
    gabor_drop = [pristine_gabor_acc - acc for acc in gabor_accuracies]
    plt.subplot(1, 2, 2)
    plt.plot(noise_levels, cifar_drop, 'o-', label='CIFAR-10 Drop', color='blue')
    plt.plot(noise_levels, gabor_drop, 's-', label='Gabor Drop', color='red')
    plt.title('Performance Drop vs. Noise Level')
    plt.xlabel('Noise Level')
    plt.ylabel('Accuracy Drop (%)')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# --- 5. Main Execution ---
if __name__ == "__main__":
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    transform_cifar = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    cifar_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)
    cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)
    cifar_test_loader = DataLoader(cifar_testset, batch_size=64, shuffle=False)

    gabor_trainset = GaborDataset(num_samples=20000, orientations=12)
    gabor_testset = GaborDataset(num_samples=5000, orientations=12)
    gabor_test_loader = DataLoader(gabor_testset, batch_size=64, shuffle=False)

    class GaborFinetuneDataset(Dataset):
        def __init__(self, gabor_dataset):
            self.gabor_dataset = gabor_dataset
        def __len__(self):
            return len(self.gabor_dataset)
        def __getitem__(self, idx):
            img, label = self.gabor_dataset[idx]
            return img, label + 10 # Offset labels for combined training
            
    gabor_finetune_trainset = GaborFinetuneDataset(gabor_trainset)
    combined_trainset = ConcatDataset([cifar_trainset, gabor_finetune_trainset])
    combined_loader = DataLoader(combined_trainset, batch_size=64, shuffle=True)
    
    model = SkipConvNet().to(device)
    
    FINETUNED_MODEL_PATH = './finetuned_multitask_model_concat_heads.pth'
    try:
        model.load_state_dict(torch.load(FINETUNED_MODEL_PATH, map_location=device))
        print("Loaded pre-trained finetuned model with concat + split heads.")
    except (FileNotFoundError, RuntimeError):
        print("No compatible finetuned model found. Starting finetuning from scratch...")
        model = finetune_model(model, combined_loader, num_epochs=5)
        torch.save(model.state_dict(), FINETUNED_MODEL_PATH)

    print("\n--- Evaluating Pristine Finetuned Model ---")
    pristine_cifar_acc = evaluate_cifar_finetuned(model, cifar_test_loader)
    pristine_gabor_acc = evaluate_gabor_finetuned(model, gabor_test_loader)
    print(f"Pristine Model - CIFAR-10 Accuracy: {pristine_cifar_acc:.2f}%")
    print(f"Pristine Model - Gabor Orientation Accuracy: {pristine_gabor_acc:.2f}%")
    
    noise_levels_to_test = np.linspace(0, 0.5, 11) 
    test_and_plot_robustness(model, cifar_test_loader, gabor_test_loader, noise_levels_to_test)