# Approach 1: Transfer Learning Baseline for Multi-Label Aerial Image Classification

This notebook implements a baseline multi-label classification model using transfer learning with pre-trained CNNs.

## Strategy:
- Use pre-trained ResNet50/EfficientNet as feature extractor
- Replace final layer with multi-label classification head
- Train with Binary Cross-Entropy loss
- Evaluate with multi-label metrics (micro/macro F1, mAP)

## Expected Performance: 70-80% F1 score

## 1. Setup and Imports

In [1]:
# Install required packages
!pip install torch torchvision transformers datasets pillow matplotlib scikit-learn tqdm



In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models

from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import f1_score, precision_score, recall_score, hamming_loss, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cpu


## 2. Load and Explore Dataset

In [2]:
# Load dataset from HuggingFace
print("Loading AID_MultiLabel dataset...")
dataset = load_dataset("jonathan-roberts1/AID_MultiLabel")

print(f"\nDataset Info:")
print(f"Number of samples: {dataset['train'].num_rows}")
print(f"Features: {dataset['train'].features}")

# Extract class names
class_names = dataset['train'].features['label'].feature.names
num_classes = len(class_names)

print(f"\nNumber of classes: {num_classes}")
print(f"Class names: {class_names}")

Loading AID_MultiLabel dataset...


NameError: name 'load_dataset' is not defined

In [None]:
# Analyze label distribution
all_labels = dataset['train']['label']

# Count occurrences of each class
class_counts = np.zeros(num_classes)
labels_per_image = []

for label_list in all_labels:
    labels_per_image.append(len(label_list))
    for cls_idx in label_list:
        class_counts[cls_idx] += 1

# Visualize class distribution
fig, axes = plt.subplots(1, 2, figsize=(16, 5))

# Class frequency
axes[0].bar(range(num_classes), class_counts)
axes[0].set_xlabel('Class Index')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Class Distribution in Dataset')
axes[0].set_xticks(range(num_classes))
axes[0].set_xticklabels(range(num_classes), rotation=45)
axes[0].grid(axis='y', alpha=0.3)

# Labels per image
axes[1].hist(labels_per_image, bins=range(1, max(labels_per_image)+2), edgecolor='black')
axes[1].set_xlabel('Number of Labels per Image')
axes[1].set_ylabel('Frequency')
axes[1].set_title('Distribution of Labels per Image')
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nDataset Statistics:")
print(f"Average labels per image: {np.mean(labels_per_image):.2f}")
print(f"Min labels per image: {np.min(labels_per_image)}")
print(f"Max labels per image: {np.max(labels_per_image)}")
print(f"\nMost frequent classes:")
for i in np.argsort(class_counts)[-5:][::-1]:
    print(f"  {class_names[i]}: {int(class_counts[i])} ({class_counts[i]/len(all_labels)*100:.1f}%)")
print(f"\nLeast frequent classes:")
for i in np.argsort(class_counts)[:5]:
    print(f"  {class_names[i]}: {int(class_counts[i])} ({class_counts[i]/len(all_labels)*100:.1f}%)")

## 3. Data Preprocessing and Augmentation

In [None]:
# Define image transformations
# Training: aggressive augmentation for aerial images (rotation-invariant)
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomRotation(degrees=90),  # Aerial images: any rotation is valid
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
])

# Validation/Test: no augmentation, only resize and normalize
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Transformations defined.")
print(f"Training transform: Aggressive augmentation with rotation, flips, color jitter")
print(f"Validation transform: Resize and normalize only")

In [None]:
# Custom Dataset class
class AIDMultiLabelDataset(Dataset):
    def __init__(self, images, labels, num_classes, transform=None):
        """
        Args:
            images: List of PIL images
            labels: List of label indices (multi-label)
            num_classes: Total number of classes
            transform: Torchvision transforms to apply
        """
        self.images = images
        self.labels = labels
        self.num_classes = num_classes
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # Get image
        image = self.images[idx]
        if not isinstance(image, Image.Image):
            image = Image.fromarray(image)
        
        # Convert to RGB if needed
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        
        # Convert labels to multi-hot encoding
        label_vector = torch.zeros(self.num_classes, dtype=torch.float32)
        label_vector[self.labels[idx]] = 1.0
        
        return image, label_vector

