# Identity Document Classification using EfficientNet

This notebook demonstrates how to leverage transfer learning with EfficientNet-B0 for identity document classification.

EfficientNet is a family of convolutional neural networks that achieves state-of-the-art accuracy with significantly fewer parameters than traditional models. It uses a technique called compound scaling to scale the network's depth, width, and resolution in a principled way.

## 1. Setup and Imports

In [None]:
import os
import sys
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import cv2
from datetime import datetime

# Add current directory to path to import local modules
sys.path.append('.')

# Import custom modules
from dataset import load_data, IDDocumentDataset, get_transforms
from models import create_efficient_net_b0
from train import train_model, evaluate_model, visualize_training_history, model_summary, visualize_misclassified_samples

# Set the random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

Let's configure our training parameters. For EfficientNet, we'll use slightly different hyperparameters optimized for transfer learning.

In [None]:
# Data parameters
DATA_DIR = '../../data/cropped_images'  # Path to dataset directory
IMG_SIZE = 224  # Input image size (EfficientNet-B0 was trained on 224px images)
VAL_SPLIT = 0.15  # Percentage of data to use for validation
TEST_SPLIT = 0.15  # Percentage of data to use for testing
BATCH_SIZE = 32  # Batch size for training - adjust based on your GPU memory
NUM_WORKERS = 4  # Number of workers for data loading

# Training parameters - tuned for transfer learning
EPOCHS = 30  # Maximum number of epochs to train for
LEARNING_RATE = 3e-4  # Learning rate for transfer learning
WEIGHT_DECAY = 1e-5  # L2 regularization
PATIENCE = 7  # Early stopping patience
CHECKPOINT_DIR = Path('../../checkpoints/efficientnet')  # Directory to save model checkpoints
os.makedirs(CHECKPOINT_DIR, exist_ok=True)  # Create checkpoint directory if it doesn't exist

# Transfer learning settings
PRETRAINED = True  # Use pretrained weights
FREEZE_BACKBONE = False  # Whether to freeze the backbone layers

## 3. Data Loading and Exploration

Let's load our dataset and explore it.

In [None]:
# Load the data
data = load_data(
    data_dir=DATA_DIR,
    img_size=IMG_SIZE,
    val_split=VAL_SPLIT,
    test_split=TEST_SPLIT,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    seed=SEED
)

# Get class names
class_names = data['class_names']
num_classes = data['num_classes']

print(f"Number of classes: {num_classes}")
print(f"Class names: {class_names}")

## 4. Dataset Visualization

Let's visualize the original images from our dataset to understand what we're working with.

In [None]:
def visualize_original_samples(dataset, class_names, num_images=5, figsize=(15, 10)):
    """
    Visualize original samples from a dataset without any transformations.
    
    Args:
        dataset: PyTorch dataset object
        class_names (list): List of class names
        num_images (int): Number of images to display
        figsize (tuple): Figure size
    """
    # Select random indices
    indices = np.random.choice(len(dataset), num_images, replace=False)
    
    # Create figure
    fig, axes = plt.subplots(1, num_images, figsize=figsize)
    
    for i, idx in enumerate(indices):
        # Get image path and label directly from dataset
        img_path = dataset.image_paths[idx]
        label = dataset.labels[idx]
        
        # Load original image without transformations
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Display image
        axes[i].imshow(image)
        axes[i].set_title(f"Class: {class_names[label]}")
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize original samples from each class
for i, class_name in enumerate(class_names):
    print(f"Class: {class_name}")
    # Filter dataset to show only samples from this class
    class_indices = [j for j, label in enumerate(data['train_dataset'].labels) if label == i]
    
    # Choose a subset of indices
    subset_indices = np.random.choice(class_indices, min(5, len(class_indices)), replace=False)
    
    # Create a temporary dataset with just these samples
    temp_dataset = IDDocumentDataset(
        [data['train_dataset'].image_paths[j] for j in subset_indices],
        [data['train_dataset'].labels[j] for j in subset_indices],
        None,
        'train'
    )
    
    # Visualize
    visualize_original_samples(temp_dataset, class_names)
    print("\n")

## 5. Creating and Configuring the EfficientNet Model

