# Identity Document Classification using ResNet50

This notebook demonstrates how to use ResNet50 with transfer learning for identity document classification.

ResNet (Residual Network) is a powerful CNN architecture that introduced skip connections to solve the vanishing gradient problem in deep networks. ResNet50 has 50 layers and was trained on the ImageNet dataset, making it capable of extracting rich features for various computer vision tasks.

## 1. Setup and Imports

In [None]:
import os
import sys
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import cv2
from datetime import datetime

# Add current directory to path to import local modules
sys.path.append('.')

# Import custom modules
from dataset import load_data, IDDocumentDataset, get_transforms
from models import create_resnet50
from train import train_model, evaluate_model, visualize_training_history, model_summary, visualize_misclassified_samples

# Set the random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

Let's configure our training parameters. ResNet50 is a larger model compared to EfficientNet-B0, so we'll adjust the learning parameters accordingly.

In [None]:
# Data parameters
DATA_DIR = '../../data/cropped_images'  # Path to dataset directory
IMG_SIZE = 224  # Input image size for ResNet50
VAL_SPLIT = 0.15  # Percentage of data to use for validation
TEST_SPLIT = 0.15  # Percentage of data to use for testing
BATCH_SIZE = 24  # Slightly smaller batch size due to model size
NUM_WORKERS = 4  # Number of workers for data loading

# Training parameters - tuned for ResNet50 transfer learning
EPOCHS = 25  # Maximum number of epochs to train for
LEARNING_RATE = 1e-4  # Lower learning rate for ResNet50
WEIGHT_DECAY = 1e-4  # L2 regularization
PATIENCE = 8  # Early stopping patience
CHECKPOINT_DIR = Path('../../checkpoints/resnet50')  # Directory to save model checkpoints
os.makedirs(CHECKPOINT_DIR, exist_ok=True)  # Create checkpoint directory if it doesn't exist

# Transfer learning settings
PRETRAINED = True  # Use pretrained weights
FREEZE_BACKBONE = True  # Freeze backbone layers except the last one

## 3. Data Loading and Exploration

Let's load our dataset and explore it. We'll use the same data loading function as in the other notebooks.

In [None]:
# Load the data
data = load_data(
    data_dir=DATA_DIR,
    img_size=IMG_SIZE,
    val_split=VAL_SPLIT,
    test_split=TEST_SPLIT,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    seed=SEED
)

# Get class names
class_names = data['class_names']
num_classes = data['num_classes']

print(f"Number of classes: {num_classes}")
print(f"Class names: {class_names}")

## 4. Dataset Visualization

Let's visualize some random samples from our dataset to understand what we're working with.

In [None]:
def visualize_original_samples(dataset, class_names, num_images=5, figsize=(15, 10)):
    """
    Visualize original samples from a dataset without any transformations.
    
    Args:
        dataset: PyTorch dataset object
        class_names (list): List of class names
        num_images (int): Number of images to display
        figsize (tuple): Figure size
    """
    # Select random indices
    indices = np.random.choice(len(dataset), num_images, replace=False)
    
    # Create figure
    fig, axes = plt.subplots(1, num_images, figsize=figsize)
    
    for i, idx in enumerate(indices):
        # Get image path and label directly from dataset
        img_path = dataset.image_paths[idx]
        label = dataset.labels[idx]
        
        # Load original image without transformations
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Display image
        axes[i].imshow(image)
        axes[i].set_title(f"Class: {class_names[label]}")
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize random samples from the training set
print("Random training samples:")
visualize_original_samples(data['train_dataset'], class_names)

# Visualize one sample from each class
print("\nOne sample from each class:")
for i, class_name in enumerate(class_names):
    # Find indices of samples from this class
    class_indices = [j for j, label in enumerate(data['train_dataset'].labels) if label == i]
    if class_indices:  # Make sure we have samples of this class
        # Select one random sample
        idx = random.choice(class_indices)
        
        # Get image path and load image
        img_path = data['train_dataset'].image_paths[idx]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Display image
        plt.figure(figsize=(5, 5))
        plt.imshow(image)
        plt.title(f"Class: {class_name}")
        plt.axis('off')
        plt.show()

## 5. Creating and Configuring the ResNet50 Model

Now we'll create our ResNet50 model with pretrained weights and configure it for our document classification task.

In [None]:
# Create the ResNet50 model
model = create_resnet50(
    num_classes=num_classes,
    pretrained=PRETRAINED,
    freeze_backbone=FREEZE_BACKBONE
)
model.to(device)

# Print model summary
model_summary(model)

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%} of total)")

## 6. Setting Up Training Components

