# Identity Document Classification using Vision Transformer

This notebook demonstrates how to use Vision Transformer (ViT) with transfer learning for identity document classification.

Unlike CNNs, Vision Transformers treat images as sequences of patches and apply self-attention mechanisms to learn relationships between these patches. This architecture has shown impressive performance on various computer vision tasks, especially when pretrained on large datasets.

## 1. Setup and Imports

In [None]:
import os
import sys
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import cv2
from datetime import datetime

# Add current directory to path to import local modules
sys.path.append('.')

# Import custom modules
from dataset import load_data, IDDocumentDataset, get_transforms
from models import create_vision_transformer
from train import train_model, evaluate_model, visualize_training_history, model_summary, visualize_misclassified_samples

# Set the random seed for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration

Let's configure our training parameters. Vision Transformers are memory-intensive and often require specific hyperparameters for best performance.

In [None]:
# Data parameters
DATA_DIR = '../../data/cropped_images'  # Path to dataset directory
IMG_SIZE = 224  # Input image size for ViT-B/16 (224x224 is standard)
VAL_SPLIT = 0.15  # Percentage of data to use for validation
TEST_SPLIT = 0.15  # Percentage of data to use for testing
BATCH_SIZE = 16  # Smaller batch size due to memory requirements of ViT
NUM_WORKERS = 4  # Number of workers for data loading

# Training parameters - tuned for Vision Transformer
EPOCHS = 20  # Maximum number of epochs to train for
LEARNING_RATE = 5e-5  # Lower learning rate for ViT
WEIGHT_DECAY = 1e-3  # Higher weight decay for better regularization
PATIENCE = 6  # Early stopping patience
CHECKPOINT_DIR = Path('../../checkpoints/vit')  # Directory to save model checkpoints
os.makedirs(CHECKPOINT_DIR, exist_ok=True)  # Create checkpoint directory if it doesn't exist

# Transfer learning settings
PRETRAINED = True  # Use pretrained weights
FREEZE_BACKBONE = True  # Freeze backbone layers except the head

## 3. Data Loading and Exploration

Let's load our dataset and explore it.

In [None]:
# Load the data
data = load_data(
    data_dir=DATA_DIR,
    img_size=IMG_SIZE,
    val_split=VAL_SPLIT,
    test_split=TEST_SPLIT,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    seed=SEED
)

# Get class names
class_names = data['class_names']
num_classes = data['num_classes']

print(f"Number of classes: {num_classes}")
print(f"Class names: {class_names}")

## 4. Dataset Visualization

Let's visualize some samples from our dataset to understand what we're working with.

In [None]:
def visualize_original_samples(dataset, class_names, num_images=5, figsize=(15, 10)):
    """
    Visualize original samples from a dataset without any transformations.
    
    Args:
        dataset: PyTorch dataset object
        class_names (list): List of class names
        num_images (int): Number of images to display
        figsize (tuple): Figure size
    """
    # Select random indices
    indices = np.random.choice(len(dataset), num_images, replace=False)
    
    # Create figure
    fig, axes = plt.subplots(1, num_images, figsize=figsize)
    
    for i, idx in enumerate(indices):
        # Get image path and label directly from dataset
        img_path = dataset.image_paths[idx]
        label = dataset.labels[idx]
        
        # Load original image without transformations
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Display image
        axes[i].imshow(image)
        axes[i].set_title(f"Class: {class_names[label]}")
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize original samples from the training set
print("Random training samples:")
visualize_original_samples(data['train_dataset'], class_names)

## 5. Understanding Vision Transformer

Before we create our model, let's understand how Vision Transformer processes images differently than CNNs.

In [None]:
def visualize_image_patches(image_path, patch_size=16, figsize=(15, 10)):
    """
    Visualize how an image is split into patches for Vision Transformer processing.
    
    Args:
        image_path: Path to the image file
        patch_size: Size of each patch (ViT-B/16 uses 16x16 patches)
        figsize: Figure size
    """
    # Load and resize image to match ViT input size
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (224, 224))
    
    # Calculate grid dimensions
    h, w, _ = image.shape
    grid_h, grid_w = h // patch_size, w // patch_size
    
    # Create figure for original image
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.imshow(image)
    ax.set_title("Original Image")
    ax.axis('off')
    plt.show()
    
    # Create figure for patched visualization
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(image)
    
    # Draw grid lines
    for i in range(grid_h + 1):
        y = i * patch_size
        ax.axhline(y=y, color='white', linestyle='-', linewidth=1)
    
    for i in range(grid_w + 1):
        x = i * patch_size
        ax.axvline(x=x, color='white', linestyle='-', linewidth=1)
    
    # Add patch numbering for better understanding
    for i in range(grid_h):
        for j in range(grid_w):
            ax.text(j * patch_size + patch_size // 2, i * patch_size + patch_size // 2, 
                  f"{i*grid_w + j}", color='white', ha='center', va='center', fontsize=8,
                  bbox=dict(boxstyle="round,pad=0.3", fc="black", alpha=0.5))
    
    ax.set_title(f"Image Split into {patch_size}×{patch_size} Patches (Total: {grid_h*grid_w} patches)")
    ax.axis('off')
    plt.tight_layout()
    plt.show()
    
    return grid_h * grid_w  # Return number of patches

