# Импорт библиотек и загрузка данных

In [13]:
import json
import cv2 
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import seaborn as sns
import pandas as pd
from collections import Counter
from functools import lru_cache
import os
import shutil
import re
import logging
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms, models
from torch import nn
from torch import optim
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
import torch
import timm

  from .autonotebook import tqdm as notebook_tqdm


## Считываем данные

In [2]:
class ImageDataset:
    def __init__(self, path, path2image, n_classes=15):
        # Загрузка данных из CSV
        self.df = pd.read_csv(path)
        self.path2image = Path(path2image)
        self.N = n_classes
        
    def __len__(self):
        return len(self.df)

    def _encode_label(self, classes):
        label = np.zeros(self.N)
        for c in np.array(classes, dtype=np.int16) - 1:
            label[c] = 1
        return label
        
    def sample(self, clss=1, count=15):
        files = []
        while len(files) < count:
            indx = np.random.randint(0, len(self.df))
            labels = self.df.iloc[indx]["OUTPUT:classes"]
            try:
                if clss in self.parse(labels):
                    files.append(self.path2image / self.file_name(indx))
            except:
                continue
        return files
    
    def create_path(self, indx):
        return str(self.path2image / self.file_name(indx))
        
    def __getitem__(self, indx):
        img = cv2.imread(self.create_path(indx))
        clases = self.parse(self.df.iloc[indx]["OUTPUT:classes"])
        label = self._encode_label(clases)
        sample = {'image': img,
                  'label': label}
        return sample
            
    def file_name(self, indx):
        return self.df.iloc[indx]['file_name']
    
    def resize(self, img, scale=2):    
        shape = (np.array(img.shape[:2]) / scale).astype(np.int32)
        return cv2.resize(img, shape)
    
    def parse(self, st):        
        # Преобразуем строку вида '[2, 5, 11]' в список [2, 5, 11]
        return [int(x) for x in re.findall("\d+", st)]

  return [int(x) for x in re.findall("\d+", st)]


In [3]:
class TorchImageDataset(ImageDataset, Dataset):
    def __init__(self, path, path2image, imgsz=256, MEAN = (0.485, 0.456, 0.406), STD = (0.229, 0.224, 0.225)):
        super().__init__(path, path2image)
        self.imgsz = imgsz
        self.fransform_img= transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize((self.imgsz, self.imgsz)),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(10),    
                    transforms.ColorJitter(brightness=0.2, 
                                                   contrast=0.2, 
                                                   saturation=0.2, 
                                                   hue=0.2),
                    transforms.Normalize(mean=MEAN, std=STD),
        ])
            
    @lru_cache(10000)
    def __getitem__(self,indx):
        img, label = super().__getitem__(indx).values()
        if img is not None:
            img = self.fransform_img(img)
            return {'image':img,'label':label}

## Подготовка файла

## Получаем датасет

In [4]:
dataset = TorchImageDataset("./analis/data.csv",  "../Resize/")

## DataLoader

In [5]:
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():  
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device       

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            img,label = tuple(b.values())
            img = to_device(img,self.device)
            label = to_device(label, self.device)
            yield {"image":img,"label":label}

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [6]:
def split_data(dataset, train_ratio=0.7, val_ratio=0.1):

    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)
    test_size = total_size - train_size - val_size

    train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

    return train_data, val_data, test_data

In [7]:
train_data, val_data, test_data = split_data(dataset, train_ratio=0.6, val_ratio=0.05)

In [55]:
type(dataset[0])

dict

In [9]:
device = torch.device("cuda:0")

In [10]:
len(dataset)

14594

In [11]:
train_loader = DeviceDataLoader(DataLoader(train_data, batch_size=16,num_workers = 0), device=device)
val_loader   = DeviceDataLoader(DataLoader(val_data, batch_size=16,num_workers = 0), device=device)
test_loader   = DeviceDataLoader(DataLoader(test_data, batch_size=16,num_workers = 0), device=device)

# Обучение модели