Let's set up our loss function, optimizer, and learning rate scheduler for training ResNet50.

In [None]:
# Define loss function with label smoothing for better generalization
criterion = nn.CrossEntropyLoss(label_smoothing=0.1) 

# Define optimizer - AdamW with weight decay for regularization
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# For a partially frozen model like this, we'll use OneCycleLR
# which works well for transfer learning with larger models
scheduler = OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    epochs=EPOCHS,
    steps_per_epoch=len(data['train_loader']),
    pct_start=0.2,  # Spend 20% of time warming up
    div_factor=10,  # LR starts at max_lr/10
    final_div_factor=100,  # Final LR is max_lr/100
    anneal_strategy='cos'  # Cosine annealing
)

## 7. Training the Model

Now we'll train our ResNet50 model. We're using early stopping to prevent overfitting and to save the best model.

In [None]:
# Generate a unique model name with timestamp
model_name = f"resnet50_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

# Train the model
model, history = train_model(
    model=model,
    dataloaders={
        'train': data['train_loader'],
        'val': data['val_loader']
    },
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=EPOCHS,
    device=device,
    save_dir=CHECKPOINT_DIR,
    model_name=model_name,
    early_stopping_patience=PATIENCE
)

## 8. Visualizing Training Results

Let's visualize our training history to see how the model performed over time.

In [None]:
# Visualize training history
visualize_training_history(history)

## 9. Model Evaluation

Now let's evaluate our trained model on the test set to see how well it generalizes to unseen data.

In [None]:
# Evaluate the model on the test set
test_metrics = evaluate_model(model, data['test_loader'], device, class_names)

## 10. Analyzing Misclassifications

Let's look at examples that our model misclassified to understand its weaknesses and potential areas for improvement.

In [None]:
# Visualize misclassified samples
visualize_misclassified_samples(model, data['test_loader'], class_names, device, num_samples=10)

## 11. Fine-tuning the Model

Now, let's try fine-tuning by unfreezing more layers of the network. This is a common practice to further improve performance after initial training with frozen backbone.

In [None]:
# Unfreeze the last two blocks (layer3 and layer4) for fine-tuning
for name, param in model.named_parameters():
    if any(block in name for block in ['layer3', 'layer4', 'fc']):
        param.requires_grad = True

# Count trainable parameters after unfreezing
trainable_params_finetuned = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters after unfreezing: {trainable_params_finetuned:,} ({trainable_params_finetuned/total_params:.2%} of total)")

# Setup for fine-tuning with a lower learning rate
ft_learning_rate = LEARNING_RATE / 10  # Lower learning rate for fine-tuning
ft_epochs = 10  # Fewer epochs for fine-tuning

# New optimizer and scheduler for fine-tuning
ft_optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 
                         lr=ft_learning_rate, 
                         weight_decay=WEIGHT_DECAY)

ft_scheduler = CosineAnnealingLR(ft_optimizer, T_max=ft_epochs)

# Fine-tune the model
print("\nFine-tuning the model with unfrozen layers...")
ft_model_name = f"{model_name}_finetuned"
model, ft_history = train_model(
    model=model,
    dataloaders={
        'train': data['train_loader'],
        'val': data['val_loader']
    },
    criterion=criterion,
    optimizer=ft_optimizer,
    scheduler=ft_scheduler,
    num_epochs=ft_epochs,
    device=device,
    save_dir=CHECKPOINT_DIR,
    model_name=ft_model_name,
    early_stopping_patience=3  # Shorter patience for fine-tuning
)

## 12. Evaluating the Fine-tuned Model

Let's see if fine-tuning improved the model's performance.

In [None]:
# Visualize fine-tuning history
visualize_training_history(ft_history)

# Evaluate the fine-tuned model on the test set
ft_test_metrics = evaluate_model(model, data['test_loader'], device, class_names)

# Compare metrics before and after fine-tuning
print("\nComparison of metrics before and after fine-tuning:")
print(f"Initial model - Accuracy: {test_metrics['accuracy']:.4f}, F1 Score: {test_metrics['f1']:.4f}")
print(f"Fine-tuned model - Accuracy: {ft_test_metrics['accuracy']:.4f}, F1 Score: {ft_test_metrics['f1']:.4f}")

## 13. Saving the Fine-tuned Model

Let's save our fine-tuned model for future use.

In [None]:
# Save the fine-tuned model
save_path = os.path.join(CHECKPOINT_DIR, f"{ft_model_name}_final.pth")
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'num_classes': num_classes,
    'img_size': IMG_SIZE,
    'metrics': ft_test_metrics,
    'model_type': 'resnet50'
}, save_path)
print(f"Fine-tuned model saved to {save_path}")