# Visualize image patches for a random sample
sample_idx = np.random.choice(len(data['train_dataset']))
sample_image_path = data['train_dataset'].image_paths[sample_idx]
sample_label = data['train_dataset'].labels[sample_idx]
print(f"Sample Class: {class_names[sample_label]}")
num_patches = visualize_image_patches(sample_image_path)
print(f"Vision Transformer processes these {num_patches} patches + 1 class token using self-attention.")

## 6. Creating and Configuring the Vision Transformer Model

Now we'll create our Vision Transformer model with pretrained weights and configure it for our document classification task.

In [None]:
# Create the Vision Transformer model
model = create_vision_transformer(
    num_classes=num_classes,
    pretrained=PRETRAINED,
    freeze_backbone=FREEZE_BACKBONE
)
model.to(device)

# Print model summary
model_summary(model)

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%} of total)")

## 7. Setting Up Training Components

Let's set up our loss function, optimizer, and learning rate scheduler for training the Vision Transformer.

In [None]:
# Define loss function with label smoothing for better generalization
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# Define optimizer - AdamW with weight decay is standard for ViT
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Vision Transformers benefit from cosine learning rate schedule
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

## 8. Training the Model

Now we'll train our Vision Transformer model using early stopping to prevent overfitting.

In [None]:
# Generate a unique model name with timestamp
model_name = f"vit_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

# Train the model
model, history = train_model(
    model=model,
    dataloaders={
        'train': data['train_loader'],
        'val': data['val_loader']
    },
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=EPOCHS,
    device=device,
    save_dir=CHECKPOINT_DIR,
    model_name=model_name,
    early_stopping_patience=PATIENCE
)

## 9. Visualizing Training Results

Let's visualize our training history to see how the model performed over time.

In [None]:
# Visualize training history
visualize_training_history(history)

## 10. Model Evaluation

Now let's evaluate our trained model on the test set to see how well it generalizes to unseen data.

In [None]:
# Evaluate the model on the test set
test_metrics = evaluate_model(model, data['test_loader'], device, class_names)

## 11. Analyzing Misclassifications

Let's look at examples that our model misclassified to understand its weaknesses and potential areas for improvement.

In [None]:
# Visualize misclassified samples
visualize_misclassified_samples(model, data['test_loader'], class_names, device, num_samples=10)

## 12. Fine-tuning the Model

Now, let's try fine-tuning by unfreezing the backbone. Vision Transformers often benefit from fine-tuning with a very low learning rate.

In [None]:
# Unfreeze all layers for fine-tuning
for param in model.parameters():
    param.requires_grad = True

# Count trainable parameters after unfreezing
trainable_params_finetuned = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters after unfreezing: {trainable_params_finetuned:,} ({trainable_params_finetuned/total_params:.2%} of total)")

# Setup for fine-tuning with a much lower learning rate
ft_learning_rate = LEARNING_RATE / 20  # Much lower learning rate for fine-tuning ViT
ft_epochs = 8  # Fewer epochs for fine-tuning

# New optimizer and scheduler for fine-tuning
ft_optimizer = optim.AdamW(model.parameters(), lr=ft_learning_rate, weight_decay=WEIGHT_DECAY)
ft_scheduler = CosineAnnealingLR(ft_optimizer, T_max=ft_epochs, eta_min=1e-7)

# Fine-tune the model
print("\nFine-tuning the model with unfrozen layers...")
ft_model_name = f"{model_name}_finetuned"
model, ft_history = train_model(
    model=model,
    dataloaders={
        'train': data['train_loader'],
        'val': data['val_loader']
    },
    criterion=criterion,
    optimizer=ft_optimizer,
    scheduler=ft_scheduler,
    num_epochs=ft_epochs,
    device=device,
    save_dir=CHECKPOINT_DIR,
    model_name=ft_model_name,
    early_stopping_patience=3  # Shorter patience for fine-tuning
)