In [14]:
class ModelMultilabel(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.n_classes = n_classes        
        self.model = None
        self.model = self.__efficientnet_b5()
        
    def __efficientnet_b5(self):
        model = models.efficientnet_b5(pretrained=True)
        model.classifier[1] = nn.Linear(2048,self.n_classes)
        return model
    def forward(self, input):
        out = self.model(input)
        out = F.sigmoid(out)
        return out

def train_model(model, dataloader_train, dataloader_valid, learningRate, num_epochs, device = 'cuda:0'):
    logger = logging.getLogger('api_log')
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(f'log_{learningRate}.txt')
    file_handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.debug(f"Start_train  lr={learningRate}")
    
    # Установка оптимизатора с текущей скоростью обучения
    optimizer = optim.Adam(model.parameters(), lr=learningRate)
    criterion = nn.CrossEntropyLoss()
    # Цикл обучения
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        all_labels = []
        all_predictions = []

        for s in tqdm(dataloader_train, desc=f'Training Epoch {epoch+1}/{num_epochs}'):
            if s is None:
                continue
            inputs = s['image']
            labels = s['label']
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)  # Выходы модели
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            predicted = outputs
            all_labels.extend(labels.detach().cpu().numpy())
            all_predictions.extend(predicted.detach().cpu().numpy())
        # Средняя потеря за эпоху
        epoch_loss = running_loss / len(dataloader_train)
        # precision = precision_score(all_labels, all_predictions, average='weighted')
        # recall = recall_score(all_labels, all_predictions, average='weighted')
        # accuracy = accuracy_score(all_labels, all_predictions)
        f_score = F_score(all_predictions, all_labels)
        # print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, "
        #       f"Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f}, F_score: {f_score:.4f}")
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, F_score: {f_score:.4f}")
        
        # logger.debug(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f},  F_score: {f_score:.4f}")
        validate_model(model, dataloader_valid, logger)    
    # Сохранение модели
    torch.save(model.state_dict(), f'model_{learningRate}.pth')
    
def validate_model(model, dataloader_valid, logger, device='cuda:0'):
    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for s in tqdm(dataloader_valid, desc='Validation'):
            if s is None:
                continue
            inputs = s['image']
            labels = s['label']
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            predicted = outputs
            all_labels.extend(labels.detach().cpu().numpy())
            all_predictions.extend(predicted.detach().cpu().numpy())

    # precision = precision_score(all_labels, all_predictions, average='weighted')
    # recall = recall_score(all_labels, all_predictions, average='weighted')
    # accuracy = accuracy_score(all_labels, all_predictions)
    f_score = F_score(all_predictions, all_labels)
    # print(f"Validation Results - Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f}")
    # logger.debug(f"Validation Results - Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f}, F_score: {f_score:.4f}")
    logger.info(f"Validation Results -  F_score: {f_score:.4f}")


def F_score(output, label, threshold=0.5, beta=1): #Calculate the accuracy of the model
    if isinstance(output, list):
        output = np.array(output)
    if isinstance(label, list):
        label = np.array(label)
    print(label.shape, output.shape)
    assert (output.shape == label.shape),'shape is different'
    prob = output > threshold
    label = label > threshold

    TP = (prob == label).sum()
    TN = (np.bitwise_not(prob) == np.bitwise_not(label)).sum()
    FP = (prob == np.bitwise_not(label)).sum()
    FN = (np.bitwise_not(prob) == label).sum()
    print(TP, TN, FP, FN)
    precision = np.mean(TP / (TP + FP + 1e-12))
    recall = np.mean(TP / (TP + FN + 1e-12))
    F2 = (1 + beta**2) * precision * recall / (beta**2 * precision + recall + 1e-12)
    return F2


In [15]:
model = ModelMultilabel(15).to(device)



In [70]:
dataset[1]['label']

array([0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.])

In [16]:
train_model(model, train_loader, val_loader, 0.001, 50)

Training Epoch 1/50: 100%|██████████| 548/548 [25:15<00:00,  2.76s/it]


(8756, 15) (8756, 15)
111503 111503 19837 19837


Validation: 100%|██████████| 46/46 [01:53<00:00,  2.47s/it]


(729, 15) (729, 15)
9292 9292 1643 1643


Training Epoch 2/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
114168 114168 17172 17172


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.64it/s]


(729, 15) (729, 15)
9543 9543 1392 1392


Training Epoch 3/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
115572 115572 15768 15768


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.60it/s]


(729, 15) (729, 15)
9496 9496 1439 1439


Training Epoch 4/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
116694 116694 14646 14646


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.65it/s]


(729, 15) (729, 15)
9607 9607 1328 1328


Training Epoch 5/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
117534 117534 13806 13806


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.60it/s]


(729, 15) (729, 15)
9586 9586 1349 1349


Training Epoch 6/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
117814 117814 13526 13526


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.60it/s]


(729, 15) (729, 15)
9655 9655 1280 1280


Training Epoch 7/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
118945 118945 12395 12395


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.61it/s]


(729, 15) (729, 15)
9698 9698 1237 1237


Training Epoch 8/50: 100%|██████████| 548/548 [03:33<00:00,  2.56it/s]


(8756, 15) (8756, 15)
118969 118969 12371 12371


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.66it/s]


(729, 15) (729, 15)
9774 9774 1161 1161


Training Epoch 9/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
120213 120213 11127 11127


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
9771 9771 1164 1164