## 14. Inference on Single Image

Let's use our fine-tuned model to classify individual images and visualize the results.

In [None]:
def predict_single_image(model, image_path, transform, class_names, device):
    """
    Predict class for a single image and visualize the result.
    
    Args:
        model: Trained PyTorch model
        image_path: Path to the image file
        transform: Transformation pipeline for inference
        class_names: List of class names
        device: Device to run inference on
    """
    # Load and preprocess the image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Display original image
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.title("Input Image")
    plt.axis('off')
    plt.show()
    
    # Apply transformations
    transformed = transform(image=image)
    image_tensor = transformed['image']
    image_tensor = image_tensor.unsqueeze(0).to(device)  # Add batch dimension
    
    # Set model to evaluation mode
    model.eval()
    
    # Make prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        values, indices = torch.topk(probabilities, 3)  # Get top 3 predictions
    
    # Display prediction results
    values = values.squeeze().cpu().numpy() * 100  # Convert to percentage
    indices = indices.squeeze().cpu().numpy()
    
    print("Top 3 predictions:")
    for i in range(min(3, len(class_names))):
        print(f"{class_names[indices[i]]}: {values[i]:.2f}%")
    
    # Create a horizontal bar chart for visualization
    plt.figure(figsize=(10, 5))
    plt.barh(y=[class_names[idx] for idx in indices[:3]], width=values[:3])
    plt.xlabel('Confidence (%)')
    plt.title('Top 3 Predictions')
    plt.xlim(0, 100)
    plt.gca().invert_yaxis()  # Highest confidence at the top
    plt.show()

# Get test images for inference
test_dataset = data['test_dataset']
test_indices = np.random.choice(len(test_dataset), 3, replace=False)

# Use validation/test transform for inference (no augmentation)
inference_transform = get_transforms(img_size=IMG_SIZE)['test']

# Run inference on multiple test images
for idx in test_indices:
    test_image_path = test_dataset.image_paths[idx]
    test_label = test_dataset.labels[idx]
    print(f"\nTrue class: {class_names[test_label]}")
    predict_single_image(model, test_image_path, inference_transform, class_names, device)

## 15. Model Inference Performance

Let's measure the inference time of our ResNet50 model to understand its performance in a production environment.

In [None]:
import time

def measure_inference_time(model, dataloader, device, num_runs=100):
    """
    Measure the average inference time per image.
    
    Args:
        model: Trained PyTorch model
        dataloader: PyTorch dataloader containing test data
        device: Device to run inference on
        num_runs: Number of inference runs to average over
    
    Returns:
        float: Average inference time per image in milliseconds
    """
    model.eval()  # Set model to evaluation mode
    
    # Get a single batch for testing
    inputs, _ = next(iter(dataloader))
    inputs = inputs.to(device)
    
    # Warm-up runs
    with torch.no_grad():
        for _ in range(10):
            _ = model(inputs[0:1])
    
    # Measure inference time
    torch.cuda.synchronize() if device.type == 'cuda' else None
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(inputs[0:1])  # Run inference on a single image
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    end_time = time.time()
    
    # Calculate average time in milliseconds
    avg_time = (end_time - start_time) * 1000 / num_runs
    
    return avg_time

# Measure inference time
avg_inference_time = measure_inference_time(model, data['test_loader'], device)
print(f"Average inference time per image: {avg_inference_time:.2f} ms")
print(f"Frames per second: {1000 / avg_inference_time:.2f} FPS")

## 16. Conclusion

In this notebook, we've walked through the process of using ResNet50 with transfer learning for identity document classification:

1. We set up our environment and loaded our dataset
2. We explored and visualized the dataset
3. We created a ResNet50 model with pretrained weights
4. We configured the model for our classification task, initially freezing most of the backbone
5. We trained the model using OneCycleLR scheduling and label smoothing
6. We evaluated the model and analyzed its performance
7. We fine-tuned the model by unfreezing more layers and training with a lower learning rate
8. We compared performance before and after fine-tuning
9. We performed inference on test images
10. We measured inference performance to understand real-world usage characteristics

ResNet50 is a powerful architecture for this document classification task because:
1. Its residual connections allow for deeper networks without vanishing gradients
2. The pretrained weights provide a strong starting point for transfer learning
3. It has a good balance of model capacity and inference speed

For further improvements, you could:
1. Try different training strategies (e.g., gradual unfreezing)
2. Experiment with different learning rates for different layers
3. Implement more advanced data augmentation techniques specific to document images
4. Ensemble this model with other architectures, like EfficientNet