# **Attention Enhanced ResNet for CIFAR-10 Model Evaluation Notebook**

### **A project by:**
- **Karthik Krapa (kk5754)**
- **Krish Murjani (km6520)**
- **Pratham Saraf (ps5218)**

This notebook focuses on evaluating and analyzing our trained model on the CIFAR-10 test dataset. We'll measure performance metrics, visualize predictions, and explore model behavior.

## Setup and Environment Configuration

We begin by importing necessary libraries for:
- Data handling and manipulation (NumPy, Pandas, PIL)
- Deep learning framework (PyTorch)
- Visualization tools
- Utility functions for efficient processing

The code also configures the optimal computing device and sets random seeds for reproducible results across runs.

In [None]:
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm
import os
import gc

device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
print(f"Using device: {device}")

torch.manual_seed(42)
torch.cuda.manual_seed(42)
np.random.seed(42)

Using device: cuda


## Model Architecture Components 

### Spatial Attention Module

Loading the spatial attention mechanism from our training architecture. This module:

- Focuses on important spatial regions within feature maps
- Generates attention maps by combining channel-wise average and maximum activations
- Uses a convolutional layer followed by sigmoid activation to create a spatial attention mask
- Applies this mask to the input features, emphasizing important regions

This attention mechanism helps the model focus on discriminative spatial locations in the input, enhancing feature representation without significantly increasing computational complexity.

In [None]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        attention = torch.cat([avg_out, max_out], dim=1)
        attention = self.conv(attention)
        attention = self.sigmoid(attention)
        
        return x * attention

## Attention Residual Block

Reconstructing the core building block of our network architecture. This block combines:

1. **Standard Residual Pathway**:
  - Two convolutional layers with batch normalization and ReLU
  - Skip connection to facilitate gradient flow during training

2. **Dual Attention Mechanism**:
  - Channel attention: Captures interdependencies between feature channels
    - Uses both average and max pooling for comprehensive feature aggregation
    - Employs a bottleneck design with reduction ratio for efficiency
  - Spatial attention: Emphasizes important regions in the feature maps
    - Implemented using the previously defined SpatialAttention module

3. **Adaptive Skip Connection**:
  - Identity mapping when dimensions match
  - 1×1 convolution with batch normalization for dimension matching

This hybrid attention-enhanced residual block provides the network with the ability to focus on both what (channels) and where (spatial locations) is important in the feature representation.