Now we'll create our EfficientNet-B0 model with pretrained weights and configure it for our classification task.

In [None]:
# Create the EfficientNet model
model = create_efficient_net_b0(
    num_classes=num_classes,
    pretrained=PRETRAINED,
    freeze_backbone=FREEZE_BACKBONE
)
model.to(device)

# Print model summary
model_summary(model)

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%} of total)")

## 6. Setting Up Training Components

Let's set up our loss function, optimizer, and learning rate scheduler. For transfer learning, we'll use a OneCycleLR scheduler which has been shown to work well with pretrained models.

In [None]:
# Define loss function
criterion = nn.CrossEntropyLoss()

# Define optimizer - Adam with weight decay for regularization
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Define learning rate scheduler
# We'll use OneCycleLR which works well for transfer learning
scheduler = OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    epochs=EPOCHS,
    steps_per_epoch=len(data['train_loader']),
    pct_start=0.3,  # Spend 30% of time warming up
    div_factor=25,  # LR starts at max_lr/25
    final_div_factor=10000,  # Final LR is max_lr/10000
    anneal_strategy='cos'  # Cosine annealing
)

## 7. Training the Model

Now we'll train our EfficientNet model. The training loop includes both training and validation phases, with early stopping to prevent overfitting.

In [None]:
# Generate a unique model name with timestamp
model_name = f"efficientnet_b0_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

# Train the model
model, history = train_model(
    model=model,
    dataloaders={
        'train': data['train_loader'],
        'val': data['val_loader']
    },
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=EPOCHS,
    device=device,
    save_dir=CHECKPOINT_DIR,
    model_name=model_name,
    early_stopping_patience=PATIENCE
)

## 8. Visualizing Training Results

Let's visualize our training history to see how the model performed over time.

In [None]:
# Visualize training history
visualize_training_history(history)

## 9. Model Evaluation

Now let's evaluate our trained model on the test set to see how well it generalizes to unseen data.

In [None]:
# Evaluate the model on the test set
test_metrics = evaluate_model(model, data['test_loader'], device, class_names)

## 10. Analyzing Misclassifications

Let's look at examples that our model misclassified to understand its weaknesses and potential areas for improvement.

In [None]:
# Visualize misclassified samples
visualize_misclassified_samples(model, data['test_loader'], class_names, device, num_samples=15)

## 11. Inference on Single Image

Let's use our trained model to classify individual images and visualize the results.

In [None]:
def predict_single_image(model, image_path, transform, class_names, device):
    """
    Predict class for a single image and visualize the result.
    
    Args:
        model: Trained PyTorch model
        image_path: Path to the image file
        transform: Transformation pipeline for inference
        class_names: List of class names
        device: Device to run inference on
    """
    # Load and preprocess the image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Display original image
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.title("Input Image")
    plt.axis('off')
    plt.show()
    
    # Apply transformations
    transformed = transform(image=image)
    image_tensor = transformed['image']
    image_tensor = image_tensor.unsqueeze(0).to(device)  # Add batch dimension
    
    # Set model to evaluation mode
    model.eval()
    
    # Make prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        values, indices = torch.topk(probabilities, 3)  # Get top 3 predictions
    
    # Display prediction results
    values = values.squeeze().cpu().numpy() * 100  # Convert to percentage
    indices = indices.squeeze().cpu().numpy()
    
    print("Top 3 predictions:")
    for i in range(min(3, len(class_names))):
        print(f"{class_names[indices[i]]}: {values[i]:.2f}%")
    
    # Create a horizontal bar chart for visualization
    plt.figure(figsize=(10, 5))
    plt.barh(y=[class_names[idx] for idx in indices[:3]], width=values[:3])
    plt.xlabel('Confidence (%)')
    plt.title('Top 3 Predictions')
    plt.xlim(0, 100)
    plt.gca().invert_yaxis()  # Highest confidence at the top
    plt.show()

# Get test images for inference
test_dataset = data['test_dataset']
test_indices = np.random.choice(len(test_dataset), 3, replace=False)

# Use validation/test transform for inference (no augmentation)
inference_transform = get_transforms(img_size=IMG_SIZE)['test']

