# Visual Transformer with Linformer

Training Visual Transformer on *Dogs vs Cats Data*

* Dogs vs. Cats Redux: Kernels Edition - https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition
* Base Code - https://www.kaggle.com/reukki/pytorch-cnn-tutorial-with-cats-and-dogs/
* Efficient Attention Implementation - https://github.com/lucidrains/vit-pytorch#efficient-attention

In [2]:
!pip -q install vit_pytorch linformer

## Import Libraries

In [4]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import seaborn as sns
import time

from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score,cohen_kappa_score

from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

from vit_pytorch.efficient import ViT


In [5]:
print(f"Torch: {torch.__version__}")

Torch: 2.7.1+cu118


In [6]:
import torch
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


Torch: 2.7.1+cu118
CUDA available: True


In [7]:
# Training settings
batch_size = 64
epochs = 200
lr = 3e-4
gamma = 0.7
seed = 42

In [8]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [9]:
device = 'cuda'

## Load Data

In [11]:
# os.makedirs('data', exist_ok=True)
from pathlib import Path
parent_dir = Path.cwd().parent
root_path = os.path.join(parent_dir, 'dataset') 

In [12]:
train_transforms = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)


test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]
)

# 使用 ImageFolder 加载数据
train_data = datasets.ImageFolder(root=os.path.join(root_path, 'train'), transform=train_transforms)
valid_data   = datasets.ImageFolder(root=os.path.join(root_path, 'val'), transform=val_transforms)
test_data  = datasets.ImageFolder(root=os.path.join(root_path, 'test'), transform=val_transforms)

# DataLoader
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
valid_loader   = DataLoader(valid_data, batch_size=64, shuffle=False)
test_loader  = DataLoader(test_data, batch_size=64, shuffle=False)

In [13]:
print(f"Train: {len(train_data)} images, {len(train_loader)} batches")
print(f"Val:   {len(valid_data)} images, {len(valid_loader)} batches")
print(f"Test:  {len(test_data)} images, {len(test_loader)} batches")
print(f"Classes: {train_data.classes}")

Train: 5041 images, 79 batches
Val:   1263 images, 20 batches
Test:  359 images, 6 batches
Classes: ['Beach', 'Bridge', 'Pond', 'Port', 'River']


In [14]:
print(len(train_data), len(train_loader))

5041 79


In [15]:
print(len(valid_data), len(valid_loader))

1263 20


## Efficient Attention

### Linformer

In [18]:
efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  # 7x7 patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

### Visual Transformer

In [20]:
model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=5,
    transformer=efficient_transformer,
    channels=3,
).to(device)

### Training

