In [1]:
import os
import time
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from timm import create_model
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import pandas as pd
import copy

In [2]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder(root='C:/Users/liamcee/Documents/farbruh/fer13/train', transform=transform)
test_dataset = datasets.ImageFolder(root='C:/Users/liamcee/Documents/farbruh/fer13/test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

num_classes = len(train_dataset.classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device, num_classes)

cuda 7


In [3]:
model = create_model(
    'convnext_tiny',
    pretrained=True,
    num_classes=num_classes
)
model = model.to(device)

In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=4, gamma=0.1)

In [5]:
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []

best_acc = 0.0
best_model_wts = copy.deepcopy(model.state_dict())

def train(model, loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    loop = tqdm(loader, desc="Training")

    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        loop.set_postfix(loss=loss.item(), acc=100.0 * correct / total)

    epoch_loss = running_loss / len(loader)
    epoch_acc = 100.0 * correct / total
    train_losses.append(epoch_loss)
    train_accuracies.append(epoch_acc)
    print(f"Train Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

def evaluate(model, loader, criterion, epoch):
    global best_acc, best_model_wts
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    loop = tqdm(loader, desc="Evaluating")

    with torch.no_grad():
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            loop.set_postfix(loss=loss.item(), acc=100.0 * correct / total)

    epoch_loss = running_loss / len(loader)
    epoch_acc = 100.0 * correct / total
    test_losses.append(epoch_loss)
    test_accuracies.append(epoch_acc)

    print(f"Test  Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

    if epoch_acc > best_acc:
        best_acc = epoch_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(best_model_wts, 'best_convnext_fer13.pth')
        print(f"best model saved with accuracy: {best_acc:.2f}%")

In [6]:
epochs = 20
for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")
    start_time = time.time()
    
    train(model, train_loader, optimizer, criterion)
    evaluate(model, test_loader, criterion, epoch)
    
    scheduler.step()

    elapsed = time.time() - start_time
    print(f"Time: {elapsed:.2f} s")


Epoch 1/20


Training: 100%|██████████| 1795/1795 [06:48<00:00,  4.39it/s, acc=53.4, loss=1.13] 


Train Loss: 1.2154, Accuracy: 53.36%


Evaluating: 100%|██████████| 449/449 [00:32<00:00, 13.61it/s, acc=63.1, loss=0.33]  


Test  Loss: 0.9755, Accuracy: 63.07%
best model saved with accuracy: 63.07%
Time: 441.96 s

Epoch 2/20


Training: 100%|██████████| 1795/1795 [06:50<00:00,  4.37it/s, acc=66.3, loss=0.304]


Train Loss: 0.9033, Accuracy: 66.34%


Evaluating: 100%|██████████| 449/449 [00:32<00:00, 13.65it/s, acc=64.9, loss=0.674] 


Test  Loss: 0.9701, Accuracy: 64.91%
best model saved with accuracy: 64.91%
Time: 443.66 s

Epoch 3/20


Training: 100%|██████████| 1795/1795 [06:50<00:00,  4.37it/s, acc=72.6, loss=0.0554]


Train Loss: 0.7449, Accuracy: 72.59%


Evaluating: 100%|██████████| 449/449 [00:33<00:00, 13.58it/s, acc=67, loss=0.349]   


Test  Loss: 0.9098, Accuracy: 66.98%
best model saved with accuracy: 66.98%
Time: 443.99 s

Epoch 4/20


Training: 100%|██████████| 1795/1795 [06:51<00:00,  4.36it/s, acc=79.3, loss=0.544]


Train Loss: 0.5740, Accuracy: 79.33%


Evaluating: 100%|██████████| 449/449 [00:32<00:00, 13.64it/s, acc=68.7, loss=0.144] 


Test  Loss: 0.9052, Accuracy: 68.72%
best model saved with accuracy: 68.72%
Time: 444.60 s

Epoch 5/20


Training: 100%|██████████| 1795/1795 [15:23<00:00,  1.94it/s, acc=92.6, loss=0.57]   


Train Loss: 0.2238, Accuracy: 92.56%


Evaluating: 100%|██████████| 449/449 [01:07<00:00,  6.65it/s, acc=70.5, loss=0.957] 


Test  Loss: 1.1460, Accuracy: 70.49%
best model saved with accuracy: 70.49%
Time: 991.37 s

Epoch 6/20


Training: 100%|██████████| 1795/1795 [16:28<00:00,  1.82it/s, acc=96.9, loss=0.00433]


Train Loss: 0.1043, Accuracy: 96.89%


Evaluating: 100%|██████████| 449/449 [01:19<00:00,  5.62it/s, acc=70.3, loss=1.46]   


Test  Loss: 1.4358, Accuracy: 70.33%
Time: 1068.52 s

Epoch 7/20


Training: 100%|██████████| 1795/1795 [16:39<00:00,  1.80it/s, acc=98.8, loss=0.00397] 


Train Loss: 0.0456, Accuracy: 98.83%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.86it/s, acc=70.1, loss=2.26]   


Test  Loss: 1.8459, Accuracy: 70.14%
Time: 1064.84 s

Epoch 8/20


Training: 100%|██████████| 1795/1795 [17:09<00:00,  1.74it/s, acc=99.4, loss=0.000596]


Train Loss: 0.0260, Accuracy: 99.39%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.83it/s, acc=69.9, loss=2.18]    


Test  Loss: 2.1026, Accuracy: 69.89%
Time: 1095.57 s

Epoch 9/20


Training: 100%|██████████| 1795/1795 [17:09<00:00,  1.74it/s, acc=99.7, loss=0.000288]


Train Loss: 0.0119, Accuracy: 99.70%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.85it/s, acc=69.9, loss=2.31]    


