In [None]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision import models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data import WeightedRandomSampler
import matplotlib.pyplot as plt
from torch.utils.data import random_split
from torchmetrics.classification import MulticlassF1Score, MulticlassPrecision, MulticlassRecall


train_dataset_path = "./SimpsonsDataset/train"
val_dataset_path = "./SimpsonsDataset/val"
test_dataset_path = "./SimpsonsDataset/test"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Просмотр кол-ва изображений на класс

Просмотрев кол-во сэмплов для каждого класса, можем сделать вывод, что датасет плохо сбалансирован, т.к. между классами разница в кол-ве сэмплов достигает сотен.

In [2]:
def characters_count(dataset_path):
    """Подсчитывает кол-во сэмплов на каждый класс"""
    character_count = {}

    for character in os.listdir(dataset_path):
        character_path = os.path.join(dataset_path, character)
        num_files = len(os.listdir(character_path))
        character_count[character] = num_files
        
    return character_count
        
cls_cnt = characters_count(train_dataset_path)
print(len(cls_cnt))
sorted_counts = sorted(cls_cnt.items(), key=lambda x: x[1])
for character, count in sorted_counts:
    print(f"{character}: {count} images")

42
cletus_spuckler: 50 images
lionel_hutz: 50 images
agnes_skinner: 51 images
martin_prince: 51 images
patty_bouvier: 52 images
professor_john_frink: 52 images
rainier_wolfcastle: 52 images
sideshow_mel: 52 images
snake_jailbird: 54 images
otto_mann: 58 images
disco_stu: 59 images
gil: 59 images
miss_hoover: 59 images
fat_tony: 61 images
troy_mcclure: 61 images
ralph_wiggum: 64 images
carl_carlson: 70 images
selma_bouvier: 75 images
barney_gumble: 77 images
groundskeeper_willie: 90 images
maggie_simpson: 92 images
waylon_smithers: 131 images
mayor_quimby: 177 images
lenny_leonard: 223 images
nelson_muntz: 259 images
edna_krabappel: 329 images
comic_book_guy: 338 images
kent_brockman: 358 images
apu_nahasapeemapetilon: 448 images
sideshow_bob: 632 images
abraham_grampa_simpson: 658 images
chief_wiggum: 710 images
milhouse_van_houten: 777 images
charles_montgomery_burns: 860 images
principal_skinner: 860 images
krusty_the_clown: 869 images
marge_simpson: 930 images
bart_simpson: 967 imag

# Рассчет весов классов и сэмплов для балансировки данных

При обучении нейросети на плохо сбалансированных данных, нейросеть будет отдавать предпочтение классам с большим кол-вом сэмплов. Чтобы это исправить мы можем использовать **WeightedRandomSampler**, который позволит пропорцианально выбирать сэмплы на основе их весов. Таким образом сэмплы из малочисленных классов получают большие веса и соответсвтенно большую вероятность выбора, а из многочисленных классов - меньшую.

In [3]:
def compute_class_weights(class_counts):
    """Рассчет весов для каждого класса"""
    class_weights = {}

    for cls, count in class_counts.items():
        class_weights[cls] = 1 / count

    return class_weights

def compute_sample_weights(class_counts):
    """Подсчитывает веса для сэмплов в каждом классе"""
    sample_weights = []

    for character in os.listdir(train_dataset_path):
        character_path = os.path.join(train_dataset_path, character)
        num_files = len(os.listdir(character_path))

        weight = 1 / num_files
        sample_weights.extend([weight] * num_files)
        
    return torch.tensor(sample_weights, dtype=torch.float).to(device)

sample_weights = compute_sample_weights(cls_cnt)

class_weights = compute_class_weights(cls_cnt)
sorted_weights = sorted(class_weights.items(), key=lambda x: x[1])
for cls, weight in sorted_weights:
    print(f"{cls}: {weight:.4f}")