print("Custom Dataset class defined.")

In [None]:
# Split dataset: 70% train, 15% validation, 15% test
images = dataset['train']['image']
labels = dataset['train']['label']

# First split: 70% train, 30% temp
X_train, X_temp, y_train, y_temp = train_test_split(
    images, labels, test_size=0.3, random_state=42, shuffle=True
)

# Second split: split temp into 50-50 for val and test (15% each of total)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42, shuffle=True
)

print(f"Dataset split:")
print(f"  Training samples: {len(X_train)} ({len(X_train)/len(images)*100:.1f}%)")
print(f"  Validation samples: {len(X_val)} ({len(X_val)/len(images)*100:.1f}%)")
print(f"  Test samples: {len(X_test)} ({len(X_test)/len(images)*100:.1f}%)")

# Create dataset objects
train_dataset = AIDMultiLabelDataset(X_train, y_train, num_classes, transform=train_transform)
val_dataset = AIDMultiLabelDataset(X_val, y_val, num_classes, transform=val_transform)
test_dataset = AIDMultiLabelDataset(X_test, y_test, num_classes, transform=val_transform)

print("\nDataset objects created successfully.")

In [None]:
# Create data loaders
BATCH_SIZE = 32  # Adjust based on GPU memory

train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False, 
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"Data loaders created:")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

## 4. Model Architecture

In [None]:
class MultiLabelCNN(nn.Module):
    """
    Multi-label classification model using pre-trained CNN backbone.
    """
    def __init__(self, num_classes, backbone='resnet50', pretrained=True, dropout=0.5):
        super(MultiLabelCNN, self).__init__()
        
        self.num_classes = num_classes
        self.backbone_name = backbone
        
        # Load pre-trained backbone
        if backbone == 'resnet50':
            self.backbone = models.resnet50(pretrained=pretrained)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()  # Remove original FC layer
        
        elif backbone == 'resnet34':
            self.backbone = models.resnet34(pretrained=pretrained)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        
        elif backbone == 'resnet101':
            self.backbone = models.resnet101(pretrained=pretrained)
            num_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Identity()
        
        else:
            raise ValueError(f"Unsupported backbone: {backbone}")
        
        # Multi-label classification head
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)
        )
        
        print(f"Model initialized:")
        print(f"  Backbone: {backbone} (pretrained={pretrained})")
        print(f"  Feature dimension: {num_features}")
        print(f"  Output classes: {num_classes}")
        print(f"  Dropout: {dropout}")
    
    def forward(self, x):
        # Extract features
        features = self.backbone(x)
        
        # Classify
        logits = self.classifier(features)
        
        return logits
    
    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())
    
    def get_num_trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("Model class defined.")

In [None]:
# Initialize model
MODEL_BACKBONE = 'resnet50'  # Options: 'resnet34', 'resnet50', 'resnet101'
DROPOUT_RATE = 0.5

model = MultiLabelCNN(
    num_classes=num_classes, 
    backbone=MODEL_BACKBONE, 
    pretrained=True,
    dropout=DROPOUT_RATE
)

model = model.to(device)

print(f"\nModel moved to {device}")
print(f"Total parameters: {model.get_num_params():,}")
print(f"Trainable parameters: {model.get_num_trainable_params():,}")

## 5. Loss Function and Optimizer

In [None]:
# Binary Cross-Entropy Loss (suitable for multi-label)
# BCEWithLogitsLoss combines sigmoid + BCE for numerical stability
criterion = nn.BCEWithLogitsLoss()

print("Loss function: BCEWithLogitsLoss")
print("  - Combines Sigmoid + BCE for numerical stability")
print("  - Suitable for multi-label classification")
print("  - Treats each label independently")

In [None]:
# Optimizer and Learning Rate Scheduler
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4

optimizer = optim.Adam(
    model.parameters(), 
    lr=LEARNING_RATE, 
    weight_decay=WEIGHT_DECAY
)