In [22]:
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [23]:
checkpoint_dir = os.path.join(parent_dir, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
best_val_acc = 0

In [24]:
history = {
    "train_loss":[],
    "val_loss":[],
    "train_acc":[],
    "val_acc":[]
}

In [25]:
def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
    """加载检查点函数"""
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_acc = checkpoint['best_val_acc']
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
        return start_epoch, best_val_acc
    else:
        return 0, 0

In [26]:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    # 在这里添加检查点保存逻辑
    # 1. 保存最佳模型
    if epoch_val_accuracy > best_val_acc:
        best_val_acc = epoch_val_accuracy
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            'loss': epoch_val_loss,
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, 'best_model.pth'))
        print(f"Saved best model with validation accuracy: {best_val_acc:.4f}")
    
    # 2. 定期保存检查点（每10个epoch）
    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            'loss': epoch_val_loss,
        }
        torch.save(checkpoint, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth'))
        print(f"Saved regular checkpoint at epoch {epoch+1}")

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

    history["train_loss"].append(epoch_loss.detach().cpu().item())
    history["val_loss"].append(epoch_val_loss.detach().cpu().item())
    history["train_acc"].append(epoch_accuracy.detach().cpu().item() * 100)
    history["val_acc"].append(epoch_val_accuracy.detach().cpu().item() * 100)


  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.5590
Epoch : 1 - loss : 1.4517 - acc: 0.3568 - val_loss : 1.1611 - val_acc: 0.5590



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.6407
Epoch : 2 - loss : 1.1147 - acc: 0.5517 - val_loss : 0.9043 - val_acc: 0.6407



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7037
Epoch : 3 - loss : 1.0093 - acc: 0.6000 - val_loss : 0.8082 - val_acc: 0.7037



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7487
Epoch : 4 - loss : 0.9393 - acc: 0.6373 - val_loss : 0.6920 - val_acc: 0.7487



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.7547
Epoch : 5 - loss : 0.8400 - acc: 0.6828 - val_loss : 0.6591 - val_acc: 0.7547



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8085
Epoch : 6 - loss : 0.8345 - acc: 0.6885 - val_loss : 0.5707 - val_acc: 0.8085



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8232
Epoch : 7 - loss : 0.7527 - acc: 0.7210 - val_loss : 0.4910 - val_acc: 0.8232



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8463
Epoch : 8 - loss : 0.7127 - acc: 0.7348 - val_loss : 0.4684 - val_acc: 0.8463



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8484
Epoch : 9 - loss : 0.6921 - acc: 0.7437 - val_loss : 0.4709 - val_acc: 0.8484



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8544
Saved regular checkpoint at epoch 10
Epoch : 10 - loss : 0.6734 - acc: 0.7532 - val_loss : 0.4242 - val_acc: 0.8544



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8658
Epoch : 11 - loss : 0.6436 - acc: 0.7710 - val_loss : 0.4186 - val_acc: 0.8658



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 12 - loss : 0.6201 - acc: 0.7712 - val_loss : 0.4408 - val_acc: 0.8541



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 13 - loss : 0.6023 - acc: 0.7765 - val_loss : 0.4276 - val_acc: 0.8552



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8755
Epoch : 14 - loss : 0.6056 - acc: 0.7772 - val_loss : 0.3999 - val_acc: 0.8755



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8815
Epoch : 15 - loss : 0.5716 - acc: 0.7961 - val_loss : 0.3668 - val_acc: 0.8815



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 16 - loss : 0.5585 - acc: 0.8010 - val_loss : 0.3922 - val_acc: 0.8682



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8864
Epoch : 17 - loss : 0.5300 - acc: 0.8104 - val_loss : 0.3524 - val_acc: 0.8864



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8935
Epoch : 18 - loss : 0.5541 - acc: 0.8021 - val_loss : 0.3272 - val_acc: 0.8935



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.8994
Epoch : 19 - loss : 0.5187 - acc: 0.8148 - val_loss : 0.3147 - val_acc: 0.8994



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9008
Saved regular checkpoint at epoch 20
Epoch : 20 - loss : 0.5248 - acc: 0.8086 - val_loss : 0.3100 - val_acc: 0.9008



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9028
Epoch : 21 - loss : 0.5203 - acc: 0.8165 - val_loss : 0.3069 - val_acc: 0.9028



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 22 - loss : 0.4910 - acc: 0.8233 - val_loss : 0.3363 - val_acc: 0.8908



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 23 - loss : 0.4879 - acc: 0.8285 - val_loss : 0.2985 - val_acc: 0.9026



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 24 - loss : 0.5012 - acc: 0.8197 - val_loss : 0.3012 - val_acc: 0.8974



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9060
Epoch : 25 - loss : 0.4977 - acc: 0.8269 - val_loss : 0.2942 - val_acc: 0.9060



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 26 - loss : 0.4688 - acc: 0.8350 - val_loss : 0.2868 - val_acc: 0.9057



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9060
Epoch : 27 - loss : 0.4960 - acc: 0.8266 - val_loss : 0.2882 - val_acc: 0.9060



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 28 - loss : 0.4548 - acc: 0.8452 - val_loss : 0.3017 - val_acc: 0.8987



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9065
Epoch : 29 - loss : 0.4506 - acc: 0.8472 - val_loss : 0.2793 - val_acc: 0.9065



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9156
Saved regular checkpoint at epoch 30
Epoch : 30 - loss : 0.4511 - acc: 0.8418 - val_loss : 0.2634 - val_acc: 0.9156



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 31 - loss : 0.4481 - acc: 0.8410 - val_loss : 0.2861 - val_acc: 0.9065



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 32 - loss : 0.4350 - acc: 0.8492 - val_loss : 0.2804 - val_acc: 0.9060



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9166
Epoch : 33 - loss : 0.4384 - acc: 0.8508 - val_loss : 0.2541 - val_acc: 0.9166



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9252
Epoch : 34 - loss : 0.4251 - acc: 0.8489 - val_loss : 0.2336 - val_acc: 0.9252



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 35 - loss : 0.4356 - acc: 0.8433 - val_loss : 0.2461 - val_acc: 0.9169



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9315
Epoch : 36 - loss : 0.4359 - acc: 0.8491 - val_loss : 0.2309 - val_acc: 0.9315



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 37 - loss : 0.4050 - acc: 0.8606 - val_loss : 0.2411 - val_acc: 0.9213



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9328
Epoch : 38 - loss : 0.4091 - acc: 0.8560 - val_loss : 0.2150 - val_acc: 0.9328



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 39 - loss : 0.3996 - acc: 0.8584 - val_loss : 0.2305 - val_acc: 0.9247



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 40
Epoch : 40 - loss : 0.4110 - acc: 0.8582 - val_loss : 0.2412 - val_acc: 0.9247



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 41 - loss : 0.3910 - acc: 0.8618 - val_loss : 0.2171 - val_acc: 0.9320



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9388
Epoch : 42 - loss : 0.3826 - acc: 0.8624 - val_loss : 0.1931 - val_acc: 0.9388



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 43 - loss : 0.3954 - acc: 0.8568 - val_loss : 0.1966 - val_acc: 0.9385



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 44 - loss : 0.3693 - acc: 0.8728 - val_loss : 0.2091 - val_acc: 0.9349



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9458
Epoch : 45 - loss : 0.3839 - acc: 0.8669 - val_loss : 0.1839 - val_acc: 0.9458



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 46 - loss : 0.3873 - acc: 0.8628 - val_loss : 0.2196 - val_acc: 0.9263



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 47 - loss : 0.3679 - acc: 0.8761 - val_loss : 0.2060 - val_acc: 0.9328



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 48 - loss : 0.3548 - acc: 0.8796 - val_loss : 0.1861 - val_acc: 0.9375



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 49 - loss : 0.3515 - acc: 0.8798 - val_loss : 0.1888 - val_acc: 0.9427



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 50
Epoch : 50 - loss : 0.3496 - acc: 0.8803 - val_loss : 0.1903 - val_acc: 0.9396



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 51 - loss : 0.3562 - acc: 0.8780 - val_loss : 0.1841 - val_acc: 0.9438



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 52 - loss : 0.3635 - acc: 0.8719 - val_loss : 0.1952 - val_acc: 0.9414



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 53 - loss : 0.3549 - acc: 0.8750 - val_loss : 0.1939 - val_acc: 0.9354



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 54 - loss : 0.3172 - acc: 0.8944 - val_loss : 0.1754 - val_acc: 0.9424



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 55 - loss : 0.3276 - acc: 0.8827 - val_loss : 0.1872 - val_acc: 0.9388



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 56 - loss : 0.3255 - acc: 0.8860 - val_loss : 0.1814 - val_acc: 0.9424



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 57 - loss : 0.3193 - acc: 0.8849 - val_loss : 0.2048 - val_acc: 0.9296



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 58 - loss : 0.3149 - acc: 0.8874 - val_loss : 0.1666 - val_acc: 0.9455



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 59 - loss : 0.3071 - acc: 0.8944 - val_loss : 0.1647 - val_acc: 0.9455



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9567
Saved regular checkpoint at epoch 60
Epoch : 60 - loss : 0.3149 - acc: 0.8941 - val_loss : 0.1457 - val_acc: 0.9567



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 61 - loss : 0.3360 - acc: 0.8812 - val_loss : 0.1811 - val_acc: 0.9455



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9599
Epoch : 62 - loss : 0.3141 - acc: 0.8899 - val_loss : 0.1454 - val_acc: 0.9599



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 63 - loss : 0.2984 - acc: 0.8992 - val_loss : 0.1759 - val_acc: 0.9364



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 64 - loss : 0.3012 - acc: 0.8960 - val_loss : 0.1713 - val_acc: 0.9452



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 65 - loss : 0.3083 - acc: 0.8938 - val_loss : 0.1546 - val_acc: 0.9497



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 66 - loss : 0.3121 - acc: 0.8926 - val_loss : 0.1325 - val_acc: 0.9583



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9612
Epoch : 67 - loss : 0.3002 - acc: 0.9019 - val_loss : 0.1305 - val_acc: 0.9612



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 68 - loss : 0.2807 - acc: 0.9040 - val_loss : 0.1452 - val_acc: 0.9526



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 69 - loss : 0.2974 - acc: 0.8995 - val_loss : 0.1326 - val_acc: 0.9591



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 70
Epoch : 70 - loss : 0.2921 - acc: 0.8941 - val_loss : 0.1401 - val_acc: 0.9596



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9692
Epoch : 71 - loss : 0.2801 - acc: 0.9057 - val_loss : 0.1185 - val_acc: 0.9692



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 72 - loss : 0.2831 - acc: 0.8983 - val_loss : 0.1201 - val_acc: 0.9658



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 73 - loss : 0.2778 - acc: 0.9005 - val_loss : 0.1336 - val_acc: 0.9583



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 74 - loss : 0.2835 - acc: 0.9044 - val_loss : 0.1770 - val_acc: 0.9460



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 75 - loss : 0.2829 - acc: 0.9054 - val_loss : 0.1503 - val_acc: 0.9552



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 76 - loss : 0.2737 - acc: 0.9060 - val_loss : 0.1170 - val_acc: 0.9635



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 77 - loss : 0.2709 - acc: 0.9085 - val_loss : 0.1408 - val_acc: 0.9580



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 78 - loss : 0.2778 - acc: 0.9071 - val_loss : 0.1635 - val_acc: 0.9450



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 79 - loss : 0.2687 - acc: 0.9104 - val_loss : 0.1150 - val_acc: 0.9677



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 80
Epoch : 80 - loss : 0.2542 - acc: 0.9140 - val_loss : 0.1258 - val_acc: 0.9630



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 81 - loss : 0.2403 - acc: 0.9197 - val_loss : 0.1105 - val_acc: 0.9666



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 82 - loss : 0.2599 - acc: 0.9114 - val_loss : 0.1755 - val_acc: 0.9479



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 83 - loss : 0.2507 - acc: 0.9185 - val_loss : 0.1283 - val_acc: 0.9591



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 84 - loss : 0.2518 - acc: 0.9122 - val_loss : 0.1279 - val_acc: 0.9640



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 85 - loss : 0.2530 - acc: 0.9134 - val_loss : 0.1216 - val_acc: 0.9632



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 86 - loss : 0.2358 - acc: 0.9181 - val_loss : 0.1216 - val_acc: 0.9578



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 87 - loss : 0.2494 - acc: 0.9160 - val_loss : 0.1221 - val_acc: 0.9635



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 88 - loss : 0.2351 - acc: 0.9180 - val_loss : 0.1112 - val_acc: 0.9627



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 89 - loss : 0.2596 - acc: 0.9113 - val_loss : 0.1145 - val_acc: 0.9651



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 90
Epoch : 90 - loss : 0.2369 - acc: 0.9224 - val_loss : 0.1279 - val_acc: 0.9651



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 91 - loss : 0.2482 - acc: 0.9133 - val_loss : 0.1139 - val_acc: 0.9638



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 92 - loss : 0.2613 - acc: 0.9173 - val_loss : 0.1192 - val_acc: 0.9607



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 93 - loss : 0.2325 - acc: 0.9205 - val_loss : 0.1247 - val_acc: 0.9625



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 94 - loss : 0.2467 - acc: 0.9197 - val_loss : 0.1112 - val_acc: 0.9674



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 95 - loss : 0.2406 - acc: 0.9157 - val_loss : 0.1024 - val_acc: 0.9677



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 96 - loss : 0.2143 - acc: 0.9256 - val_loss : 0.1335 - val_acc: 0.9594



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 97 - loss : 0.2283 - acc: 0.9255 - val_loss : 0.0946 - val_acc: 0.9627



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 98 - loss : 0.2486 - acc: 0.9139 - val_loss : 0.1100 - val_acc: 0.9669



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9781
Epoch : 99 - loss : 0.2302 - acc: 0.9255 - val_loss : 0.0939 - val_acc: 0.9781



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 100
Epoch : 100 - loss : 0.2236 - acc: 0.9212 - val_loss : 0.1111 - val_acc: 0.9635



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 101 - loss : 0.2305 - acc: 0.9201 - val_loss : 0.1075 - val_acc: 0.9677



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 102 - loss : 0.2253 - acc: 0.9232 - val_loss : 0.1148 - val_acc: 0.9703



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 103 - loss : 0.2160 - acc: 0.9288 - val_loss : 0.1314 - val_acc: 0.9594



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 104 - loss : 0.2347 - acc: 0.9221 - val_loss : 0.0846 - val_acc: 0.9766



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 105 - loss : 0.2240 - acc: 0.9223 - val_loss : 0.1243 - val_acc: 0.9578



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 106 - loss : 0.2068 - acc: 0.9337 - val_loss : 0.1099 - val_acc: 0.9682



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 107 - loss : 0.2244 - acc: 0.9257 - val_loss : 0.0880 - val_acc: 0.9755



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 108 - loss : 0.2206 - acc: 0.9260 - val_loss : 0.1075 - val_acc: 0.9685



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 109 - loss : 0.2189 - acc: 0.9262 - val_loss : 0.1038 - val_acc: 0.9697



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 110
Epoch : 110 - loss : 0.2263 - acc: 0.9233 - val_loss : 0.0984 - val_acc: 0.9724



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 111 - loss : 0.2123 - acc: 0.9294 - val_loss : 0.1011 - val_acc: 0.9711



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 112 - loss : 0.1905 - acc: 0.9339 - val_loss : 0.0883 - val_acc: 0.9721



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 113 - loss : 0.2066 - acc: 0.9297 - val_loss : 0.0965 - val_acc: 0.9692



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 114 - loss : 0.2058 - acc: 0.9293 - val_loss : 0.1057 - val_acc: 0.9692



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 115 - loss : 0.2127 - acc: 0.9313 - val_loss : 0.0971 - val_acc: 0.9739



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 116 - loss : 0.2044 - acc: 0.9309 - val_loss : 0.0988 - val_acc: 0.9742



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9820
Epoch : 117 - loss : 0.2051 - acc: 0.9258 - val_loss : 0.0816 - val_acc: 0.9820



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 118 - loss : 0.2013 - acc: 0.9276 - val_loss : 0.0846 - val_acc: 0.9739



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 119 - loss : 0.1859 - acc: 0.9408 - val_loss : 0.0976 - val_acc: 0.9716



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 120
Epoch : 120 - loss : 0.1837 - acc: 0.9370 - val_loss : 0.0941 - val_acc: 0.9692



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 121 - loss : 0.2105 - acc: 0.9264 - val_loss : 0.1063 - val_acc: 0.9688



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 122 - loss : 0.1967 - acc: 0.9367 - val_loss : 0.0831 - val_acc: 0.9734



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 123 - loss : 0.1877 - acc: 0.9380 - val_loss : 0.1002 - val_acc: 0.9700



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 124 - loss : 0.2005 - acc: 0.9341 - val_loss : 0.1093 - val_acc: 0.9664



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 125 - loss : 0.1881 - acc: 0.9385 - val_loss : 0.1171 - val_acc: 0.9688



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 126 - loss : 0.2047 - acc: 0.9326 - val_loss : 0.1055 - val_acc: 0.9727



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 127 - loss : 0.1726 - acc: 0.9416 - val_loss : 0.1097 - val_acc: 0.9700



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 128 - loss : 0.1951 - acc: 0.9360 - val_loss : 0.0838 - val_acc: 0.9750



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 129 - loss : 0.1816 - acc: 0.9411 - val_loss : 0.0936 - val_acc: 0.9766



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 130
Epoch : 130 - loss : 0.1725 - acc: 0.9403 - val_loss : 0.0906 - val_acc: 0.9766



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 131 - loss : 0.1787 - acc: 0.9400 - val_loss : 0.1015 - val_acc: 0.9742



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 132 - loss : 0.1890 - acc: 0.9376 - val_loss : 0.1223 - val_acc: 0.9680



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 133 - loss : 0.2017 - acc: 0.9326 - val_loss : 0.0897 - val_acc: 0.9742



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 134 - loss : 0.1739 - acc: 0.9406 - val_loss : 0.0861 - val_acc: 0.9789



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 135 - loss : 0.1805 - acc: 0.9399 - val_loss : 0.1015 - val_acc: 0.9688



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 136 - loss : 0.1701 - acc: 0.9422 - val_loss : 0.0864 - val_acc: 0.9766



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 137 - loss : 0.1849 - acc: 0.9398 - val_loss : 0.0924 - val_acc: 0.9734



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 138 - loss : 0.1764 - acc: 0.9427 - val_loss : 0.0948 - val_acc: 0.9747



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 139 - loss : 0.1756 - acc: 0.9415 - val_loss : 0.0920 - val_acc: 0.9742



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 140
Epoch : 140 - loss : 0.1747 - acc: 0.9424 - val_loss : 0.0827 - val_acc: 0.9781



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 141 - loss : 0.1854 - acc: 0.9337 - val_loss : 0.0967 - val_acc: 0.9755



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 142 - loss : 0.1894 - acc: 0.9348 - val_loss : 0.1035 - val_acc: 0.9758



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 143 - loss : 0.1802 - acc: 0.9377 - val_loss : 0.0955 - val_acc: 0.9747



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 144 - loss : 0.1680 - acc: 0.9430 - val_loss : 0.0927 - val_acc: 0.9758



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 145 - loss : 0.1712 - acc: 0.9449 - val_loss : 0.0852 - val_acc: 0.9758



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 146 - loss : 0.1606 - acc: 0.9484 - val_loss : 0.0799 - val_acc: 0.9773



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 147 - loss : 0.1778 - acc: 0.9409 - val_loss : 0.0970 - val_acc: 0.9719



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 148 - loss : 0.1761 - acc: 0.9435 - val_loss : 0.0788 - val_acc: 0.9802



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 149 - loss : 0.1618 - acc: 0.9450 - val_loss : 0.0795 - val_acc: 0.9742



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 150
Epoch : 150 - loss : 0.1643 - acc: 0.9408 - val_loss : 0.0834 - val_acc: 0.9797



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 151 - loss : 0.1663 - acc: 0.9421 - val_loss : 0.0731 - val_acc: 0.9773



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 152 - loss : 0.1650 - acc: 0.9440 - val_loss : 0.0870 - val_acc: 0.9750



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 153 - loss : 0.1661 - acc: 0.9435 - val_loss : 0.0954 - val_acc: 0.9766



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 154 - loss : 0.1578 - acc: 0.9485 - val_loss : 0.0722 - val_acc: 0.9797



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 155 - loss : 0.1591 - acc: 0.9443 - val_loss : 0.0960 - val_acc: 0.9708



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 156 - loss : 0.1587 - acc: 0.9480 - val_loss : 0.0738 - val_acc: 0.9797



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 157 - loss : 0.1628 - acc: 0.9454 - val_loss : 0.0921 - val_acc: 0.9742



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 158 - loss : 0.1626 - acc: 0.9470 - val_loss : 0.0853 - val_acc: 0.9789



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 159 - loss : 0.1790 - acc: 0.9415 - val_loss : 0.1000 - val_acc: 0.9700



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 160
Epoch : 160 - loss : 0.1496 - acc: 0.9480 - val_loss : 0.0805 - val_acc: 0.9786



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 161 - loss : 0.1596 - acc: 0.9464 - val_loss : 0.0796 - val_acc: 0.9789



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 162 - loss : 0.1454 - acc: 0.9472 - val_loss : 0.0937 - val_acc: 0.9781



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 163 - loss : 0.1593 - acc: 0.9493 - val_loss : 0.0739 - val_acc: 0.9820



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 164 - loss : 0.1532 - acc: 0.9500 - val_loss : 0.0754 - val_acc: 0.9781



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 165 - loss : 0.1644 - acc: 0.9450 - val_loss : 0.0796 - val_acc: 0.9750



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 166 - loss : 0.1650 - acc: 0.9460 - val_loss : 0.0868 - val_acc: 0.9739



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 167 - loss : 0.1385 - acc: 0.9533 - val_loss : 0.0847 - val_acc: 0.9750



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 168 - loss : 0.1607 - acc: 0.9468 - val_loss : 0.0698 - val_acc: 0.9794



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 169 - loss : 0.1596 - acc: 0.9488 - val_loss : 0.0838 - val_acc: 0.9766



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 170
Epoch : 170 - loss : 0.1636 - acc: 0.9455 - val_loss : 0.0806 - val_acc: 0.9747



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 171 - loss : 0.1595 - acc: 0.9464 - val_loss : 0.0810 - val_acc: 0.9763



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9844
Epoch : 172 - loss : 0.1439 - acc: 0.9521 - val_loss : 0.0669 - val_acc: 0.9844



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 173 - loss : 0.1486 - acc: 0.9496 - val_loss : 0.0557 - val_acc: 0.9828



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 174 - loss : 0.1467 - acc: 0.9524 - val_loss : 0.0759 - val_acc: 0.9813



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 175 - loss : 0.1344 - acc: 0.9548 - val_loss : 0.0751 - val_acc: 0.9813



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 176 - loss : 0.1501 - acc: 0.9515 - val_loss : 0.0687 - val_acc: 0.9836



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 177 - loss : 0.1527 - acc: 0.9506 - val_loss : 0.0685 - val_acc: 0.9813



  0%|          | 0/79 [00:00<?, ?it/s]

Saved best model with validation accuracy: 0.9859
Epoch : 178 - loss : 0.1520 - acc: 0.9504 - val_loss : 0.0663 - val_acc: 0.9859



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 179 - loss : 0.1376 - acc: 0.9564 - val_loss : 0.0783 - val_acc: 0.9766



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 180
Epoch : 180 - loss : 0.1548 - acc: 0.9507 - val_loss : 0.0796 - val_acc: 0.9750



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 181 - loss : 0.1331 - acc: 0.9560 - val_loss : 0.0710 - val_acc: 0.9797



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 182 - loss : 0.1380 - acc: 0.9568 - val_loss : 0.0643 - val_acc: 0.9820



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 183 - loss : 0.1448 - acc: 0.9501 - val_loss : 0.0752 - val_acc: 0.9773



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 184 - loss : 0.1334 - acc: 0.9548 - val_loss : 0.0851 - val_acc: 0.9797



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 185 - loss : 0.1503 - acc: 0.9512 - val_loss : 0.0947 - val_acc: 0.9732



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 186 - loss : 0.1471 - acc: 0.9527 - val_loss : 0.0749 - val_acc: 0.9817



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 187 - loss : 0.1247 - acc: 0.9593 - val_loss : 0.0726 - val_acc: 0.9813



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 188 - loss : 0.1381 - acc: 0.9576 - val_loss : 0.0726 - val_acc: 0.9810



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 189 - loss : 0.1362 - acc: 0.9559 - val_loss : 0.0839 - val_acc: 0.9750



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 190
Epoch : 190 - loss : 0.1380 - acc: 0.9551 - val_loss : 0.0803 - val_acc: 0.9825



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 191 - loss : 0.1399 - acc: 0.9566 - val_loss : 0.0799 - val_acc: 0.9794



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 192 - loss : 0.1313 - acc: 0.9555 - val_loss : 0.0768 - val_acc: 0.9789



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 193 - loss : 0.1421 - acc: 0.9546 - val_loss : 0.0672 - val_acc: 0.9836



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 194 - loss : 0.1355 - acc: 0.9543 - val_loss : 0.0713 - val_acc: 0.9789



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 195 - loss : 0.1158 - acc: 0.9594 - val_loss : 0.0757 - val_acc: 0.9742



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 196 - loss : 0.1340 - acc: 0.9551 - val_loss : 0.0941 - val_acc: 0.9727



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 197 - loss : 0.1388 - acc: 0.9548 - val_loss : 0.0711 - val_acc: 0.9778



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 198 - loss : 0.1349 - acc: 0.9541 - val_loss : 0.0899 - val_acc: 0.9758



  0%|          | 0/79 [00:00<?, ?it/s]

Epoch : 199 - loss : 0.1420 - acc: 0.9512 - val_loss : 0.0849 - val_acc: 0.9771



  0%|          | 0/79 [00:00<?, ?it/s]

Saved regular checkpoint at epoch 200
Epoch : 200 - loss : 0.1365 - acc: 0.9561 - val_loss : 0.0866 - val_acc: 0.9734



In [27]:
sns.set_theme(style="whitegrid")

def plot_curves(hist, out="curves.png"):
    train_loss = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["train_loss"]
    ]
    val_loss = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["val_loss"]
    ]
    train_acc = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["train_acc"]
    ]
    val_acc = [
        v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
        for v in hist["val_acc"]
    ]

    epochs = range(1, len(train_loss) + 1)

  
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

    ax1.plot(epochs, train_loss, label="Train Loss")
    ax1.plot(epochs, val_loss,   label="Val Loss")
    ax1.set_title("Training & Validation Loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.legend()
    ax1.grid(alpha=0.3)

    ax2.plot(epochs, train_acc, label="Train Acc@1")
    ax2.plot(epochs, val_acc,   label="Val Acc@1")
    ax2.set_title("Training & Validation Accuracy")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("Accuracy (%)")
    ax2.legend()
    ax2.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig(out, dpi=300)
    plt.close()
    print(f"Curves saved to {out}")


# plot_curves(history)
curves_path = os.path.join(checkpoint_dir, "training_curves.png")
plot_curves(history, curves_path)


Curves saved to G:\remotesensing\checkpoints\training_curves.png


In [28]:
def evaluate():
    y_true, y_pred = [], []
    num_classes = len(train_data.classes)
    class_map = {str(i): name for i, name in enumerate(train_data.classes)}
    total_per_class   = [0] * num_classes
    correct_per_class = [0] * num_classes

    model.eval()
    start = time.time()
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs  = imgs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            logits = model(imgs)
            preds  = logits.argmax(dim=1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

            for t, p in zip(labels, preds):
                total_per_class[t]   += 1
                correct_per_class[t] += int(t == p)

    elapsed = time.time() - start
    y_true  = np.array(y_true)
    y_pred  = np.array(y_pred)

    oa    = accuracy_score(y_true, y_pred)
    class_accs = [
        np.mean(y_pred[y_true == i] == y_true[y_true == i])
        for i in range(num_classes) if np.sum(y_true == i) > 0
    ]
    macc  = np.mean(class_accs)
    kappa = cohen_kappa_score(y_true, y_pred)
    prec  = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec   = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1    = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(f"Test completed. Samples: {len(y_true)}  Time: {elapsed:.2f}s")
    print(f"Overall Accuracy (OA) : {oa:.4f}")
    print(f"Mean Accuracy (mAcc)  : {macc:.4f}")
    print(f"Cohen-Kappa           : {kappa:.4f}")
    print(f"Precision (macro)     : {prec:.4f}")
    print(f"Recall    (macro)     : {rec:.4f}")
    print(f"F1-score  (macro)     : {f1:.4f}")

    print("Per-class accuracy:")
    for idx in range(num_classes):
        total = total_per_class[idx]
        correct = correct_per_class[idx]
        acc_cls = 100.0 * correct / total if total else 0.0
        cls_name = class_map.get(str(idx), f'class_{idx}')
        print(f" [{idx:02d}] {cls_name:<20s}: {acc_cls:6.2f}%  ({correct}/{total})")

evaluate()

Test completed. Samples: 359  Time: 9.42s
Overall Accuracy (OA) : 0.9331
Mean Accuracy (mAcc)  : 0.9333
Cohen-Kappa           : 0.9164
Precision (macro)     : 0.9341
Recall    (macro)     : 0.9333
F1-score  (macro)     : 0.9331
Per-class accuracy:
 [00] Beach               :  95.83%  (69/72)
 [01] Bridge              :  92.42%  (61/66)
 [02] Pond                :  88.31%  (68/77)
 [03] Port                :  92.75%  (64/69)
 [04] River               :  97.33%  (73/75)