homer_simpson: 0.0006
ned_flanders: 0.0010
moe_szyslak: 0.0010
lisa_simpson: 0.0010
bart_simpson: 0.0010
marge_simpson: 0.0011
krusty_the_clown: 0.0012
charles_montgomery_burns: 0.0012
principal_skinner: 0.0012
milhouse_van_houten: 0.0013
chief_wiggum: 0.0014
abraham_grampa_simpson: 0.0015
sideshow_bob: 0.0016
apu_nahasapeemapetilon: 0.0022
kent_brockman: 0.0028
comic_book_guy: 0.0030
edna_krabappel: 0.0030
nelson_muntz: 0.0039
lenny_leonard: 0.0045
mayor_quimby: 0.0056
waylon_smithers: 0.0076
maggie_simpson: 0.0109
groundskeeper_willie: 0.0111
barney_gumble: 0.0130
selma_bouvier: 0.0133
carl_carlson: 0.0143
ralph_wiggum: 0.0156
fat_tony: 0.0164
troy_mcclure: 0.0164
disco_stu: 0.0169
gil: 0.0169
miss_hoover: 0.0169
otto_mann: 0.0172
snake_jailbird: 0.0185
patty_bouvier: 0.0192
professor_john_frink: 0.0192
rainier_wolfcastle: 0.0192
sideshow_mel: 0.0192
agnes_skinner: 0.0196
martin_prince: 0.0196
cletus_spuckler: 0.0200
lionel_hutz: 0.0200


# Аугментации данных

Аугментации позволят создать больше примеров для обучения путем случайных изменений изображений. Помимо того, что это помогает создавать больше сэмплов для малочисленных классов, аугментации так же улучшают обобщающие способности модели, делая ее более устойчивой к меняющимся условиям.