Training Epoch 10/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
120201 120201 11139 11139


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
9843 9843 1092 1092


Training Epoch 11/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
120936 120936 10404 10404


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.66it/s]


(729, 15) (729, 15)
9824 9824 1111 1111


Training Epoch 12/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
121109 121109 10231 10231


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.70it/s]


(729, 15) (729, 15)
9829 9829 1106 1106


Training Epoch 13/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
121705 121705 9635 9635


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.65it/s]


(729, 15) (729, 15)
9863 9863 1072 1072


Training Epoch 14/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
122019 122019 9321 9321


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.69it/s]


(729, 15) (729, 15)
9833 9833 1102 1102


Training Epoch 15/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
122139 122139 9201 9201


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.45it/s]


(729, 15) (729, 15)
9904 9904 1031 1031


Training Epoch 16/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
122693 122693 8647 8647


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.68it/s]


(729, 15) (729, 15)
9922 9922 1013 1013


Training Epoch 17/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
122823 122823 8517 8517


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
9908 9908 1027 1027


Training Epoch 18/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
123263 123263 8077 8077


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
9905 9905 1030 1030


Training Epoch 19/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
123611 123611 7729 7729


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.64it/s]


(729, 15) (729, 15)
9944 9944 991 991


Training Epoch 20/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
123453 123453 7887 7887


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.64it/s]


(729, 15) (729, 15)
9804 9804 1131 1131


Training Epoch 21/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
123561 123561 7779 7779


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.64it/s]


(729, 15) (729, 15)
9861 9861 1074 1074


Training Epoch 22/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
123771 123771 7569 7569


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.65it/s]


(729, 15) (729, 15)
9988 9988 947 947


Training Epoch 23/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
124006 124006 7334 7334


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.70it/s]


(729, 15) (729, 15)
9899 9899 1036 1036


Training Epoch 24/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
124134 124134 7206 7206


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.59it/s]


(729, 15) (729, 15)
9972 9972 963 963


Training Epoch 25/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
124373 124373 6967 6967


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.65it/s]


(729, 15) (729, 15)
9879 9879 1056 1056


Training Epoch 26/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
124381 124381 6959 6959


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.66it/s]


(729, 15) (729, 15)
9891 9891 1044 1044


Training Epoch 27/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
124355 124355 6985 6985


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
9888 9888 1047 1047


Training Epoch 28/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
124370 124370 6970 6970


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.62it/s]


(729, 15) (729, 15)
10008 10008 927 927


Training Epoch 29/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
124663 124663 6677 6677


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.67it/s]


(729, 15) (729, 15)
9932 9932 1003 1003


Training Epoch 30/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
124557 124557 6783 6783


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.71it/s]


(729, 15) (729, 15)
10001 10001 934 934


Training Epoch 31/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
124856 124856 6484 6484


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.66it/s]


(729, 15) (729, 15)
10037 10037 898 898


Training Epoch 32/50: 100%|██████████| 548/548 [03:33<00:00,  2.57it/s]


(8756, 15) (8756, 15)
124766 124766 6574 6574


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.69it/s]


(729, 15) (729, 15)
9991 9991 944 944


Training Epoch 33/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
125112 125112 6228 6228


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.62it/s]


(729, 15) (729, 15)
9999 9999 936 936


Training Epoch 34/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
125153 125153 6187 6187


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
9936 9936 999 999


Training Epoch 35/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
125128 125128 6212 6212


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.66it/s]


(729, 15) (729, 15)
10013 10013 922 922


Training Epoch 36/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
124853 124853 6487 6487


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
10023 10023 912 912


Training Epoch 37/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
125125 125125 6215 6215


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.66it/s]


(729, 15) (729, 15)
9950 9950 985 985


Training Epoch 38/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
124896 124896 6444 6444


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.66it/s]


(729, 15) (729, 15)
9974 9974 961 961


Training Epoch 39/50: 100%|██████████| 548/548 [03:32<00:00,  2.58it/s]


(8756, 15) (8756, 15)
125318 125318 6022 6022


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.65it/s]


(729, 15) (729, 15)
10062 10062 873 873


Training Epoch 40/50: 100%|██████████| 548/548 [03:32<00:00,  2.57it/s]


(8756, 15) (8756, 15)
125345 125345 5995 5995


Validation: 100%|██████████| 46/46 [00:04<00:00,  9.63it/s]


(729, 15) (729, 15)
9953 9953 982 982


Training Epoch 41/50: 100%|██████████| 548/548 [03:20<00:00,  2.73it/s]


(8756, 15) (8756, 15)
125253 125253 6087 6087


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.38it/s]


(729, 15) (729, 15)
9992 9992 943 943