## 13. Evaluating the Fine-tuned Model

Let's see if fine-tuning improved the model's performance.

In [None]:
# Visualize fine-tuning history
visualize_training_history(ft_history)

# Evaluate the fine-tuned model on the test set
ft_test_metrics = evaluate_model(model, data['test_loader'], device, class_names)

# Compare metrics before and after fine-tuning
print("\nComparison of metrics before and after fine-tuning:")
print(f"Initial model - Accuracy: {test_metrics['accuracy']:.4f}, F1 Score: {test_metrics['f1']:.4f}")
print(f"Fine-tuned model - Accuracy: {ft_test_metrics['accuracy']:.4f}, F1 Score: {ft_test_metrics['f1']:.4f}")

## 14. Saving the Fine-tuned Model

Let's save our fine-tuned model for future use.

In [None]:
# Save the fine-tuned model
save_path = os.path.join(CHECKPOINT_DIR, f"{ft_model_name}_final.pth")
torch.save({
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'num_classes': num_classes,
    'img_size': IMG_SIZE,
    'metrics': ft_test_metrics,
    'model_type': 'vit'
}, save_path)
print(f"Fine-tuned model saved to {save_path}")

## 15. Inference on Single Image

Let's use our fine-tuned model to classify individual images and visualize the results.

In [None]:
def predict_single_image(model, image_path, transform, class_names, device):
    """
    Predict class for a single image and visualize the result.
    
    Args:
        model: Trained PyTorch model
        image_path: Path to the image file
        transform: Transformation pipeline for inference
        class_names: List of class names
        device: Device to run inference on
    """
    # Load and preprocess the image
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Display original image
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.title("Input Image")
    plt.axis('off')
    plt.show()
    
    # Apply transformations
    transformed = transform(image=image)
    image_tensor = transformed['image']
    image_tensor = image_tensor.unsqueeze(0).to(device)  # Add batch dimension
    
    # Set model to evaluation mode
    model.eval()
    
    # Make prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        values, indices = torch.topk(probabilities, 3)  # Get top 3 predictions
    
    # Display prediction results
    values = values.squeeze().cpu().numpy() * 100  # Convert to percentage
    indices = indices.squeeze().cpu().numpy()
    
    print("Top 3 predictions:")
    for i in range(min(3, len(class_names))):
        print(f"{class_names[indices[i]]}: {values[i]:.2f}%")
    
    # Create a horizontal bar chart for visualization
    plt.figure(figsize=(10, 5))
    plt.barh(y=[class_names[idx] for idx in indices[:3]], width=values[:3])
    plt.xlabel('Confidence (%)')
    plt.title('Top 3 Predictions')
    plt.xlim(0, 100)
    plt.gca().invert_yaxis()  # Highest confidence at the top
    plt.show()

# Get test images for inference
test_dataset = data['test_dataset']
test_indices = np.random.choice(len(test_dataset), 3, replace=False)

# Use validation/test transform for inference (no augmentation)
inference_transform = get_transforms(img_size=IMG_SIZE)['test']

# Run inference on multiple test images
for idx in test_indices:
    test_image_path = test_dataset.image_paths[idx]
    test_label = test_dataset.labels[idx]
    print(f"\nTrue class: {class_names[test_label]}")
    predict_single_image(model, test_image_path, inference_transform, class_names, device)

## 16. Model Inference Performance

Let's measure the inference time of our Vision Transformer model to understand its performance in a production environment.

In [None]:
import time

def measure_inference_time(model, dataloader, device, num_runs=100):
    """
    Measure the average inference time per image.
    
    Args:
        model: Trained PyTorch model
        dataloader: PyTorch dataloader containing test data
        device: Device to run inference on
        num_runs: Number of inference runs to average over
    
    Returns:
        float: Average inference time per image in milliseconds
    """
    model.eval()  # Set model to evaluation mode
    
    # Get a single batch for testing
    inputs, _ = next(iter(dataloader))
    inputs = inputs.to(device)
    
    # Warm-up runs
    with torch.no_grad():
        for _ in range(10):
            _ = model(inputs[0:1])
    
    # Measure inference time
    torch.cuda.synchronize() if device.type == 'cuda' else None
    start_time = time.time()
    
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(inputs[0:1])  # Run inference on a single image
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    end_time = time.time()
    
    # Calculate average time in milliseconds
    avg_time = (end_time - start_time) * 1000 / num_runs
    
    return avg_time

# Measure inference time
avg_inference_time = measure_inference_time(model, data['test_loader'], device)
print(f"Average inference time per image: {avg_inference_time:.2f} ms")
print(f"Frames per second: {1000 / avg_inference_time:.2f} FPS")