# Run inference on multiple test images
for idx in test_indices:
    test_image_path = test_dataset.image_paths[idx]
    test_label = test_dataset.labels[idx]
    print(f"\nTrue class: {class_names[test_label]}")
    predict_single_image(model, test_image_path, inference_transform, class_names, device)

## 12. Saving the Model

Let's save our trained model for future use.

In [None]:
# Save the model
save_path = os.path.join(CHECKPOINT_DIR, f"{model_name}_final.pth")
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'num_classes': num_classes,
    'img_size': IMG_SIZE,
    'metrics': test_metrics,
    'model_type': 'efficientnet'
}, save_path)
print(f"Model saved to {save_path}")

## 13. Loading and Using the Model for Inference

Let's demonstrate how to load the saved model and use it for inference.

In [None]:
def load_efficientnet_model(model_path, device):
    """
    Load a saved EfficientNet model and return it with associated metadata.
    
    Args:
        model_path: Path to the saved model file
        device: Device to load the model on
        
    Returns:
        model: The loaded PyTorch model
        metadata: Dictionary containing model metadata
    """
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract metadata
    class_names = checkpoint.get('class_names')
    num_classes = checkpoint.get('num_classes')
    img_size = checkpoint.get('img_size', 224)
    metrics = checkpoint.get('metrics', {})
    
    # Create model
    model = create_efficient_net_b0(num_classes=num_classes, pretrained=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    return model, {
        'class_names': class_names,
        'num_classes': num_classes,
        'img_size': img_size,
        'metrics': metrics
    }

# Test loading the model
loaded_model, metadata = load_efficientnet_model(save_path, device)
print(f"Model loaded successfully with {metadata['num_classes']} classes")
print(f"Test metrics: Accuracy={metadata['metrics'].get('accuracy', 'N/A'):.4f}")

# Test inference with the loaded model
test_idx = np.random.randint(0, len(test_dataset))
test_image_path = test_dataset.image_paths[test_idx]
test_label = test_dataset.labels[test_idx]

print(f"\nTrue class: {metadata['class_names'][test_label]}")
predict_single_image(loaded_model, test_image_path, inference_transform, metadata['class_names'], device)

## 14. Experimenting with Freezing Layers

Let's experiment with freezing the backbone layers of EfficientNet to see how it affects performance. This is a common technique in transfer learning to prevent overfitting when working with small datasets.

In [None]:
# Compare the number of trainable parameters with and without freezing
model_unfrozen = create_efficient_net_b0(num_classes=num_classes, pretrained=True, freeze_backbone=False)
model_frozen = create_efficient_net_b0(num_classes=num_classes, pretrained=True, freeze_backbone=True)

trainable_unfrozen = sum(p.numel() for p in model_unfrozen.parameters() if p.requires_grad)
trainable_frozen = sum(p.numel() for p in model_frozen.parameters() if p.requires_grad)

print(f"Trainable parameters (unfrozen): {trainable_unfrozen:,}")
print(f"Trainable parameters (frozen): {trainable_frozen:,}")
print(f"Reduction: {(1 - trainable_frozen/trainable_unfrozen):.2%}")

# The notebook implementation stops here, but in a real scenario you would:
# 1. Train the model with frozen layers using a smaller learning rate
# 2. Compare the performance with the unfrozen model
# 3. Potentially unfreeze layers gradually for fine-tuning

## 15. Conclusion

In this notebook, we've walked through the process of using EfficientNet-B0 with transfer learning for identity document classification:

1. We set up our environment and loaded our dataset
2. We explored and visualized the dataset
3. We created an EfficientNet-B0 model with pretrained weights
4. We configured the model for our classification task
5. We used advanced training techniques like OneCycleLR scheduling
6. We trained, evaluated, and analyzed the model
7. We performed inference on test images
8. We saved and loaded the model for future use
9. We explored the effects of freezing backbone layers for transfer learning

EfficientNet is a powerful model for this task because it efficiently extracts features from images. The pretrained weights help the model leverage general image understanding to quickly adapt to our specific document classification task, even with a relatively small dataset.

For further improvements, you could:
1. Try other EfficientNet variants (B1-B7) for potentially better performance
2. Experiment with more advanced data augmentation techniques
3. Implement gradual unfreezing during training
4. Try different learning rate schedules