In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import aim
import torch.optim as optim

from library.epoch import train_epoch, evaluate_model
from library.model import VisualClassifier


os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Отключить параллелизм токенайзера


# Трансформации данных
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Датасеты и DataLoader
train_dataset = datasets.Flowers102(root="./data", split="train", transform=train_transform, download=True)
val_dataset = datasets.Flowers102(root="./data", split="val", transform=val_transform, download=True)
test_dataset = datasets.Flowers102(root="./data", split="test", transform=val_transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Основной цикл обучения
num_epochs = 20
learning_rate = 1e-4
hidden_dim = 1024
num_classes = 102  # Количество классов в Flowers102

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VisualClassifier("google/vit-base-patch16-224-in21k", num_classes, hidden_dim).to(device)

# Заморозка параметров ViT
for param in model.vision_model.parameters():
    param.requires_grad = False

# Оптимизация и обучение
optimizer = optim.AdamW(model.classifier.parameters(), lr=learning_rate, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

# Инициализация AIM
run = aim.Run(repo=".", experiment="flowers102_visual_classifier")
run["hparams"] = {"num_epochs": num_epochs, "learning_rate": learning_rate, "hidden_dim": hidden_dim}



In [3]:
# Цикл обучения
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    val_metrics = evaluate_model(model, val_loader, device, num_classes)

    # Логирование метрик
    run.track(train_loss, name="train_loss", step=epoch)
    for metric_name, metric_value in val_metrics.items():
        run.track(metric_value, name=metric_name, step=epoch, context={"subset": "validation"})

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_loss:.4f}, Validation Metrics: {val_metrics}")

# Финальное тестирование
test_metrics = evaluate_model(model, test_loader, device, num_classes)
print(f"Test Metrics: {test_metrics}")

run.close()

100%|██████████| 16/16 [00:03<00:00,  5.26it/s]
100%|██████████| 16/16 [00:02<00:00,  6.40it/s]


Epoch 1/20, Loss: 4.5711, Validation Metrics: {'accuracy': 0.2411764705882353, 'precision_weighted': 0.33549134845285644, 'recall_weighted': 0.2411764705882353, 'f1_score_weighted': 0.21420266497564225, 'precision_class_0': 0.45454545454545453, 'recall_class_0': 1.0, 'f1_score_class_0': 0.625, 'precision_class_1': 0.1038961038961039, 'recall_class_1': 0.8, 'f1_score_class_1': 0.1839080459770115, 'precision_class_2': 0.8333333333333334, 'recall_class_2': 0.5, 'f1_score_class_2': 0.625, 'precision_class_3': 0.0, 'recall_class_3': 0.0, 'f1_score_class_3': 0.0, 'precision_class_4': 0.6666666666666666, 'recall_class_4': 0.2, 'f1_score_class_4': 0.3076923076923077, 'precision_class_5': 0.09803921568627451, 'recall_class_5': 1.0, 'f1_score_class_5': 0.17857142857142858, 'precision_class_6': 0.2, 'recall_class_6': 0.1, 'f1_score_class_6': 0.13333333333333333, 'precision_class_7': 0.0, 'recall_class_7': 0.0, 'f1_score_class_7': 0.0, 'precision_class_8': 1.0, 'recall_class_8': 0.2, 'f1_score_cla

100%|██████████| 16/16 [00:02<00:00,  5.81it/s]
100%|██████████| 16/16 [00:02<00:00,  6.32it/s]


Epoch 2/20, Loss: 4.3806, Validation Metrics: {'accuracy': 0.7970588235294118, 'precision_weighted': 0.841626308877934, 'recall_weighted': 0.7970588235294118, 'f1_score_weighted': 0.7895113069690666, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.5882352941176471, 'recall_class_1': 1.0, 'f1_score_class_1': 0.7407407407407407, 'precision_class_2': 0.8333333333333334, 'recall_class_2': 1.0, 'f1_score_class_2': 0.9090909090909091, 'precision_class_3': 1.0, 'recall_class_3': 0.2, 'f1_score_class_3': 0.3333333333333333, 'precision_class_4': 0.8181818181818182, 'recall_class_4': 0.9, 'f1_score_class_4': 0.8571428571428571, 'precision_class_5': 0.5882352941176471, 'recall_class_5': 1.0, 'f1_score_class_5': 0.7407407407407407, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 1.0, 'recall_class_7': 0.9, 'f1_score_class_7': 0.9473684210526315, 'precision_class_8': 1.0, 'recall_class_8': 0.8, 'f1_score

100%|██████████| 16/16 [00:02<00:00,  5.78it/s]
100%|██████████| 16/16 [00:02<00:00,  6.28it/s]


Epoch 3/20, Loss: 4.1600, Validation Metrics: {'accuracy': 0.8862745098039215, 'precision_weighted': 0.9105038752097576, 'recall_weighted': 0.8862745098039215, 'f1_score_weighted': 0.8853479122813132, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 0.9090909090909091, 'recall_class_2': 1.0, 'f1_score_class_2': 0.9523809523809523, 'precision_class_3': 1.0, 'recall_class_3': 0.4, 'f1_score_class_3': 0.5714285714285714, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 1.0, 'recall_class_7': 0.9, 'f1_score_class_7': 0.9473684210526315, 'precision_class_8': 1.0, 'recall_class_8': 0.8, 'f1_score_class_8': 0.8888888888888888, 'precision_class_9': 1.0, 'r

100%|██████████| 16/16 [00:02<00:00,  5.81it/s]
100%|██████████| 16/16 [00:02<00:00,  6.43it/s]


Epoch 4/20, Loss: 3.9003, Validation Metrics: {'accuracy': 0.9225490196078432, 'precision_weighted': 0.9356926069000373, 'recall_weighted': 0.9225490196078432, 'f1_score_weighted': 0.9204514138367873, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.7692307692307693, 'recall_class_1': 1.0, 'f1_score_class_1': 0.8695652173913043, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.4, 'f1_score_class_3': 0.5714285714285714, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 1.0, 'recall_class_7': 1.0, 'f1_score_class_7': 1.0, 'precision_class_8': 1.0, 'recall_class_8': 0.9, 'f1_score_class_8': 0.9473684210526315, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score_class_9': 1.0,

100%|██████████| 16/16 [00:02<00:00,  5.66it/s]
100%|██████████| 16/16 [00:02<00:00,  6.39it/s]


Epoch 5/20, Loss: 3.6063, Validation Metrics: {'accuracy': 0.9441176470588235, 'precision_weighted': 0.9566571829729724, 'recall_weighted': 0.9441176470588235, 'f1_score_weighted': 0.9423873497641018, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 1.0, 'recall_class_1': 1.0, 'f1_score_class_1': 1.0, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.6, 'f1_score_class_3': 0.75, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 1.0, 'recall_class_7': 1.0, 'f1_score_class_7': 1.0, 'precision_class_8': 1.0, 'recall_class_8': 0.9, 'f1_score_class_8': 0.9473684210526315, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score_class_9': 1.0, 'precision_class_10': 0.8333333333333334, '

100%|██████████| 16/16 [00:02<00:00,  5.82it/s]
100%|██████████| 16/16 [00:02<00:00,  6.36it/s]


Epoch 6/20, Loss: 3.2824, Validation Metrics: {'accuracy': 0.9529411764705882, 'precision_weighted': 0.9633499469177219, 'recall_weighted': 0.9529411764705882, 'f1_score_weighted': 0.9524857067563315, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.6, 'f1_score_class_3': 0.75, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 1.0, 'recall_class_7': 1.0, 'f1_score_class_7': 1.0, 'precision_class_8': 1.0, 'recall_class_8': 0.9, 'f1_score_class_8': 0.9473684210526315, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score_class_9': 1.0, 'precision_cl

100%|██████████| 16/16 [00:02<00:00,  5.72it/s]
100%|██████████| 16/16 [00:02<00:00,  6.37it/s]


Epoch 7/20, Loss: 2.9500, Validation Metrics: {'accuracy': 0.9607843137254902, 'precision_weighted': 0.9657231113113467, 'recall_weighted': 0.9607843137254902, 'f1_score_weighted': 0.9592269586045963, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.7692307692307693, 'recall_class_1': 1.0, 'f1_score_class_1': 0.8695652173913043, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.7, 'f1_score_class_3': 0.8235294117647058, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score

100%|██████████| 16/16 [00:02<00:00,  5.78it/s]
100%|██████████| 16/16 [00:02<00:00,  6.35it/s]


Epoch 8/20, Loss: 2.6078, Validation Metrics: {'accuracy': 0.961764705882353, 'precision_weighted': 0.965904520316285, 'recall_weighted': 0.961764705882353, 'f1_score_weighted': 0.9607765901549008, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.8333333333333334, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9090909090909091, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.8, 'f1_score_class_3': 0.8888888888888888, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score_cl

100%|██████████| 16/16 [00:02<00:00,  5.78it/s]
100%|██████████| 16/16 [00:02<00:00,  6.40it/s]


Epoch 9/20, Loss: 2.2768, Validation Metrics: {'accuracy': 0.9598039215686275, 'precision_weighted': 0.9650132547191371, 'recall_weighted': 0.9598039215686275, 'f1_score_weighted': 0.9575858780496012, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.8333333333333334, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9090909090909091, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.6, 'f1_score_class_3': 0.75, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score_class_9': 1.0

100%|██████████| 16/16 [00:02<00:00,  5.78it/s]
100%|██████████| 16/16 [00:02<00:00,  6.36it/s]


Epoch 10/20, Loss: 1.9563, Validation Metrics: {'accuracy': 0.9607843137254902, 'precision_weighted': 0.965821759939407, 'recall_weighted': 0.9607843137254902, 'f1_score_weighted': 0.9593364553556329, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.7, 'f1_score_class_3': 0.8235294117647058, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 0.9090909090909091, 'recall_class_6': 1.0, 'f1_score_class_6': 0.9523809523809523, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'r

100%|██████████| 16/16 [00:02<00:00,  5.83it/s]
100%|██████████| 16/16 [00:02<00:00,  6.38it/s]


Epoch 11/20, Loss: 1.6509, Validation Metrics: {'accuracy': 0.9715686274509804, 'precision_weighted': 0.9743018419489009, 'recall_weighted': 0.9715686274509804, 'f1_score_weighted': 0.9708011202729834, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.7, 'f1_score_class_3': 0.8235294117647058, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 0.9090909090909091, 'recall_class_6': 1.0, 'f1_score_class_6': 0.9523809523809523, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, '

100%|██████████| 16/16 [00:02<00:00,  5.71it/s]
100%|██████████| 16/16 [00:02<00:00,  6.41it/s]


Epoch 12/20, Loss: 1.3747, Validation Metrics: {'accuracy': 0.9676470588235294, 'precision_weighted': 0.9712036329683388, 'recall_weighted': 0.9676470588235294, 'f1_score_weighted': 0.9661151541956495, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.9, 'f1_score_class_3': 0.9473684210526315, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 0.9090909090909091, 'recall_class_6': 1.0, 'f1_score_class_6': 0.9523809523809523, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, '

100%|██████████| 16/16 [00:02<00:00,  5.81it/s]
100%|██████████| 16/16 [00:02<00:00,  6.34it/s]


Epoch 13/20, Loss: 1.1370, Validation Metrics: {'accuracy': 0.9725490196078431, 'precision_weighted': 0.9749851455733809, 'recall_weighted': 0.9725490196078431, 'f1_score_weighted': 0.9713280615119988, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.7, 'f1_score_class_3': 0.8235294117647058, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_scor

100%|██████████| 16/16 [00:02<00:00,  5.79it/s]
100%|██████████| 16/16 [00:02<00:00,  6.42it/s]


Epoch 14/20, Loss: 0.9480, Validation Metrics: {'accuracy': 0.9735294117647059, 'precision_weighted': 0.9758764111705289, 'recall_weighted': 0.9735294117647059, 'f1_score_weighted': 0.9724356941075207, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.8, 'f1_score_class_3': 0.8888888888888888, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 0.9090909090909091, 'recall_class_6': 1.0, 'f1_score_class_6': 0.9523809523809523, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, '

100%|██████████| 16/16 [00:02<00:00,  5.80it/s]
100%|██████████| 16/16 [00:02<00:00,  6.26it/s]


Epoch 15/20, Loss: 0.7906, Validation Metrics: {'accuracy': 0.9725490196078431, 'precision_weighted': 0.974589027530204, 'recall_weighted': 0.9725490196078431, 'f1_score_weighted': 0.9713627503101185, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.8, 'f1_score_class_3': 0.8888888888888888, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score

100%|██████████| 16/16 [00:02<00:00,  5.77it/s]
100%|██████████| 16/16 [00:02<00:00,  6.33it/s]


Epoch 16/20, Loss: 0.6645, Validation Metrics: {'accuracy': 0.9745098039215686, 'precision_weighted': 0.9766191325014856, 'recall_weighted': 0.9745098039215686, 'f1_score_weighted': 0.9733761021686717, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.9, 'f1_score_class_3': 0.9473684210526315, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 0.9090909090909091, 'recall_class_6': 1.0, 'f1_score_class_6': 0.9523809523809523, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, '

100%|██████████| 16/16 [00:02<00:00,  5.79it/s]
100%|██████████| 16/16 [00:02<00:00,  6.31it/s]


Epoch 17/20, Loss: 0.5623, Validation Metrics: {'accuracy': 0.9735294117647059, 'precision_weighted': 0.9760992275698158, 'recall_weighted': 0.9735294117647059, 'f1_score_weighted': 0.9725205765453443, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.9, 'f1_score_class_3': 0.9473684210526315, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_scor

100%|██████████| 16/16 [00:02<00:00,  5.84it/s]
100%|██████████| 16/16 [00:02<00:00,  6.37it/s]


Epoch 18/20, Loss: 0.4823, Validation Metrics: {'accuracy': 0.9754901960784313, 'precision_weighted': 0.9775103980986334, 'recall_weighted': 0.9754901960784313, 'f1_score_weighted': 0.974416284323405, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.9, 'f1_score_class_3': 0.9473684210526315, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_score

100%|██████████| 16/16 [00:02<00:00,  5.72it/s]
100%|██████████| 16/16 [00:02<00:00,  6.33it/s]


Epoch 19/20, Loss: 0.4176, Validation Metrics: {'accuracy': 0.9754901960784313, 'precision_weighted': 0.9777332144979204, 'recall_weighted': 0.9754901960784313, 'f1_score_weighted': 0.9745011667612287, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.9, 'f1_score_class_3': 0.9473684210526315, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_scor

100%|██████████| 16/16 [00:02<00:00,  5.75it/s]
100%|██████████| 16/16 [00:02<00:00,  6.27it/s]


Epoch 20/20, Loss: 0.3661, Validation Metrics: {'accuracy': 0.9754901960784313, 'precision_weighted': 0.9777332144979204, 'recall_weighted': 0.9754901960784313, 'f1_score_weighted': 0.9745011667612287, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 0.9090909090909091, 'recall_class_1': 1.0, 'f1_score_class_1': 0.9523809523809523, 'precision_class_2': 1.0, 'recall_class_2': 1.0, 'f1_score_class_2': 1.0, 'precision_class_3': 1.0, 'recall_class_3': 0.9, 'f1_score_class_3': 0.9473684210526315, 'precision_class_4': 1.0, 'recall_class_4': 1.0, 'f1_score_class_4': 1.0, 'precision_class_5': 1.0, 'recall_class_5': 1.0, 'f1_score_class_5': 1.0, 'precision_class_6': 1.0, 'recall_class_6': 1.0, 'f1_score_class_6': 1.0, 'precision_class_7': 0.9090909090909091, 'recall_class_7': 1.0, 'f1_score_class_7': 0.9523809523809523, 'precision_class_8': 1.0, 'recall_class_8': 1.0, 'f1_score_class_8': 1.0, 'precision_class_9': 1.0, 'recall_class_9': 1.0, 'f1_scor

100%|██████████| 97/97 [00:13<00:00,  7.16it/s]


Test Metrics: {'accuracy': 0.964221824686941, 'precision_weighted': 0.9666759420735439, 'recall_weighted': 0.964221824686941, 'f1_score_weighted': 0.9638471307584779, 'precision_class_0': 1.0, 'recall_class_0': 1.0, 'f1_score_class_0': 1.0, 'precision_class_1': 1.0, 'recall_class_1': 0.975, 'f1_score_class_1': 0.9873417721518988, 'precision_class_2': 0.8181818181818182, 'recall_class_2': 0.9, 'f1_score_class_2': 0.8571428571428571, 'precision_class_3': 0.8571428571428571, 'recall_class_3': 0.8333333333333334, 'f1_score_class_3': 0.8450704225352113, 'precision_class_4': 1.0, 'recall_class_4': 0.9777777777777777, 'f1_score_class_4': 0.9887640449438202, 'precision_class_5': 0.8928571428571429, 'recall_class_5': 1.0, 'f1_score_class_5': 0.9433962264150944, 'precision_class_6': 0.8333333333333334, 'recall_class_6': 1.0, 'f1_score_class_6': 0.9090909090909091, 'precision_class_7': 1.0, 'recall_class_7': 1.0, 'f1_score_class_7': 1.0, 'precision_class_8': 1.0, 'recall_class_8': 0.9230769230769

In [4]:
{'accuracy': 0.964221824686941, 'precision_weighted': 0.9666759420735439, 'recall_weighted': 0.964221824686941, 'f1_score_weighted': 0.9638471307584779,}

SyntaxError: invalid syntax (3845707928.py, line 1)