# Learning rate scheduler: reduce LR when validation loss plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=3, 
    verbose=True
)

print(f"\nOptimizer: Adam")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"\nScheduler: ReduceLROnPlateau")
print(f"  Factor: 0.5")
print(f"  Patience: 3 epochs")

## 6. Training and Evaluation Functions

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    """
    Train model for one epoch.
    """
    model.train()
    running_loss = 0.0
    
    progress_bar = tqdm(dataloader, desc='Training')
    
    for images, labels in progress_bar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
    
    epoch_loss = running_loss / len(dataloader)
    return epoch_loss

In [None]:
def evaluate(model, dataloader, criterion, device, threshold=0.5):
    """
    Evaluate model on validation/test set.
    """
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Evaluating'):
            images = images.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            
            # Get probabilities and predictions
            probs = torch.sigmoid(outputs)  # Convert logits to probabilities
            preds = (probs > threshold).float()  # Apply threshold
            
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)
    
    return epoch_loss, all_preds, all_labels, all_probs

In [None]:
def calculate_metrics(y_true, y_pred, class_names):
    """
    Calculate comprehensive multi-label classification metrics.
    """
    # Overall metrics
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    weighted_f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    
    micro_precision = precision_score(y_true, y_pred, average='micro', zero_division=0)
    macro_precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    
    micro_recall = recall_score(y_true, y_pred, average='micro', zero_division=0)
    macro_recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    
    hamming = hamming_loss(y_true, y_pred)
    
    # Per-class metrics
    per_class_f1 = f1_score(y_true, y_pred, average=None, zero_division=0)
    per_class_precision = precision_score(y_true, y_pred, average=None, zero_division=0)
    per_class_recall = recall_score(y_true, y_pred, average=None, zero_division=0)
    
    # Subset accuracy (exact match)
    subset_acc = np.mean(np.all(y_true == y_pred, axis=1))
    
    metrics = {
        'micro_f1': micro_f1,
        'macro_f1': macro_f1,
        'weighted_f1': weighted_f1,
        'micro_precision': micro_precision,
        'macro_precision': macro_precision,
        'micro_recall': micro_recall,
        'macro_recall': macro_recall,
        'hamming_loss': hamming,
        'subset_accuracy': subset_acc,
        'per_class_f1': per_class_f1,
        'per_class_precision': per_class_precision,
        'per_class_recall': per_class_recall
    }
    
    return metrics

def print_metrics(metrics):
    """
    Pretty print metrics.
    """
    print("\n" + "="*60)
    print("EVALUATION METRICS")
    print("="*60)
    print(f"\nOverall Performance:")
    print(f"  Micro F1:      {metrics['micro_f1']:.4f}")
    print(f"  Macro F1:      {metrics['macro_f1']:.4f}")
    print(f"  Weighted F1:   {metrics['weighted_f1']:.4f}")
    print(f"\n  Micro Precision: {metrics['micro_precision']:.4f}")
    print(f"  Macro Precision: {metrics['macro_precision']:.4f}")
    print(f"\n  Micro Recall:    {metrics['micro_recall']:.4f}")
    print(f"  Macro Recall:    {metrics['macro_recall']:.4f}")
    print(f"\n  Hamming Loss:    {metrics['hamming_loss']:.4f}")
    print(f"  Subset Accuracy: {metrics['subset_accuracy']:.4f}")
    print("="*60)

print("Training and evaluation functions defined.")

## 7. Training Loop

In [None]:
# Training configuration
NUM_EPOCHS = 25
EARLY_STOP_PATIENCE = 7

# For tracking
history = {
    'train_loss': [],
    'val_loss': [],
    'val_micro_f1': [],
    'val_macro_f1': [],
}

best_val_f1 = 0.0
best_model_path = 'best_model_baseline.pth'
epochs_no_improve = 0

print(f"Starting training for {NUM_EPOCHS} epochs...")
print(f"Early stopping patience: {EARLY_STOP_PATIENCE}")
print(f"Best model will be saved to: {best_model_path}\n")

