In [2]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision import models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data import WeightedRandomSampler


train_dataset_path = "./SimpsonsDataset/train"
test_dataset_path = "./SimpsonsDataset/test"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Просмотр кол-ва изображений на класс

Просмотрев кол-во сэмплов для каждого класса, можем сделать вывод, что датасет плохо сбалансирован, т.к. между классами разница в кол-ве сэмплов достигает сотен.

In [3]:
def characters_count(dataset_path):
    """Подсчитывает кол-во сэмплов на каждый класс"""
    character_count = {}

    for character in os.listdir(dataset_path):
        character_path = os.path.join(dataset_path, character)
        num_files = len(os.listdir(character_path))
        character_count[character] = num_files
        
    return character_count
        
cls_cnt = characters_count(train_dataset_path)
sorted_counts = sorted(cls_cnt.items(), key=lambda x: x[1])
for character, count in sorted_counts:
    print(f"{character}: {count} images")

cletus_spuckler: 56 images
lionel_hutz: 56 images
agnes_skinner: 57 images
martin_prince: 57 images
patty_bouvier: 58 images
professor_john_frink: 58 images
rainier_wolfcastle: 58 images
sideshow_mel: 58 images
snake_jailbird: 60 images
otto_mann: 64 images
miss_hoover: 65 images
disco_stu: 66 images
gil: 66 images
fat_tony: 68 images
troy_mcclure: 68 images
ralph_wiggum: 71 images
carl_carlson: 78 images
selma_bouvier: 83 images
barney_gumble: 85 images
groundskeeper_willie: 100 images
maggie_simpson: 102 images
waylon_smithers: 145 images
mayor_quimby: 197 images
lenny_leonard: 248 images
nelson_muntz: 288 images
edna_krabappel: 366 images
comic_book_guy: 376 images
kent_brockman: 398 images
apu_nahasapeemapetilon: 498 images
sideshow_bob: 702 images
abraham_grampa_simpson: 731 images
chief_wiggum: 789 images
milhouse_van_houten: 863 images
charles_montgomery_burns: 955 images
principal_skinner: 956 images
krusty_the_clown: 965 images
marge_simpson: 1033 images
bart_simpson: 1074 ima

# Рассчет весов классов и сэмплов для балансировки данных

При обучении нейросети на плохо сбалансированных данных, нейросеть будет отдавать предпочтение классам с большим кол-вом сэмплов. Чтобы это исправить мы можем использовать **WeightedRandomSampler**, который позволит пропорцианально выбирать сэмплы на основе их весов. Таким образом сэмплы из малочисленных классов получают большие веса и соответсвтенно большую вероятность выбора, а из многочисленных классов - меньшую.

In [None]:
def compute_class_weights(class_counts):
    """Рассчет весов для каждого класса"""
    class_weights = {}

    for cls, count in class_counts.items():
        class_weights[cls] = 1 / count

    return class_weights

def compute_sample_weights(class_counts):
    """Подсчитывает веса для сэмплов в каждом классе"""
    sample_weights = []

    for character in os.listdir(train_dataset_path):
        character_path = os.path.join(train_dataset_path, character)
        num_files = len(os.listdir(character_path))

        weight = 1 / num_files
        sample_weights.extend([weight] * num_files)
        
    return torch.tensor(sample_weights, dtype=torch.float).to(device)

sample_weights = compute_sample_weights(cls_cnt)

class_weights = compute_class_weights(cls_cnt)
sorted_weights = sorted(class_weights.items(), key=lambda x: x[1])
for cls, weight in sorted_weights:
    print(f"{cls}: {weight:.4f}")

homer_simpson: 0.0006
ned_flanders: 0.0009
moe_szyslak: 0.0009
lisa_simpson: 0.0009
bart_simpson: 0.0009
marge_simpson: 0.0010
krusty_the_clown: 0.0010
principal_skinner: 0.0010
charles_montgomery_burns: 0.0010
milhouse_van_houten: 0.0012
chief_wiggum: 0.0013
abraham_grampa_simpson: 0.0014
sideshow_bob: 0.0014
apu_nahasapeemapetilon: 0.0020
kent_brockman: 0.0025
comic_book_guy: 0.0027
edna_krabappel: 0.0027
nelson_muntz: 0.0035
lenny_leonard: 0.0040
mayor_quimby: 0.0051
waylon_smithers: 0.0069
maggie_simpson: 0.0098
groundskeeper_willie: 0.0100
barney_gumble: 0.0118
selma_bouvier: 0.0120
carl_carlson: 0.0128
ralph_wiggum: 0.0141
fat_tony: 0.0147
troy_mcclure: 0.0147
disco_stu: 0.0152
gil: 0.0152
miss_hoover: 0.0154
otto_mann: 0.0156
snake_jailbird: 0.0167
patty_bouvier: 0.0172
professor_john_frink: 0.0172
rainier_wolfcastle: 0.0172
sideshow_mel: 0.0172
agnes_skinner: 0.0175
martin_prince: 0.0175
cletus_spuckler: 0.0179
lionel_hutz: 0.0179


# Аугментации данных

Аугментации позволят создать больше примеров для обучения путем случайных изменений изображений. Помимо того, что это помогает создавать больше сэмплов для малочисленных классов, аугментации так же улучшают обобщающие способности модели, делая ее более устойчивой к меняющимся условиям.

