In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

In [13]:
# ResNet очікує 3 канали (RGB) і розмір 224x224
train_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), # Конвертуємо 1 канал в 3 (ImageNet очікує RGB)
    transforms.Resize((224, 224)),               # Стандартний розмір для ResNet
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Нормалізація ImageNet
])

test_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = datasets.ImageFolder("./train", transform=train_transform)
test_data  = datasets.ImageFolder("./test", transform=test_transform)
    
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=64, shuffle=False)

num_classes = len(train_data.classes)
print("Класи:", train_data.classes)

Класи: ['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised']


In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Використовується пристрій: {device}")

Використовується пристрій: cpu


In [15]:
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

In [16]:
# ResNet18 має шар fc (Linear) на виході. Вхідні ознаки цього шару: model.fc.in_features
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes) # Замінюємо на новий шар з 7 виходами

model = model.to(device)

In [17]:
criterion = nn.CrossEntropyLoss()
# Для fine-tuning часто використовують менший learning rate
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4) 
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

In [18]:
def train_epoch(loader, model, loss_fn, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    loop = tqdm(loader, leave=False)
    for X, y in loop:
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(pred, 1)
        correct += (predicted == y).sum().item()
        total += y.size(0)

        loop.set_postfix(loss=loss.item(), accuracy=correct/total)

    return total_loss / len(loader), correct / total

def test_epoch(loader, model, loss_fn):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    loop = tqdm(loader, leave=False)
    with torch.no_grad():
        for X, y in loop:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = loss_fn(pred, y)

            total_loss += loss.item()
            _, predicted = torch.max(pred, 1)
            correct += (predicted == y).sum().item()
            total += y.size(0)

            loop.set_postfix(loss=loss.item(), accuracy=correct/total)

    return total_loss / len(loader), correct / total

In [19]:
epochs = 15 # Transfer learning зазвичай збігається швидше
best_val_acc = 0
patience = 5
counter = 0

train_losses, train_accs = [], []
val_losses, val_accs = [], []

for epoch in range(epochs):
    print(f"\nEpoch {epoch+1}/{epochs}")

    train_loss, train_acc = train_epoch(train_loader, model, criterion, optimizer)
    val_loss, val_acc = test_epoch(test_loader, model, criterion)

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    scheduler.step(val_loss)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        counter = 0
        torch.save(model.state_dict(), "resnet18_finetuned.pth")
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping!")
            break


Epoch 1/15


                                                                             

Train Loss: 1.1484 | Train Acc: 0.5638 | Val Loss: 0.9929 | Val Acc: 0.6303

Epoch 2/15


                                                                             

Train Loss: 0.9163 | Train Acc: 0.6578 | Val Loss: 0.9608 | Val Acc: 0.6457

Epoch 3/15


                                                                             

Train Loss: 0.8176 | Train Acc: 0.6957 | Val Loss: 0.9285 | Val Acc: 0.6624

Epoch 4/15


                                                                             

Train Loss: 0.7246 | Train Acc: 0.7330 | Val Loss: 0.9292 | Val Acc: 0.6698

Epoch 5/15


                                                                             

Train Loss: 0.6407 | Train Acc: 0.7664 | Val Loss: 0.9531 | Val Acc: 0.6682

Epoch 6/15


                                                                             

Train Loss: 0.5552 | Train Acc: 0.8001 | Val Loss: 0.9624 | Val Acc: 0.6769

Epoch 7/15


                                                                             

Train Loss: 0.4829 | Train Acc: 0.8267 | Val Loss: 1.0383 | Val Acc: 0.6718

Epoch 8/15


                                                                             

Train Loss: 0.3278 | Train Acc: 0.8880 | Val Loss: 1.0059 | Val Acc: 0.6955

Epoch 9/15


                                                                             

Train Loss: 0.2553 | Train Acc: 0.9171 | Val Loss: 1.0558 | Val Acc: 0.6877

Epoch 10/15


                                                                              

Train Loss: 0.2163 | Train Acc: 0.9296 | Val Loss: 1.1337 | Val Acc: 0.6849

Epoch 11/15


                                                                              

Train Loss: 0.1832 | Train Acc: 0.9407 | Val Loss: 1.1544 | Val Acc: 0.6893

Epoch 12/15


                                                                              

Train Loss: 0.1324 | Train Acc: 0.9608 | Val Loss: 1.1518 | Val Acc: 0.6943

Epoch 13/15


                                                                              

Train Loss: 0.1025 | Train Acc: 0.9710 | Val Loss: 1.1976 | Val Acc: 0.6960

Epoch 14/15


                                                                              

Train Loss: 0.0940 | Train Acc: 0.9727 | Val Loss: 1.2231 | Val Acc: 0.6913

Epoch 15/15


                                                                              

Train Loss: 0.0853 | Train Acc: 0.9747 | Val Loss: 1.2544 | Val Acc: 0.6948