In [None]:
class AttentionResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(AttentionResidualBlock, self).__init__()
        
        # Main path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Channel attention
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        reduction = max(out_channels // 16, 4)  
        self.channel_attention = nn.Sequential(
            nn.Conv2d(out_channels, out_channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // reduction, out_channels, 1, bias=False),
            nn.Sigmoid()
        )
        
        # Spatial attention
        self.spatial_attention = SpatialAttention(kernel_size=7)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = x
        
        # Main path
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Channel attention
        avg_out = self.channel_attention(self.avg_pool(out))
        max_out = self.channel_attention(self.max_pool(out))
        out = out * (avg_out + max_out)
        
        # Spatial attention
        out = self.spatial_attention(out)
        
        # Residual connection
        identity = self.shortcut(identity)
        out += identity
        out = self.relu(out)
        
        return out

## Enhanced Efficient ResNet Model

Reconstructing our complete model architecture for evaluation. This custom ResNet variant features:

1. **Network Structure**:
  - Initial 3×3 convolution maintaining spatial dimensions
  - Four stages of attention-enhanced residual blocks:
    - Layer 1: 2 blocks (base_width → base_width*2)
    - Layer 2: 2 blocks with downsampling (base_width*2 → base_width*4)
    - Layer 3: 2 blocks with downsampling (base_width*4 → base_width*8)
    - Layer 4: 2 blocks maintaining feature map size (base_width*8 → base_width*8)
  - Classification head with global average pooling, dropout (30%), and fully connected layer

2. **Design Considerations**:
  - Base width of 32 channels (slightly different from training)
  - Strategic downsampling placement to balance spatial information and feature abstraction
  - Consistent use of attention mechanisms throughout the network

This architecture balances model size, computational efficiency, and representational power for the CIFAR-10 classification task.

In [None]:
class EnhancedEfficientResNet(nn.Module):
    def __init__(self, num_classes=10, base_width=32):
        super(EnhancedEfficientResNet, self).__init__()
        
        # Initial convolution
        self.conv1 = nn.Conv2d(3, base_width, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(base_width)
        self.relu = nn.ReLU(inplace=True)
        
        # Layers
        self.layer1 = self._make_layer(base_width, base_width*2, 2, stride=1)
        self.layer2 = self._make_layer(base_width*2, base_width*4, 2, stride=2)
        self.layer3 = self._make_layer(base_width*4, base_width*8, 2, stride=2)
        self.layer4 = self._make_layer(base_width*8, base_width*8, 2, stride=1)  
        
        # Global pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(base_width*8, num_classes)
    
    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        layers = [
            AttentionResidualBlock(in_channels, out_channels, stride)
        ]
        
        for _ in range(1, blocks):
            layers.append(
                AttentionResidualBlock(out_channels, out_channels)
            )
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

## Original Model Architecture

This section includes our earlier version of the model for compatibility with previously saved weights and comparison purposes. The architecture consists of:

### ChannelAttentionBlock
- Similar to the enhanced version but with only channel attention (no spatial attention)
- Uses residual connections and dual-pooling channel attention mechanism
- Maintains the bottleneck design for efficiency in the attention module

### EfficientResNet
- The baseline architecture with three main layers (vs. four in the enhanced version)
- Similar initial convolution and feature extraction pathway
- Differences from the enhanced version:
 - Uses ChannelAttentionBlock instead of AttentionResidualBlock
 - Has one fewer layer (missing layer4)
 - Uses 25% dropout rate (vs. 30% in the enhanced version)
 - May have slightly different parameter counts

Keeping this version allows us to:
1. Load weights from earlier experiments
2. Perform comparative analysis between model iterations
3. Ensure backward compatibility with any saved checkpoints

In [None]:
class ChannelAttentionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ChannelAttentionBlock, self).__init__()
        
        # Main path
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Channel attention
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        reduction = max(out_channels // 16, 4) 
        self.attention = nn.Sequential(
            nn.Conv2d(out_channels, out_channels // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // reduction, out_channels, 1, bias=False),
            nn.Sigmoid()
        )
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = x
        
        # Main path
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Channel attention
        avg_out = self.attention(self.avg_pool(out))
        max_out = self.attention(self.max_pool(out))
        out = out * (avg_out + max_out)
        
        # Residual connection
        identity = self.shortcut(identity)
        out += identity
        out = self.relu(out)
        
        return out

class EfficientResNet(nn.Module):
    def __init__(self, num_classes=10, base_width=32):
        super(EfficientResNet, self).__init__()
        
        # Initial convolution
        self.conv1 = nn.Conv2d(3, base_width, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(base_width)
        self.relu = nn.ReLU(inplace=True)
        
        # Layers
        self.layer1 = self._make_layer(base_width, base_width*2, 2, stride=1)
        self.layer2 = self._make_layer(base_width*2, base_width*4, 2, stride=2)
        self.layer3 = self._make_layer(base_width*4, base_width*8, 2, stride=2)
        
        # Global pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(0.25)
        self.fc = nn.Linear(base_width*8, num_classes)
    
    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        layers = [
            ChannelAttentionBlock(in_channels, out_channels, stride)
        ]
        
        for _ in range(1, blocks):
            layers.append(
                ChannelAttentionBlock(out_channels, out_channels)
            )
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        
        return x

## Test-Time Augmentation (TTA) Configuration

Test-Time Augmentation is a powerful technique that improves model robustness by averaging predictions across multiple transformed versions of the same input image. Our TTA setup includes:

### Base Normalization
- Using CIFAR-10 dataset statistics (mean and standard deviation) for consistent preprocessing

### Augmentation Ensemble
Eight strategic transformations applied during inference:

1. **Original Image**: Standard preprocessing with normalization only
2. **Horizontal Flip**: Exploits horizontal symmetry in natural images
3. **Crop Variation 1**: Random crop with reflection padding to test positional robustness
4. **Crop Variation 2**: Random crop with edge padding for different boundary handling
5. **Minor Rotation 1**: Small 5° rotation to test orientation robustness
6. **Minor Rotation 2**: Slightly larger 10° rotation for diverse angle perspectives
7. **Color Variation 1**: Brightness, contrast, saturation, and hue adjustments
8. **Color Variation 2**: Alternative color profile with different emphasis on contrast

This diverse set of transformations helps the model account for various image variations that might occur in real-world scenarios, leading to more reliable predictions through ensemble averaging.

In [None]:
test_normalization = transforms.Normalize(
    mean=(0.4914, 0.4822, 0.4465), 
    std=(0.2023, 0.1994, 0.2010)
)

# Advanced test-time augmentation transforms
advanced_transforms = [
    # 1. Original transform 
    transforms.Compose([
        transforms.ToTensor(),
        test_normalization,
    ]),
    # 2. Horizontal flip
    transforms.Compose([
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        test_normalization,
    ]),
    # 3. Small crop 1
    transforms.Compose([
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.ToTensor(),
        test_normalization,
    ]),
    # 4. Small crop 2 
    transforms.Compose([
        transforms.RandomCrop(32, padding=4, padding_mode='edge'),
        transforms.ToTensor(),
        test_normalization,
    ]),
    # 5. Slight rotate 1
    transforms.Compose([
        transforms.RandomRotation(5),
        transforms.ToTensor(),
        test_normalization,
    ]),
    # 6. Slight rotate 2
    transforms.Compose([
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        test_normalization,
    ]),
    # 7. Color jitter
    transforms.Compose([
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02),
        transforms.ToTensor(),
        test_normalization,
    ]),
    # 8. Color jitter 2
    transforms.Compose([
        transforms.ColorJitter(brightness=0.05, contrast=0.15, saturation=0.05, hue=0),
        transforms.ToTensor(),
        test_normalization,
    ]),
]

## Enhanced CIFAR-10 Test Dataset Implementation

This custom Dataset class provides robust handling of test data with several improvements:

### Advanced Features:
- **Flexible Data Loading**: Handles pickle-formatted CIFAR test data with comprehensive error handling
- **Format Adaptability**: Automatically detects and reshapes data between flat (3072-dimensional) and structured (3×32×32) formats
- **Error Resilience**: Implements graceful fallback mechanisms to prevent pipeline failures
- **ID Preservation**: Maintains original image identifiers for accurate submission generation

### Error Handling Strategy:
- Detailed error reporting during data loading and image processing
- Fallback to zero tensors when individual images fail to process
- Proper format conversion between NumPy arrays and PIL Images

This implementation ensures reliable data loading and preprocessing even when dealing with potentially problematic test files, making the evaluation pipeline more robust against unexpected data issues.

In [None]:
class EnhancedCIFARTestDataset(Dataset):
    def __init__(self, pkl_file_path, transform=None):
        """
        Args:
            pkl_file_path (string): Path to the .pkl file containing test data
            transform (callable, optional): Transform to be applied on a sample
        """
        self.transform = transform
        
        try:
            with open(pkl_file_path, 'rb') as f:
                data = pickle.load(f, encoding='bytes')
            
            self.images = data[b'data']
            self.ids = data[b'ids'] if b'ids' in data else np.arange(len(self.images))
            
            if len(self.images.shape) == 2:
                print(f"Reshaping flat images of shape {self.images.shape}")
                self.images = self.images.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
                print(f"Reshaped to {self.images.shape}")
        except Exception as e:
            print(f"Error loading data: {e}")
            raise
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        try:
            image = self.images[idx]
            if not isinstance(image, Image.Image):
                image = Image.fromarray(image.astype('uint8'))
            
            if self.transform:
                image = self.transform(image)
            
            return image, self.ids[idx]
        except Exception as e:
            print(f"Error processing image {idx}: {e}")
            if self.transform:
                return torch.zeros(3, 32, 32), self.ids[idx]
            else:
                return np.zeros((32, 32, 3), dtype=np.uint8), self.ids[idx]

## Enhanced Test-Time Augmentation (TTA) Prediction

This function implements a sophisticated prediction pipeline that leverages multiple augmented views of each test image to produce more robust classifications.

### Key Features:

1. **Weighted Ensemble Averaging**:
  - Assigns double weight (2.0) to predictions from the original image
  - Normalizes weights to ensure proper probability distribution
  - Combines predictions by weighted averaging of softmax probabilities

2. **Temperature Scaling**:
  - Applies softmax temperature scaling (T=0.9) to calibrate confidence scores
  - Lower temperature increases confidence in high-probability predictions
  - Helps balance between confidence and uncertainty in the final ensemble

3. **Memory Optimization**:
  - Employs smaller batch sizes (32) to prevent out-of-memory errors
  - Explicitly releases GPU memory after each augmentation pass
  - Uses garbage collection to free CPU memory during long processing runs

4. **Workflow Management**:
  - Progress tracking with tqdm for each augmentation pipeline
  - Flexible configuration of augmentation count via num_transforms parameter
  - Maintains consistent image IDs across all augmentations for proper alignment

This comprehensive approach produces state-of-the-art results by combining multiple perspectives of each image, significantly improving robustness against variations in the test set.

In [None]:
def enhanced_tta_prediction(model, pkl_file_path, output_filename="enhanced_submission.csv", num_transforms=8):
    """
    Advanced test-time augmentation with weighted averaging of predictions
    """
    print("Starting enhanced TTA prediction...")
    model.eval()
    
    transforms_to_use = advanced_transforms[:num_transforms]
    
    weights = [2.0] 
    weights.extend([1.0] * (len(transforms_to_use) - 1))
    
    weights = [w / sum(weights) for w in weights]
    
    all_probs = []
    image_ids = None
    
    for i, transform in enumerate(tqdm(transforms_to_use, desc="Processing augmentations")):
        dataset = EnhancedCIFARTestDataset(pkl_file_path, transform=transform)
        dataloader = DataLoader(
            dataset, 
            batch_size=32,  
            shuffle=False, 
            num_workers=2, 
            pin_memory=True
        )
        
        batch_probs = []
        batch_ids = []
        
        with torch.no_grad():
            for images, ids in dataloader:
                images = images.to(device)
                outputs = model(images)
                
                outputs = outputs / 0.9  
                
                probs = F.softmax(outputs, dim=1)
                
                batch_probs.append(probs.cpu().numpy())
                batch_ids.append(ids.numpy())
                
                del images, outputs, probs
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
        
        augmentation_probs = np.concatenate(batch_probs)
        
        all_probs.append(augmentation_probs * weights[i])
        
        if image_ids is None:
            image_ids = np.concatenate(batch_ids)
        
        del batch_probs, augmentation_probs
        gc.collect()
    
    avg_probs = np.sum(all_probs, axis=0)
    final_preds = np.argmax(avg_probs, axis=1)
    
    submission_df = pd.DataFrame({
        'ID': image_ids, 
        'Labels': final_preds
    })
    submission_df = submission_df.sort_values('ID')
    submission_df.to_csv(output_filename, index=False)
    print(f"Enhanced TTA submission file created: {output_filename}")
    return submission_df

## Class-Specialized Prediction Strategy

This advanced prediction approach combines efficiency with accuracy by applying different levels of augmentation based on prediction confidence:

### Two-Stage Prediction Pipeline:

1. **Initial Confidence Assessment**:
   - First pass uses only the base transform for all images
   - Classifies each image and records prediction confidence
   - Applies class-specific confidence thresholds based on known difficulty patterns
   
2. **Targeted Enhancement**:
   - Identifies low-confidence predictions that fall below class-specific thresholds
   - Only applies full TTA (all transforms) to these uncertain cases
   - Skips unnecessary processing for already confident predictions
   - Uses temperature scaling (T=1.2) to balance confidence in the ensemble

## Class-Aware Confidence Thresholds:
- Different thresholds for each class based on empirical confusion patterns:
  - Higher thresholds (0.85-0.90) for visually distinct classes (airplane, automobile, ship, truck)
  - Lower thresholds (0.70-0.75) for challenging classes (cat, dog, bird, deer)

This hybrid approach significantly reduces computational overhead compared to full TTA on all samples while maintaining high accuracy by focusing augmentation resources on the most uncertain predictions.

In [None]:
def class_specialized_prediction(model, pkl_file_path, output_filename="specialized_submission.csv"):
    """
    Creates predictions with specialized handling for different classes based on confidence thresholds
    """
    print("Starting class-specialized prediction...")
    model.eval()
    
    base_transform = advanced_transforms[0]
    dataset = EnhancedCIFARTestDataset(pkl_file_path, transform=base_transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
    
    initial_preds = []
    confidence_scores = []
    image_ids = []
    
    with torch.no_grad():
        for images, ids in tqdm(dataloader, desc="Initial prediction pass"):
            images = images.to(device)
            outputs = model(images)
            
            probs = F.softmax(outputs, dim=1)
            
            values, preds = torch.max(probs, dim=1)
            
            initial_preds.extend(preds.cpu().numpy())
            confidence_scores.extend(values.cpu().numpy())
            image_ids.extend(ids.numpy())
            
            del images, outputs, probs, values, preds
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Confidence thresholds for different classes
    class_confidence_thresholds = {
        0: 0.85,  # airplane
        1: 0.90,  # automobile
        2: 0.75,  # bird - still relatively difficult
        3: 0.70,  # cat - difficult class
        4: 0.75,  # deer
        5: 0.70,  # dog - difficult class
        6: 0.85,  # frog
        7: 0.85,  # horse
        8: 0.90,  # ship
        9: 0.90,  # truck
    }
    
    low_conf_indices = []
    for i, (pred, conf) in enumerate(zip(initial_preds, confidence_scores)):
        if conf < class_confidence_thresholds.get(pred, 0.75):
            low_conf_indices.append(i)
    
    print(f"Found {len(low_conf_indices)} low confidence predictions ({len(low_conf_indices)/len(initial_preds)*100:.2f}%)")
    
    final_preds = list(initial_preds)
    
    if low_conf_indices:
        low_conf_probs = []
        
        for transform in tqdm(advanced_transforms, desc="Processing difficult cases"):
            dataset = EnhancedCIFARTestDataset(pkl_file_path, transform=transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
            
            all_outputs = []
            
            with torch.no_grad():
                for batch_idx, (images, _) in enumerate(dataloader):
                    batch_start = batch_idx * 64
                    batch_end = min(batch_start + 64, len(dataset))
                    batch_indices = list(range(batch_start, batch_end))
                    
                    process_batch = any(idx in low_conf_indices for idx in batch_indices)
                    
                    if process_batch:
                        images = images.to(device)
                        outputs = model(images)
                        all_outputs.append(outputs.cpu())
                    else:
                        all_outputs.append(torch.zeros(len(batch_indices), 10))
            
            all_outputs = torch.cat(all_outputs)
            
            selected_outputs = all_outputs[low_conf_indices]
            selected_probs = F.softmax(selected_outputs / 1.2, dim=1).numpy()  
            
            low_conf_probs.append(selected_probs)
        

            del all_outputs, selected_outputs, selected_probs
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        
        avg_probs = np.mean(np.stack(low_conf_probs), axis=0)
        improved_preds = np.argmax(avg_probs, axis=1)
        
        for i, idx in enumerate(low_conf_indices):
            final_preds[idx] = improved_preds[i]
    
    submission_df = pd.DataFrame({
        'ID': image_ids, 
        'Labels': final_preds
    })
    submission_df = submission_df.sort_values('ID')
    submission_df.to_csv(output_filename, index=False)
    print(f"Class-specialized submission file created: {output_filename}")
    return submission_df

## Adaptive Multi-Strategy Prediction System

This section implements our most sophisticated prediction approach - a hybrid system that adaptively combines multiple prediction strategies to maximize accuracy.

### Four-Phase Pipeline:

1. **Model Loading with Fail-safe Mechanism**:
  - Attempts to load using the enhanced architecture first
  - Falls back to original architecture if needed
  - Includes additional fallback options for different checkpoint formats
  - Handles strict and non-strict loading for maximum compatibility

2. **Dual Prediction Generation**:
  - Performs full enhanced TTA with all eight transforms
  - Executes class-specialized prediction with targeted augmentation
  - Saves intermediate results to temporary files

3. **Confidence-Based Integration**:
  - Computes confidence scores for all predictions
  - Applies class-specific confidence boosting for challenging categories (birds, cats, dogs)
  - Uses tiered confidence thresholds (HIGH_CONF: 0.95, MED_CONF: 0.80)

4. **Intelligent Decision Logic**:
  - Uses agreement between methods as primary signal
  - Prioritizes predictions with very high confidence
  - Handles disagreements based on confidence differentials
  - Implements special handling for traditionally difficult classes

This adaptive approach combines the strengths of multiple prediction strategies, producing more reliable results than any single method alone, particularly for edge cases and challenging images.

In [None]:
def adaptive_prediction(model_path, pkl_file_path, output_filename="adaptive_submission.csv"):
    """
    Creates predictions using an adaptive approach that combines multiple techniques
    """
    print("Starting adaptive prediction process...")
    
    try:
        model = EnhancedEfficientResNet(num_classes=10)
        model.load_state_dict(torch.load(model_path, map_location=device))
        print("Successfully loaded model with enhanced architecture")
    except:
        try:
            model = EfficientResNet(num_classes=10)
            model.load_state_dict(torch.load(model_path, map_location=device))
            print("Successfully loaded model with original architecture")
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Attempting alternative loading method...")
            
            model = EnhancedEfficientResNet(num_classes=10)
            state_dict = torch.load(model_path, map_location=device)
            if isinstance(state_dict, dict) and 'state_dict' in state_dict:
                model.load_state_dict(state_dict['state_dict'])
            else:
                model.load_state_dict(state_dict, strict=False)
                print("Warning: Model loaded with strict=False, some weights may not be loaded")
    
    model = model.to(device)
    model.eval()
    
    # Step 1: Run enhanced TTA
    print("\nStep 1: Running enhanced TTA prediction...")
    enhanced_tta_prediction(model, pkl_file_path, "tmp_tta.csv", num_transforms=8)
    
    # Step 2: Run class-specialized prediction
    print("\nStep 2: Running class-specialized prediction...")
    class_specialized_prediction(model, pkl_file_path, "tmp_spec.csv")
    
    # Step 3: Combine predictions based on confidence
    print("\nStep 3: Combining predictions adaptively...")

    tta_df = pd.read_csv("tmp_tta.csv")
    spec_df = pd.read_csv("tmp_spec.csv")
    
    tta_df = tta_df.sort_values('ID')
    spec_df = spec_df.sort_values('ID')
    
    base_transform = advanced_transforms[0]
    dataset = EnhancedCIFARTestDataset(pkl_file_path, transform=base_transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
    
    all_probs = []
    all_ids = []
    
    with torch.no_grad():
        for images, ids in tqdm(dataloader, desc="Computing confidence scores"):
            images = images.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            all_probs.append(probs.cpu().numpy())
            all_ids.extend(ids.numpy())
    
    all_probs = np.concatenate(all_probs)
    id_to_index = {id_val: i for i, id_val in enumerate(all_ids)}
    
    final_labels = []
    
    HIGH_CONF = 0.95 
    MED_CONF = 0.80   
    
    class_boost = {
        2: 0.05,  # bird
        3: 0.05,  # cat
        5: 0.05,  # dog
    }
    
    for i in range(len(tta_df)):
        img_id = tta_df.iloc[i]['ID']
        tta_pred = tta_df.iloc[i]['Labels']
        spec_pred = spec_df.loc[spec_df['ID'] == img_id, 'Labels'].values[0]
        
        idx = id_to_index[img_id]
        prob_vector = all_probs[idx]
        
        tta_conf = prob_vector[tta_pred]
        if tta_pred in class_boost:
            tta_conf += class_boost[tta_pred]
            
        spec_conf = prob_vector[spec_pred]
        if spec_pred in class_boost:
            spec_conf += class_boost[spec_pred]
        
        # Decision logic
        if tta_pred == spec_pred:
            final_labels.append(tta_pred)
        elif tta_conf > HIGH_CONF:
            final_labels.append(tta_pred)
        elif spec_conf > HIGH_CONF:
            final_labels.append(spec_pred)
        elif tta_conf > MED_CONF and tta_conf > spec_conf:
            final_labels.append(tta_pred)
        elif spec_conf > MED_CONF and spec_conf > tta_conf:
            final_labels.append(spec_pred)
        elif tta_pred in [2, 3, 5]:
            final_labels.append(spec_pred)
        else:
            final_labels.append(tta_pred)
    
    submission_df = pd.DataFrame({
        'ID': tta_df['ID'],
        'Labels': final_labels
    })
    submission_df.to_csv(output_filename, index=False)
    
    if os.path.exists("tmp_tta.csv"):
        os.remove("tmp_tta.csv")
    if os.path.exists("tmp_spec.csv"):
        os.remove("tmp_spec.csv")
    
    print(f"Adaptive submission file created: {output_filename}")
    return submission_df

## Multi-Method Ensemble Generation

This function creates a diverse ensemble of predictions using different strategies to maximize our chances of correct classification. Ensemble methods improve robustness by leveraging the "wisdom of crowds" principle where different approaches compensate for each other's weaknesses.

### Generated Prediction Set:

1. **Enhanced TTA Prediction**:
  - Implements test-time augmentation with a reduced set of 6 transforms
  - Uses different transform count than the adaptive method (8) to increase prediction diversity
  - Provides strong baseline predictions with well-balanced augmentation

2. **Class-Specialized Prediction**:
  - Focuses computational resources on difficult cases
  - Uses class-specific confidence thresholds
  - Handles challenging classes with targeted processing

3. **Adaptive Fusion Prediction**:
  - Combines the strengths of both methods
  - Uses sophisticated decision logic with confidence thresholds
  - Applies special handling for known difficult classes

This diverse set of predictions creates a foundation for final submission selection or voting-based ensemble combination, allowing us to select the most effective approach based on validation performance or combine them for further accuracy improvements.

In [None]:
def generate_multiple_predictions(model_path, pkl_file_path, base_filename="submission"):
    """
    Generates multiple prediction files using different methods for final ensemble
    """
    print("Generating multiple predictions for ensemble...")
    
    try:
        model = EnhancedEfficientResNet(num_classes=10)
        model.load_state_dict(torch.load(model_path, map_location=device))
    except:
        model = EfficientResNet(num_classes=10)
        model.load_state_dict(torch.load(model_path, map_location=device))
    
    model = model.to(device)
    model.eval()
    
    enhanced_tta_prediction(
        model, 
        pkl_file_path, 
        f"{base_filename}_tta.csv",
        num_transforms=6  
    )
    
    class_specialized_prediction(
        model,
        pkl_file_path,
        f"{base_filename}_spec.csv"
    )
    
    adaptive_prediction(
        model_path,
        pkl_file_path,
        f"{base_filename}_adaptive.csv"
    )
    
    print(f"Generated 3 prediction files for ensemble submission")

## Ensemble Voting System

This function implements a democratic voting ensemble to combine multiple prediction files into a final consensus prediction. Ensemble methods are a proven technique in machine learning that can significantly improve accuracy by reducing individual model biases.

### Ensemble Process:

1. **Input Validation**:
  - Loads all prediction files and ensures consistent sorting by ID
  - Performs strict verification that all files contain identical image IDs
  - Prepares predictions for majority voting

2. **Majority Voting Algorithm**:
  - Stacks predictions from all methods for each sample
  - Determines the most frequent class prediction for each image
  - Resolves ties by selecting the first occurrence of the maximum count

3. **Result Generation**:
  - Creates a consolidated submission file with consensus predictions
  - Maintains the original image ID ordering

This majority voting approach counteracts individual method weaknesses by leveraging collective intelligence - when multiple methods agree on a prediction, it's more likely to be correct, while disagreements are resolved democratically based on the most common prediction.

In [None]:
def ensemble_prediction_files(file_paths, output_filename="ensemble_submission.csv"):
    """
    Ensembles multiple prediction CSV files
    """
    print(f"Creating ensemble from {len(file_paths)} prediction files...")
    
    dataframes = []
    for file_path in file_paths:
        df = pd.read_csv(file_path)
        df = df.sort_values('ID')
        dataframes.append(df)
    
    for i in range(1, len(dataframes)):
        assert np.array_equal(dataframes[0]['ID'].values, dataframes[i]['ID'].values), "ID mismatch between files"
    
    all_preds = np.array([df['Labels'].values for df in dataframes])
    
    final_preds = []
    for i in range(len(dataframes[0])):
        sample_preds = all_preds[:, i]
        values, counts = np.unique(sample_preds, return_counts=True)
        max_count_idx = np.argmax(counts)
        final_preds.append(values[max_count_idx])
    
    ensemble_df = pd.DataFrame({
        'ID': dataframes[0]['ID'],
        'Labels': final_preds
    })
    ensemble_df.to_csv(output_filename, index=False)
    print(f"Ensemble submission file created: {output_filename}")
    return ensemble_df

## Prediction Analysis Tool

This utility function helps us understand and visualize differences between our various prediction methods, providing insights into model behavior and potential areas for improvement.

### Analysis Components:

1. **Agreement Rate Calculation**:
  - Measures the percentage of images where all prediction methods agree
  - Higher agreement rates generally indicate more confident and reliable predictions

2. **Class-Specific Disagreement Analysis**:
  - Identifies which classes experience the most prediction inconsistency
  - Helps pinpoint challenging categories where models struggle to reach consensus
  - Uses CIFAR-10 class names for intuitive reporting

3. **Diagnostic Insights**:
  - Allows identification of systematic weaknesses across different prediction strategies
  - Provides guidance for method selection or targeted improvements

This analysis is crucial for understanding ensemble dynamics and can inform decisions about which prediction strategies to prioritize or how to weight different methods in the final ensemble.

In [None]:
def analyze_prediction_differences(file_paths, class_names=None):
    """
    Analyzes and visualizes differences between prediction files
    """
    if class_names is None:
        class_names = [
            'airplane', 'automobile', 'bird', 'cat', 'deer',
            'dog', 'frog', 'horse', 'ship', 'truck'
        ]
    
    dataframes = []
    for file_path in file_paths:
        df = pd.read_csv(file_path)
        df = df.sort_values('ID')
        dataframes.append(df)
    
    agreement_count = 0
    class_disagreements = {i: 0 for i in range(len(class_names))}
    
    for i in range(len(dataframes[0])):
        preds = [df.iloc[i]['Labels'] for df in dataframes]
        if len(set(preds)) == 1:
            agreement_count += 1
        else:
            for pred in preds:
                class_disagreements[pred] += 1
    
    agreement_rate = agreement_count / len(dataframes[0]) * 100
    
    print(f"Agreement rate between prediction files: {agreement_rate:.2f}%")
    print("\nDisagreements by class:")
    for class_idx, count in class_disagreements.items():
        print(f"{class_names[class_idx]}: {count} disagreements")
    
    return agreement_rate, class_disagreements

## Main Execution Pipeline

This function orchestrates the complete prediction workflow for generating optimal CIFAR-10 test set submissions:

### Multi-Stage Process:

1. **Multiple Prediction Generation**:
  - Creates three distinct prediction files using different strategies
  - Builds a diverse foundation for the ensemble system

2. **Ensemble Combination**:
  - Forms a consensus prediction using majority voting across all methods
  - Creates "ensemble_submission.csv" as a robust prediction option

3. **Adaptive Single-Pass Prediction**:
  - Generates an optimized prediction using the adaptive hybrid approach
  - Creates "final_submission.csv" as our primary recommendation

4. **Comparative Analysis**:
  - Evaluates agreement rates between all prediction methods
  - Identifies classes with highest prediction divergence
  - Provides insights into model confidence and potential weaknesses

5. **Recommendation System**:
  - Offers a primary and alternative submission option
  - Leverages insights from all approaches to maximize final accuracy

This comprehensive pipeline maximizes our chances of achieving the highest possible accuracy by leveraging multiple prediction strategies and providing data-driven submission recommendations.

In [None]:
def main(model_path, pkl_file_path):
    print(f"Starting prediction process with model: {model_path}")
    print(f"Test data: {pkl_file_path}")
    
    # Method 1: Generate 3 different prediction files
    generate_multiple_predictions(model_path, pkl_file_path)
    
    # Method 2: Create ensemble from the 3 prediction files
    ensemble_prediction_files(
        [
            "submission_tta.csv",
            "submission_spec.csv",
            "submission_adaptive.csv"
        ],
        "ensemble_submission.csv"
    )
    
    # Method 3: Create adaptive prediction directly
    adaptive_prediction(model_path, pkl_file_path, "final_submission.csv")
    
    print("\nAnalyzing prediction differences:")
    analyze_prediction_differences([
        "submission_tta.csv",
        "submission_spec.csv",
        "submission_adaptive.csv",
        "ensemble_submission.csv",
        "final_submission.csv"
    ])
    
    print("\nRecommended submission file: final_submission.csv")
    print("Alternative submission file: ensemble_submission.csv")

In [None]:
if __name__ == "__main__":
    MODEL_PATH = '/kaggle/input/best_ema_model_500epoch_95.30/pytorch/default/1/best_ema_model.pth'
    TEST_DATA_PATH = '/kaggle/input/deep-learning-spring-2025-project-1/cifar_test_nolabel.pkl'

main(MODEL_PATH, TEST_DATA_PATH)

Starting prediction process with model: /kaggle/input/best_ema_model_500epoch_95.30/pytorch/default/1/best_ema_model.pth
Test data: /kaggle/input/deep-learning-spring-2025-project-1/cifar_test_nolabel.pkl
Generating multiple predictions for ensemble...


  model.load_state_dict(torch.load(model_path, map_location=device))


Starting enhanced TTA prediction...


Processing augmentations: 100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


Enhanced TTA submission file created: submission_tta.csv
Starting class-specialized prediction...


Initial prediction pass: 100%|██████████| 157/157 [00:03<00:00, 45.42it/s]


Found 6818 low confidence predictions (68.18%)


Processing difficult cases: 100%|██████████| 8/8 [00:30<00:00,  3.77s/it]
  model.load_state_dict(torch.load(model_path, map_location=device))


Class-specialized submission file created: submission_spec.csv
Starting adaptive prediction process...
Successfully loaded model with enhanced architecture

Step 1: Running enhanced TTA prediction...
Starting enhanced TTA prediction...


Processing augmentations: 100%|██████████| 8/8 [00:40<00:00,  5.01s/it]


Enhanced TTA submission file created: tmp_tta.csv

Step 2: Running class-specialized prediction...
Starting class-specialized prediction...


Initial prediction pass: 100%|██████████| 157/157 [00:03<00:00, 41.82it/s]


Found 6818 low confidence predictions (68.18%)


Processing difficult cases: 100%|██████████| 8/8 [00:29<00:00,  3.73s/it]


Class-specialized submission file created: tmp_spec.csv

Step 3: Combining predictions adaptively...


Computing confidence scores: 100%|██████████| 157/157 [00:03<00:00, 51.39it/s]


Adaptive submission file created: submission_adaptive.csv
Generated 3 prediction files for ensemble submission
Creating ensemble from 3 prediction files...
Ensemble submission file created: ensemble_submission.csv
Starting adaptive prediction process...


  model.load_state_dict(torch.load(model_path, map_location=device))


Successfully loaded model with enhanced architecture

Step 1: Running enhanced TTA prediction...
Starting enhanced TTA prediction...


Processing augmentations: 100%|██████████| 8/8 [00:39<00:00,  4.98s/it]


Enhanced TTA submission file created: tmp_tta.csv

Step 2: Running class-specialized prediction...
Starting class-specialized prediction...


Initial prediction pass: 100%|██████████| 157/157 [00:03<00:00, 41.95it/s]


Found 6818 low confidence predictions (68.18%)


Processing difficult cases: 100%|██████████| 8/8 [00:30<00:00,  3.79s/it]


Class-specialized submission file created: tmp_spec.csv

Step 3: Combining predictions adaptively...


Computing confidence scores: 100%|██████████| 157/157 [00:03<00:00, 51.17it/s]


Adaptive submission file created: final_submission.csv

Analyzing prediction differences:
Agreement rate between prediction files: 97.16%

Disagreements by class:
airplane: 171 disagreements
automobile: 123 disagreements
bird: 177 disagreements
cat: 252 disagreements
deer: 97 disagreements
dog: 159 disagreements
frog: 94 disagreements
horse: 98 disagreements
ship: 110 disagreements
truck: 139 disagreements

Recommended submission file: final_submission.csv
Alternative submission file: ensemble_submission.csv