In [4]:
def get_transforms(is_train=False):
    """Предобработка изображений
    
        is_train = True: Добавляет аугментации во время обучения
        
        is_tain = False: Обычная предобработка
    """
    if is_train:
        return T.Compose([
            T.Resize((224, 224)),            
            T.RandomHorizontalFlip(),        
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),  
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            T.RandomRotation(degrees=15),
            T.ToTensor(),      
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        return T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

# Используем модель ResNet18

In [31]:
train_set = ImageFolder(root=train_dataset_path, transform=get_transforms(is_train=True))
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
train_loader = DataLoader(train_set, batch_size=128, sampler=sampler, num_workers=4)

val_set = ImageFolder(root=val_dataset_path,transform=get_transforms(is_train=False))
val_loader = DataLoader(val_set, batch_size=128,shuffle=False, num_workers=4)

model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

for param in model.parameters():
    param.requires_grad = False

for name, child in model.named_children():
    if name in ['layer4','fc']:
        for param in child.parameters():
            param.requires_grad = True
            
num_classes = len(cls_cnt)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 11

num_classes = len(cls_cnt)

# Тренировочный цикл

In [None]:
val_f1_metric = MulticlassF1Score(num_classes=num_classes, average="macro").to(device)
for epoch in range(1, num_epochs+1):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        _, predicted = torch.max(outputs.data,dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100 * correct / total
    
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    val_f1_metric.reset()
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            val_f1_metric.update(predicted, labels)

    val_loss /= len(val_loader)
    val_accuracy = 100 * correct / total
    val_f1 = val_f1_metric.compute().item()
    
    
    print(f"Epoch [{epoch}/{num_epochs}], "f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%, Val loss: {val_loss:.4f}, Val accuracy: {val_accuracy:.2f}, Val f1-score: {val_f1:2f}")

    
torch.save(model.state_dict(), "simpsons_model.pth")

Epoch [1/11], Train Loss: 2.3407, Train Accuracy: 50.79%, Val loss: 1.3902, Val accuracy: 73.65, Val f1-score: 0.622467
Epoch [2/11], Train Loss: 0.8840, Train Accuracy: 84.97%, Val loss: 0.7818, Val accuracy: 85.20, Val f1-score: 0.752839
Epoch [3/11], Train Loss: 0.4900, Train Accuracy: 91.63%, Val loss: 0.5388, Val accuracy: 89.09, Val f1-score: 0.816090
Epoch [4/11], Train Loss: 0.3337, Train Accuracy: 94.05%, Val loss: 0.4411, Val accuracy: 91.00, Val f1-score: 0.858145
Epoch [5/11], Train Loss: 0.2407, Train Accuracy: 95.73%, Val loss: 0.3692, Val accuracy: 91.76, Val f1-score: 0.871899
Epoch [6/11], Train Loss: 0.1850, Train Accuracy: 96.68%, Val loss: 0.3245, Val accuracy: 92.98, Val f1-score: 0.891713
Epoch [7/11], Train Loss: 0.1552, Train Accuracy: 97.07%, Val loss: 0.3055, Val accuracy: 93.33, Val f1-score: 0.892583
Epoch [8/11], Train Loss: 0.1296, Train Accuracy: 97.61%, Val loss: 0.2967, Val accuracy: 93.44, Val f1-score: 0.892217
Epoch [9/11], Train Loss: 0.1052, Train 

# Оценка точности модели на тестовых данных

In [36]:
test_set = ImageFolder(root=test_dataset_path, transform=get_transforms(is_train=False))
test_loader = DataLoader(test_set, batch_size=128, num_workers=4, shuffle=False)
model.load_state_dict(torch.load("simpsons_model.pth"))
model.eval()
correct = 0
total = 0

f1_metric = MulticlassF1Score(num_classes=num_classes, average=None).to(device)
precision_metric = MulticlassPrecision(num_classes=num_classes, average=None).to(device)
recall_metric = MulticlassRecall(num_classes=num_classes, average=None).to(device)

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)

        _, predicted = torch.max(outputs.data,dim=1)
        
        total += labels.size(0)
        
        correct += (predicted == labels).sum().item()
        
        precision_metric.update(predicted, labels)
        recall_metric.update(predicted, labels)
        f1_metric.update(predicted, labels)

precision = precision_metric.compute()
recall = recall_metric.compute()
f1 = f1_metric.compute()

accuracy = 100 * correct/total
print(f"Accuracy: {accuracy:.2f}%")

for i, cls in enumerate(test_set.classes):
    print(f"Class: {cls}")
    print(f" Precision: {precision[i].item():.4f}")
    print(f" Recall: {recall[i].item():.4f}")
    print(f" F1-score: {f1[i].item():.4f}\n")


Accuracy: 93.82%
Class: abraham_grampa_simpson
 Precision: 0.9500
 Recall: 0.9396
 F1-score: 0.9448

Class: agnes_skinner
 Precision: 1.0000
 Recall: 0.9286
 F1-score: 0.9630

Class: apu_nahasapeemapetilon
 Precision: 0.9297
 Recall: 0.9597
 F1-score: 0.9444

Class: barney_gumble
 Precision: 0.7778
 Recall: 0.6667
 F1-score: 0.7179

Class: bart_simpson
 Precision: 0.9843
 Recall: 0.9366
 F1-score: 0.9598

Class: carl_carlson
 Precision: 0.9444
 Recall: 0.8500
 F1-score: 0.8947

Class: charles_montgomery_burns
 Precision: 0.9109
 Recall: 0.9454
 F1-score: 0.9278

Class: chief_wiggum
 Precision: 0.9541
 Recall: 0.9492
 F1-score: 0.9517

Class: cletus_spuckler
 Precision: 0.7647
 Recall: 0.9286
 F1-score: 0.8387

Class: comic_book_guy
 Precision: 0.8947
 Recall: 0.9140
 F1-score: 0.9043

Class: disco_stu
 Precision: 0.8889
 Recall: 0.9412
 F1-score: 0.9143

Class: edna_krabappel
 Precision: 0.9362
 Recall: 0.9670
 F1-score: 0.9514

Class: fat_tony
 Precision: 0.8947
 Recall: 1.0000
 F1-sc