In [5]:
def get_transforms(is_train=False):
    """Предобработка изображений
    
        is_train = True: Добавляет аугментации во время обучения
        
        is_tain = False: Обычная предобработка
    """
    if is_train:
        return T.Compose([
            T.Resize((224, 224)),            
            T.RandomHorizontalFlip(),        
            T.RandomResizedCrop(224, scale=(0.8, 1.0)),  
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            T.RandomRotation(degrees=15),
            T.ToTensor(),      
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    else:
        return T.Compose([
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

# Используем модель ResNet18

In [6]:
train_set = ImageFolder(root=train_dataset_path, transform=get_transforms(is_train=True))
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
train_loader = DataLoader(train_set, batch_size=128, sampler=sampler, num_workers=4)

model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

for param in model.parameters():
    param.requires_grad = False

for name, child in model.named_children():
    if name in ['layer4','fc']:
        for param in child.parameters():
            param.requires_grad = True
            
num_classes = len(cls_cnt)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10


# Тренировочный цикл

In [None]:
classes_correct = {cls: 0 for cls in cls_cnt.keys()}
classes_total = {cls: 0 for cls in cls_cnt.keys()}
model.train()
for epoch in range(1, num_epochs+1):
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        _, predicted = torch.max(outputs.data,dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        for label,prediction in zip(labels, predicted):
            if label == prediction:
                classes_correct[train_set.classes[label.item()]] += 1
            classes_total[train_set.classes[label.item()]] += 1
            
    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = 100 * correct / total
    
    print(f"Epoch [{epoch}/{num_epochs}], "f"Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%")

for cls in classes_correct.keys():
    class_accuracy = 100 * classes_correct[cls] / classes_total[cls]
    print(f"Training accuracy for class {cls}: {class_accuracy:.2f}%, {classes_correct[cls]}/{classes_total[cls]}")
    
torch.save(model.state_dict(), "simpsons_model.pth")

Epoch [1/10], Train Loss: 0.8602, Train Accuracy: 79.35%
Epoch [2/10], Train Loss: 0.2629, Train Accuracy: 93.26%
Epoch [3/10], Train Loss: 0.1995, Train Accuracy: 94.64%
Epoch [4/10], Train Loss: 0.1541, Train Accuracy: 95.74%
Epoch [5/10], Train Loss: 0.1287, Train Accuracy: 96.45%
Epoch [6/10], Train Loss: 0.1150, Train Accuracy: 96.74%
Epoch [7/10], Train Loss: 0.0982, Train Accuracy: 97.20%
Epoch [8/10], Train Loss: 0.1002, Train Accuracy: 97.14%
Epoch [9/10], Train Loss: 0.0838, Train Accuracy: 97.53%
Epoch [10/10], Train Loss: 0.0841, Train Accuracy: 97.54%
Training accuracy for class abraham_grampa_simpson: 91.92%, 3789/4122
Training accuracy for class agnes_skinner: 97.72%, 4073/4168
Training accuracy for class apu_nahasapeemapetilon: 94.01%, 3927/4177
Training accuracy for class barney_gumble: 95.01%, 3944/4151
Training accuracy for class bart_simpson: 90.77%, 3709/4086
Training accuracy for class carl_carlson: 98.14%, 4068/4145
Training accuracy for class charles_montgomery_

# Оценка точности модели на тестовых данных

In [7]:
classes_correct_test = {cls: 0 for cls in cls_cnt.keys()}
classes_total_test = {cls: 0 for cls in cls_cnt.keys()}
test_set = ImageFolder(root=test_dataset_path, transform=get_transforms(is_train=False))
test_loader = DataLoader(test_set, batch_size=128, num_workers=4, shuffle=False)
model.load_state_dict(torch.load("simpsons_model.pth"))
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)

        _, predicted = torch.max(outputs.data,dim=1)
        
        total += labels.size(0)
        
        correct += (predicted == labels).sum().item()
        
        for label,prediction in zip(labels, predicted):
            if label == prediction:
                classes_correct_test[test_set.classes[label.item()]] += 1
            classes_total_test[test_set.classes[label.item()]] += 1

accuracy = 100 * correct/total
print(f"Average accuracy: {accuracy:.2f}%")
for cls in classes_correct_test.keys():
    class_accuracy_test = 100 * classes_correct_test[cls] / classes_total_test[cls]
    print(f"Testing accuracy for class {cls}: {class_accuracy_test:.2f}%, {classes_correct_test[cls]}/{classes_total_test[cls]}")

Average accuracy: 93.28%
Testing accuracy for class abraham_grampa_simpson: 95.60%, 174/182
Testing accuracy for class agnes_skinner: 100.00%, 14/14
Testing accuracy for class apu_nahasapeemapetilon: 94.35%, 117/124
Testing accuracy for class barney_gumble: 85.71%, 18/21
Testing accuracy for class bart_simpson: 91.04%, 244/268
Testing accuracy for class carl_carlson: 95.00%, 19/20
Testing accuracy for class charles_montgomery_burns: 92.44%, 220/238
Testing accuracy for class chief_wiggum: 95.43%, 188/197
Testing accuracy for class cletus_spuckler: 85.71%, 12/14
Testing accuracy for class comic_book_guy: 91.40%, 85/93
Testing accuracy for class disco_stu: 88.24%, 15/17
Testing accuracy for class edna_krabappel: 92.31%, 84/91
Testing accuracy for class fat_tony: 94.12%, 16/17
Testing accuracy for class gil: 75.00%, 12/16
Testing accuracy for class groundskeeper_willie: 80.95%, 17/21
Testing accuracy for class homer_simpson: 88.64%, 398/449
Testing accuracy for class kent_brockman: 98.99%