In [None]:
# Main training loop
for epoch in range(NUM_EPOCHS):
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"{'='*60}")
    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_preds, val_labels, val_probs = evaluate(
        model, val_loader, criterion, device, threshold=0.5
    )
    
    # Calculate metrics
    val_metrics = calculate_metrics(val_labels, val_preds, class_names)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_micro_f1'].append(val_metrics['micro_f1'])
    history['val_macro_f1'].append(val_metrics['macro_f1'])
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss:   {val_loss:.4f}")
    print(f"  Val Micro F1: {val_metrics['micro_f1']:.4f}")
    print(f"  Val Macro F1: {val_metrics['macro_f1']:.4f}")
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    print(f"  Current LR: {current_lr:.6f}")
    
    # Save best model
    if val_metrics['micro_f1'] > best_val_f1:
        best_val_f1 = val_metrics['micro_f1']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_f1': best_val_f1,
            'metrics': val_metrics
        }, best_model_path)
        print(f"  âœ“ New best model saved! (Micro F1: {best_val_f1:.4f})")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"  No improvement for {epochs_no_improve} epoch(s)")
    
    # Early stopping
    if epochs_no_improve >= EARLY_STOP_PATIENCE:
        print(f"\nEarly stopping triggered after {epoch+1} epochs.")
        break

print(f"\n{'='*60}")
print("Training completed!")
print(f"Best validation Micro F1: {best_val_f1:.4f}")
print(f"{'='*60}")

## 8. Visualize Training Progress

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(alpha=0.3)

