In [27]:
! pip3 install certifi


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip3 install --upgrade pip[0m


In [55]:
from torchvision import datasets,transforms
from torch.utils.data import DataLoader, random_split

In [56]:
from PIL import Image, UnidentifiedImageError

In [57]:
def safe_loader(path):
    try:
        # Try to open with PIL
        return Image.open(path).convert("RGB")
    except UnidentifiedImageError:
        print(f"Skipping bad file: {path}")
        # Return a dummy image (black 224x224) so DataLoader doesn’t crash
        return Image.new("RGB", (224, 224))

In [58]:
# Define transforms (resize, tensor conversion, normalization)
transform = transforms.Compose([
    transforms.Resize((224, 224)),   # resize to fit CNN input
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # grayscale normalization
])

In [59]:
dataset = datasets.ImageFolder(root="data/rvl_cdip", loader=safe_loader,transform=transform)

In [60]:
print(len(dataset))          # total number of images
print(dataset.classes)       # list of class names
print(dataset.class_to_idx)  # mapping class → label

39997
['advertisement', 'budget', 'email', 'file_folder', 'form', 'handwritten', 'invoice', 'letter', 'memo', 'news_article', 'presentation', 'questionnaire', 'resume', 'scientific_publication', 'scientific_report', 'specification']
{'advertisement': 0, 'budget': 1, 'email': 2, 'file_folder': 3, 'form': 4, 'handwritten': 5, 'invoice': 6, 'letter': 7, 'memo': 8, 'news_article': 9, 'presentation': 10, 'questionnaire': 11, 'resume': 12, 'scientific_publication': 13, 'scientific_report': 14, 'specification': 15}


In [61]:
train_size = int(0.7 * len(dataset))   # 70% for training
test_size = int(0.2 * len(dataset))  # 20% for testing
val_size = len(dataset)-train_size-test_size # 10% for validation

In [62]:
train_dataset, test_dataset, val_dataset = random_split(dataset, [train_size, test_size, val_size])

In [71]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
test_loader   = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

In [48]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [64]:
import torchvision.models as models


model = models.resnet18(weights=None)  # no auto-download
state_dict = torch.load("resnet18-5c106cde.pth",weights_only=False)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [65]:
# Replace the final layer
import torch.nn as nn

num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 16)

In [66]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else 
                      "mps" if torch.backends.mps.is_available() else 
                      "cpu")

print("Using device:", device)
model.to(device)

Using device: mps


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [69]:
import os

# Make sure the checkpoints directory exists
os.makedirs("checkpoints", exist_ok=True)

best_acc = 0.0
for epoch in range(5):  # example: 5 epochs
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = 100 * correct / total
    print(f"Validation Accuracy: {100 * correct / total:.2f}%")

    # --- Checkpointing ---
    checkpoint = {
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "accuracy": acc,
        "loss": loss.item()
    }
    torch.save(checkpoint, f"checkpoints/epoch_{epoch+1}.pth")
    
    # Save best model separately
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "checkpoints/best_model.pth")
        print("Best model updated and saved!")

Epoch 1, Loss: 0.2349
Validation Accuracy: 82.18%
Best model updated and saved!
Epoch 2, Loss: 0.1760
Validation Accuracy: 81.20%
Epoch 3, Loss: 0.1730
Validation Accuracy: 81.40%
Epoch 4, Loss: 0.2619
Validation Accuracy: 80.88%
Epoch 5, Loss: 0.1375
Validation Accuracy: 81.63%


In [73]:
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Skipping bad file: data/rvl_cdip/scientific_publication/2500126531_2500126536.tif
Test Accuracy: 81.32%


In [75]:
# Step 7 
# Inference
# Load best checkpoint
state_dict = torch.load("checkpoints/best_model.pth", map_location=device)
model.load_state_dict(state_dict)

# Perform inference
model.eval()
test_data = torch.randn(10, 3,32,32)  # 10 new samples

with torch.no_grad():
    for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Skipping bad file: data/rvl_cdip/scientific_publication/2500126531_2500126536.tif
Test Accuracy: 81.32%


In [None]:
# Step 8
# Resume the checkpoint
checkpoint = torch.load('checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimiser.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] + 1
# Continue training
for epoch in range(start_epoch, num_epochs + 5):
    # training
    model.train()
    train_loss = 0
    for batch_data,batch_target in train_dataloader:
        output = model(batch_data)
        loss = criterion(output,batch_target) # Compute loss
        optimiser.zero_grad() # Zero gradients from previous iteration
        loss.backward() # calcluate new gradients via back propagation
        optimiser.step() # updates gradients
        train_loss+=loss.item()
        
    train_loss/=len(train_dataloader)      
    # validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_data,batch_target in val_dataloader:
            output = model(batch_data)
            loss = criterion(output,batch_target)
            val_loss+=loss.item()
        val_loss/=len(val_dataloader)
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")

    # Save best checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimiser.state_dict(),
            'loss': val_loss,
        }
        torch.save(checkpoint, 'checkpoints/best_model.pt')
        print(f"  ✓ Best model saved!")