## 17. Comparing Models

Let's compare the Vision Transformer with the CNN models we've previously trained.

In [None]:
# Define comparison metrics (hypothetical values - in practice, you would load actual results)
models_comparison = [
    {
        'Model': 'Custom CNN',
        'Accuracy': 0.92,  # Replace with actual value
        'F1 Score': 0.91,  # Replace with actual value
        'Parameters (M)': 5.2,  # Replace with actual value
        'Inference Time (ms)': 8.5  # Replace with actual value
    },
    {
        'Model': 'EfficientNet-B0',
        'Accuracy': 0.95,  # Replace with actual value
        'F1 Score': 0.94,  # Replace with actual value
        'Parameters (M)': 5.3,  # Replace with actual value
        'Inference Time (ms)': 12.3  # Replace with actual value
    },
    {
        'Model': 'ResNet50',
        'Accuracy': 0.94,  # Replace with actual value
        'F1 Score': 0.93,  # Replace with actual value
        'Parameters (M)': 25.6,  # Replace with actual value
        'Inference Time (ms)': 15.8  # Replace with actual value
    },
    {
        'Model': 'Vision Transformer',
        'Accuracy': ft_test_metrics['accuracy'],
        'F1 Score': ft_test_metrics['f1'],
        'Parameters (M)': total_params / 1e6,
        'Inference Time (ms)': avg_inference_time
    }
]

# Create a DataFrame for comparison
import pandas as pd
comparison_df = pd.DataFrame(models_comparison)
comparison_df

In [None]:
# Visualize comparison
metrics = ['Accuracy', 'F1 Score']
models = comparison_df['Model']

fig, ax = plt.subplots(figsize=(12, 6))

bar_width = 0.35
x = np.arange(len(models))

ax.bar(x - bar_width/2, comparison_df['Accuracy'], bar_width, label='Accuracy')
ax.bar(x + bar_width/2, comparison_df['F1 Score'], bar_width, label='F1 Score')

ax.set_ylim(0.85, 1.0)  # Adjust according to your actual values
ax.set_ylabel('Score')
ax.set_title('Model Performance Comparison')
ax.set_xticks(x)
ax.set_xticklabels(models)
ax.legend()

plt.tight_layout()
plt.show()

# Inference time comparison
plt.figure(figsize=(10, 6))
plt.barh(comparison_df['Model'], comparison_df['Inference Time (ms)'])
plt.xlabel('Inference Time (ms)')
plt.title('Inference Time Comparison')
plt.tight_layout()
plt.show()

# Model size comparison
plt.figure(figsize=(10, 6))
plt.barh(comparison_df['Model'], comparison_df['Parameters (M)'])
plt.xlabel('Model Size (M parameters)')
plt.title('Model Size Comparison')
plt.tight_layout()
plt.show()

## 18. Conclusion

In this notebook, we've explored the use of Vision Transformer (ViT) for identity document classification:

1. We examined how ViT processes images as sequences of patches rather than using convolutional operations
2. We created and trained a Vision Transformer model using transfer learning
3. We evaluated its performance and fine-tuned it for better results
4. We compared it with other model architectures (CNNs) we've previously explored

Key observations about Vision Transformer:

**Advantages:**
- Capable of learning global dependencies between image regions due to self-attention mechanism
- Can capture long-range interactions that CNNs might miss
- Demonstrates strong performance on document classification after fine-tuning

**Challenges:**
- Higher computational cost and memory usage compared to CNNs
- Generally requires more data or stronger regularization to prevent overfitting
- Slower inference time which could impact real-time applications

In our document classification task, Vision Transformer demonstrated competitive performance compared to CNN architectures. The ability to model global relationships between different parts of ID documents (text fields, photos, security features) makes ViT particularly well-suited for this domain.

For production deployment, the choice between ViT and CNN architectures would depend on specific requirements:
- If accuracy is the top priority and computational resources are available, a fine-tuned ViT could be the best choice
- If inference speed is critical, EfficientNet or a custom CNN might provide a better trade-off
- If model size is a constraint (e.g., for edge devices), EfficientNet would likely be preferable

Future improvements could include:
1. Hybrid approaches combining CNN features with transformer attention mechanisms
2. Ensembling multiple models for higher accuracy
3. Using domain-specific pretraining on document images
4. Exploring smaller, more efficient transformer variants like DeiT or Swin Transformer

Overall, this exploration of different architectures has provided a comprehensive evaluation of the strengths and weaknesses of various approaches to ID document classification, giving us a strong foundation for building a robust production system.