# F1 scores
axes[1].plot(history['val_micro_f1'], label='Micro F1', marker='o')
axes[1].plot(history['val_macro_f1'], label='Macro F1', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('F1 Score')
axes[1].set_title('Validation F1 Scores')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("Training history visualized and saved to 'training_history.png'")

## 9. Test Set Evaluation

In [None]:
# Load best model
print("Loading best model for testing...")
checkpoint = torch.load(best_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Best model from epoch {checkpoint['epoch']+1} loaded.\n")

# Evaluate on test set
test_loss, test_preds, test_labels, test_probs = evaluate(
    model, test_loader, criterion, device, threshold=0.5
)

# Calculate metrics
test_metrics = calculate_metrics(test_labels, test_preds, class_names)

# Print results
print_metrics(test_metrics)

In [None]:
# Per-class performance visualization
per_class_data = {
    'Class': class_names,
    'F1': test_metrics['per_class_f1'],
    'Precision': test_metrics['per_class_precision'],
    'Recall': test_metrics['per_class_recall']
}

fig, ax = plt.subplots(figsize=(14, 8))

x = np.arange(len(class_names))
width = 0.25

ax.bar(x - width, per_class_data['Precision'], width, label='Precision', alpha=0.8)
ax.bar(x, per_class_data['Recall'], width, label='Recall', alpha=0.8)
ax.bar(x + width, per_class_data['F1'], width, label='F1 Score', alpha=0.8)

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title('Per-Class Performance on Test Set', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(class_names, rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', alpha=0.3)
ax.set_ylim([0, 1.0])

plt.tight_layout()
plt.savefig('per_class_performance.png', dpi=150, bbox_inches='tight')
plt.show()

print("Per-class performance visualized and saved to 'per_class_performance.png'")

In [None]:
# Detailed per-class results table
print("\nDetailed Per-Class Results:")
print("="*80)
print(f"{'Class':<20} {'Precision':>12} {'Recall':>12} {'F1 Score':>12} {'Support':>10}")
print("="*80)

for i, class_name in enumerate(class_names):
    support = int(test_labels[:, i].sum())
    print(f"{class_name:<20} {test_metrics['per_class_precision'][i]:>12.4f} "
          f"{test_metrics['per_class_recall'][i]:>12.4f} "
          f"{test_metrics['per_class_f1'][i]:>12.4f} {support:>10}")

print("="*80)

# Identify best and worst performing classes
best_idx = np.argmax(test_metrics['per_class_f1'])
worst_idx = np.argmin(test_metrics['per_class_f1'])

print(f"\nBest performing class: {class_names[best_idx]} (F1: {test_metrics['per_class_f1'][best_idx]:.4f})")
print(f"Worst performing class: {class_names[worst_idx]} (F1: {test_metrics['per_class_f1'][worst_idx]:.4f})")

## 10. Prediction Visualization

In [None]:
def visualize_predictions(model, dataset, device, num_samples=8, threshold=0.5):
    """
    Visualize model predictions on sample images.
    """
    model.eval()
    
    # Randomly select samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    for idx, ax in zip(indices, axes):
        image, label = dataset[idx]
        
        # Predict
        with torch.no_grad():
            image_input = image.unsqueeze(0).to(device)
            output = model(image_input)
            probs = torch.sigmoid(output).cpu().numpy()[0]
            pred = (probs > threshold).astype(int)
        
        # Convert image for display
        img_display = image.cpu().numpy().transpose(1, 2, 0)
        # Denormalize
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_display = std * img_display + mean
        img_display = np.clip(img_display, 0, 1)
        
        # Get labels
        true_labels = [class_names[i] for i in range(num_classes) if label[i] == 1]
        pred_labels = [class_names[i] for i in range(num_classes) if pred[i] == 1]
        
        # Display
        ax.imshow(img_display)
        ax.axis('off')
        
        title = f"True: {', '.join(true_labels)}\nPred: {', '.join(pred_labels)}"
        ax.set_title(title, fontsize=8, wrap=True)
        
        # Add colored border (green=correct, red=incorrect)
        if set(true_labels) == set(pred_labels):
            for spine in ax.spines.values():
                spine.set_edgecolor('green')
                spine.set_linewidth(3)
        else:
            for spine in ax.spines.values():
                spine.set_edgecolor('red')
                spine.set_linewidth(3)
    
    plt.tight_layout()
    plt.savefig('prediction_samples.png', dpi=150, bbox_inches='tight')
    plt.show()

# Visualize predictions
visualize_predictions(model, test_dataset, device, num_samples=8, threshold=0.5)
print("Prediction samples visualized and saved to 'prediction_samples.png'")
print("Green border = Correct prediction, Red border = Incorrect prediction")

## 11. Threshold Optimization (Optional)

In [None]:
# Find optimal threshold on validation set
print("Searching for optimal threshold...\n")

# Get validation predictions (already computed)
_, val_preds_default, val_labels, val_probs = evaluate(
    model, val_loader, criterion, device, threshold=0.5
)

# Try different thresholds
thresholds = np.arange(0.1, 0.9, 0.05)
f1_scores = []

for thresh in thresholds:
    preds = (val_probs > thresh).astype(int)
    f1 = f1_score(val_labels, preds, average='micro')
    f1_scores.append(f1)

# Find best threshold
best_threshold = thresholds[np.argmax(f1_scores)]
best_f1 = np.max(f1_scores)

print(f"Optimal threshold: {best_threshold:.2f}")
print(f"Micro F1 at optimal threshold: {best_f1:.4f}")
print(f"Micro F1 at default threshold (0.5): {f1_scores[8]:.4f}")
print(f"Improvement: {best_f1 - f1_scores[8]:.4f}")

# Plot threshold analysis
plt.figure(figsize=(10, 6))
plt.plot(thresholds, f1_scores, marker='o', linewidth=2)
plt.axvline(best_threshold, color='r', linestyle='--', label=f'Optimal: {best_threshold:.2f}')
plt.axvline(0.5, color='g', linestyle='--', label='Default: 0.50')
plt.xlabel('Threshold', fontsize=12)
plt.ylabel('Micro F1 Score', fontsize=12)
plt.title('Threshold Optimization on Validation Set', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig('threshold_optimization.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Re-evaluate on test set with optimal threshold
print(f"\nRe-evaluating on test set with optimal threshold ({best_threshold:.2f})...\n")

test_preds_optimized = (test_probs > best_threshold).astype(int)
test_metrics_optimized = calculate_metrics(test_labels, test_preds_optimized, class_names)

print("\nComparison: Default (0.5) vs Optimized Threshold")
print("="*70)
print(f"{'Metric':<25} {'Default (0.5)':>20} {'Optimized':>20}")
print("="*70)
print(f"{'Micro F1':<25} {test_metrics['micro_f1']:>20.4f} {test_metrics_optimized['micro_f1']:>20.4f}")
print(f"{'Macro F1':<25} {test_metrics['macro_f1']:>20.4f} {test_metrics_optimized['macro_f1']:>20.4f}")
print(f"{'Micro Precision':<25} {test_metrics['micro_precision']:>20.4f} {test_metrics_optimized['micro_precision']:>20.4f}")
print(f"{'Micro Recall':<25} {test_metrics['micro_recall']:>20.4f} {test_metrics_optimized['micro_recall']:>20.4f}")
print(f"{'Hamming Loss':<25} {test_metrics['hamming_loss']:>20.4f} {test_metrics_optimized['hamming_loss']:>20.4f}")
print("="*70)

## 12. Summary and Next Steps

In [None]:
print("\n" + "="*70)
print("BASELINE MODEL SUMMARY")
print("="*70)
print(f"\nModel Architecture: {MODEL_BACKBONE}")
print(f"Total Parameters: {model.get_num_params():,}")
print(f"Training Samples: {len(train_dataset)}")
print(f"Validation Samples: {len(val_dataset)}")
print(f"Test Samples: {len(test_dataset)}")
print(f"\nBest Validation Micro F1: {best_val_f1:.4f}")
print(f"Test Micro F1 (threshold=0.5): {test_metrics['micro_f1']:.4f}")
print(f"Test Micro F1 (optimized threshold={best_threshold:.2f}): {test_metrics_optimized['micro_f1']:.4f}")
print(f"Test Macro F1: {test_metrics_optimized['macro_f1']:.4f}")
print("\n" + "="*70)
print("NEXT STEPS FOR IMPROVEMENT:")
print("="*70)
print("1. Try different backbones (EfficientNet, ResNet101, Vision Transformer)")
print("2. Implement class-weighted loss or Focal Loss to handle class imbalance")
print("3. Advanced data augmentation (Mixup, CutMix, AutoAugment)")
print("4. Add attention mechanisms (CBAM, SENet)")
print("5. Implement label correlation modeling with GNN")
print("6. Build ensemble of multiple models")
print("7. Add Grad-CAM for interpretability")
print("8. Per-class threshold optimization")
print("="*70)

## 13. Save Results

In [None]:
# Save all results to a dictionary
results = {
    'model_config': {
        'backbone': MODEL_BACKBONE,
        'num_classes': num_classes,
        'dropout': DROPOUT_RATE,
        'total_params': model.get_num_params(),
    },
    'training_config': {
        'batch_size': BATCH_SIZE,
        'learning_rate': LEARNING_RATE,
        'weight_decay': WEIGHT_DECAY,
        'num_epochs_trained': len(history['train_loss']),
        'early_stop_patience': EARLY_STOP_PATIENCE,
    },
    'dataset_split': {
        'train_size': len(train_dataset),
        'val_size': len(val_dataset),
        'test_size': len(test_dataset),
    },
    'test_metrics_default_threshold': test_metrics,
    'test_metrics_optimized_threshold': test_metrics_optimized,
    'optimal_threshold': float(best_threshold),
    'training_history': history,
    'class_names': class_names,
}

# Save to file
import json

with open('baseline_results.json', 'w') as f:
    # Convert numpy arrays to lists for JSON serialization
    results_serializable = results.copy()
    for key in ['test_metrics_default_threshold', 'test_metrics_optimized_threshold']:
        for metric_key, metric_value in results_serializable[key].items():
            if isinstance(metric_value, np.ndarray):
                results_serializable[key][metric_key] = metric_value.tolist()
    
    json.dump(results_serializable, f, indent=2)

print("Results saved to 'baseline_results.json'")
print("\nFiles generated:")
print("  - best_model_baseline.pth (model checkpoint)")
print("  - baseline_results.json (all metrics and config)")
print("  - training_history.png (loss and F1 curves)")
print("  - per_class_performance.png (per-class metrics)")
print("  - prediction_samples.png (sample predictions)")
print("  - threshold_optimization.png (threshold analysis)")