Training Epoch 42/50: 100%|██████████| 548/548 [03:18<00:00,  2.77it/s]


(8756, 15) (8756, 15)
125486 125486 5854 5854


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.41it/s]


(729, 15) (729, 15)
9883 9883 1052 1052


Training Epoch 43/50: 100%|██████████| 548/548 [03:18<00:00,  2.76it/s]


(8756, 15) (8756, 15)
125418 125418 5922 5922


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.38it/s]


(729, 15) (729, 15)
9982 9982 953 953


Training Epoch 44/50: 100%|██████████| 548/548 [03:18<00:00,  2.77it/s]


(8756, 15) (8756, 15)
125424 125424 5916 5916


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.38it/s]


(729, 15) (729, 15)
9954 9954 981 981


Training Epoch 45/50: 100%|██████████| 548/548 [03:18<00:00,  2.77it/s]


(8756, 15) (8756, 15)
125509 125509 5831 5831


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.38it/s]


(729, 15) (729, 15)
9971 9971 964 964


Training Epoch 46/50: 100%|██████████| 548/548 [03:18<00:00,  2.77it/s]


(8756, 15) (8756, 15)
125563 125563 5777 5777


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.42it/s]


(729, 15) (729, 15)
9986 9986 949 949


Training Epoch 47/50: 100%|██████████| 548/548 [03:18<00:00,  2.77it/s]


(8756, 15) (8756, 15)
125683 125683 5657 5657


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.41it/s]


(729, 15) (729, 15)
9998 9998 937 937


Training Epoch 48/50: 100%|██████████| 548/548 [03:18<00:00,  2.76it/s]


(8756, 15) (8756, 15)
125561 125561 5779 5779


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.38it/s]


(729, 15) (729, 15)
9983 9983 952 952


Training Epoch 49/50: 100%|██████████| 548/548 [03:18<00:00,  2.77it/s]


(8756, 15) (8756, 15)
125671 125671 5669 5669


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.38it/s]


(729, 15) (729, 15)
9993 9993 942 942


Training Epoch 50/50: 100%|██████████| 548/548 [03:18<00:00,  2.77it/s]


(8756, 15) (8756, 15)
125670 125670 5670 5670


Validation: 100%|██████████| 46/46 [00:04<00:00, 10.38it/s]


(729, 15) (729, 15)
9947 9947 988 988


In [78]:
torch.save(model, "model80.pth")

In [None]:
model = torch.load("model.pth")
model.eval()

## Метрики

In [17]:
model.eval()
all_predictions = []
all_labels = []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)

        outputs = model(inputs)
        outputs = (outputs > 0.5).float()

        all_predictions.append(outputs.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
    
all_predictions = np.vstack(all_predictions)
all_labels = np.vstack(all_labels)
    
f1 = f1_score(all_labels, all_predictions, average='weighted')  
print(f"F1-Score: {f1:.4f}")

Evaluating: 100%|██████████| 320/320 [13:22<00:00,  2.51s/it]

F1-Score: 0.7977





In [19]:
f1_score(all_labels, all_predictions, average='macro')  

np.float64(0.5278923969003797)

In [20]:
f1_score(all_labels, all_predictions, average='micro')  

np.float64(0.8147874516935667)

In [21]:
import numpy as np
from sklearn.metrics import f1_score
import torch
from tqdm import tqdm

# Переводим модель в режим оценки
model.eval()

all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)

        outputs = model(inputs)
        outputs = (outputs > 0.5).float()  # Если задача бинарной классификации, иначе можно использовать softmax

        all_predictions.append(outputs.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

# Преобразуем списки в массивы
all_predictions = np.vstack(all_predictions)
all_labels = np.vstack(all_labels)

# Вычисляем F1-Score для каждого класса
f1_per_class = f1_score(all_labels, all_predictions, average=None)

# Выводим F1-Score для каждого класса
for i, f1 in enumerate(f1_per_class):
    print(f"Class {i + 1} F1-Score: {f1:.4f}")

Evaluating: 100%|██████████| 320/320 [00:33<00:00,  9.50it/s]

Class 1 F1-Score: 0.7377
Class 2 F1-Score: 0.8680
Class 3 F1-Score: 0.4062
Class 4 F1-Score: 0.4911
Class 5 F1-Score: 0.9222
Class 6 F1-Score: 0.0000
Class 7 F1-Score: 0.8441
Class 8 F1-Score: 0.8781
Class 9 F1-Score: 0.4490
Class 10 F1-Score: 0.8088
Class 11 F1-Score: 0.8221
Class 12 F1-Score: 0.0000
Class 13 F1-Score: 0.6909
Class 14 F1-Score: 0.0000
Class 15 F1-Score: 0.0000