Test  Loss: 2.1493, Accuracy: 69.95%
Time: 1094.83 s

Epoch 10/20


Training: 100%|██████████| 1795/1795 [16:58<00:00,  1.76it/s, acc=99.7, loss=0.878]   


Train Loss: 0.0091, Accuracy: 99.74%


Evaluating: 100%|██████████| 449/449 [00:49<00:00,  9.14it/s, acc=70.3, loss=2.29]    


Test  Loss: 2.2422, Accuracy: 70.34%
Time: 1067.80 s

Epoch 11/20


Training: 100%|██████████| 1795/1795 [17:12<00:00,  1.74it/s, acc=99.8, loss=0.000146]


Train Loss: 0.0068, Accuracy: 99.75%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.89it/s, acc=70.1, loss=2.41]    


Test  Loss: 2.3636, Accuracy: 70.13%
Time: 1097.34 s

Epoch 12/20


Training: 100%|██████████| 1795/1795 [17:12<00:00,  1.74it/s, acc=99.7, loss=0.000313]


Train Loss: 0.0059, Accuracy: 99.74%


Evaluating: 100%|██████████| 449/449 [01:04<00:00,  6.92it/s, acc=70, loss=2.46]      


Test  Loss: 2.4394, Accuracy: 69.99%
Time: 1097.73 s

Epoch 13/20


Training: 100%|██████████| 1795/1795 [17:12<00:00,  1.74it/s, acc=99.8, loss=0.000783]


Train Loss: 0.0043, Accuracy: 99.84%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.85it/s, acc=70.2, loss=2.5]     


Test  Loss: 2.4510, Accuracy: 70.17%
Time: 1097.91 s

Epoch 14/20


Training: 100%|██████████| 1795/1795 [17:13<00:00,  1.74it/s, acc=99.8, loss=0.000117]


Train Loss: 0.0042, Accuracy: 99.84%


Evaluating: 100%|██████████| 449/449 [01:02<00:00,  7.24it/s, acc=70.3, loss=2.54]    


Test  Loss: 2.4653, Accuracy: 70.33%
Time: 1095.80 s

Epoch 15/20


Training: 100%|██████████| 1795/1795 [17:06<00:00,  1.75it/s, acc=99.8, loss=2.99e-5] 


Train Loss: 0.0041, Accuracy: 99.84%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.85it/s, acc=70.3, loss=2.57]    


Test  Loss: 2.4810, Accuracy: 70.34%
Time: 1092.15 s

Epoch 16/20


Training: 100%|██████████| 1795/1795 [17:16<00:00,  1.73it/s, acc=99.8, loss=0.000535]


Train Loss: 0.0040, Accuracy: 99.83%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.85it/s, acc=70.3, loss=2.6]     


Test  Loss: 2.4960, Accuracy: 70.33%
Time: 1102.56 s

Epoch 17/20


Training: 100%|██████████| 1795/1795 [17:16<00:00,  1.73it/s, acc=99.8, loss=0.000145]


Train Loss: 0.0038, Accuracy: 99.84%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.86it/s, acc=70.3, loss=2.6]     


Test  Loss: 2.4974, Accuracy: 70.34%
Time: 1101.47 s

Epoch 18/20


Training: 100%|██████████| 1795/1795 [17:10<00:00,  1.74it/s, acc=99.8, loss=0.0014]  


Train Loss: 0.0038, Accuracy: 99.84%


Evaluating: 100%|██████████| 449/449 [00:59<00:00,  7.56it/s, acc=70.3, loss=2.6]     


Test  Loss: 2.4989, Accuracy: 70.34%
Time: 1089.48 s

Epoch 19/20


Training: 100%|██████████| 1795/1795 [16:53<00:00,  1.77it/s, acc=99.8, loss=3.07e-5] 


Train Loss: 0.0038, Accuracy: 99.84%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.85it/s, acc=70.3, loss=2.6]     


Test  Loss: 2.5003, Accuracy: 70.34%
Time: 1078.55 s

Epoch 20/20


Training: 100%|██████████| 1795/1795 [17:08<00:00,  1.75it/s, acc=99.8, loss=0.000531]


Train Loss: 0.0038, Accuracy: 99.84%


Evaluating: 100%|██████████| 449/449 [01:05<00:00,  6.87it/s, acc=70.3, loss=2.6]     

Test  Loss: 2.5017, Accuracy: 70.33%
Time: 1093.67 s





In [7]:
metrics = {
    'epoch': list(range(1, epochs + 1)),
    'train_loss': train_losses,
    'train_accuracy': train_accuracies,
    'test_loss': test_losses,
    'test_accuracy': test_accuracies
}

df_metrics = pd.DataFrame(metrics)

df_metrics.to_csv("convnext_fer13_metrics.csv", index=False)

torch.save(model.state_dict(), 'ferconvnext_weights.pth')
torch.save(model, 'ferconvnext_full.pth')