# Импорт библиотек и загрузка данных

In [15]:
import json
import cv2 
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import seaborn as sns
import pandas as pd
from collections import Counter
from functools import lru_cache
import os
import shutil
import re
import logging
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import datasets, transforms, models
from torch import nn
from torch import optim
from tqdm import tqdm
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
import torch

## Считываем данные

In [49]:
class ImageDataset:
    def __init__(self, path, path2image, n_classes=15):
        # Загрузка данных из CSV
        self.df = pd.read_csv(path)
        self.path2image = Path(path2image)
        self.N = n_classes
        
    def __len__(self):
        return len(self.df)

    def _encode_label(self, classes):
        label = np.zeros(self.N)
        for c in np.array(classes, dtype=np.int16) - 1:
            label[c] = 1
        return label
        
    def sample(self, clss=1, count=15):
        files = []
        while len(files) < count:
            indx = np.random.randint(0, len(self.df))
            labels = self.df.iloc[indx]["OUTPUT:classes"]
            try:
                if clss in self.parse(labels):
                    files.append(self.path2image / self.file_name(indx))
            except:
                continue
        return files
    
    def create_path(self, indx):
        return str(self.path2image / self.file_name(indx))
        
    def __getitem__(self, indx):
        img = cv2.imread(self.create_path(indx))
        clases = self.parse(self.df.iloc[indx]["OUTPUT:classes"])
        label = self._encode_label(clases)
        sample = {'image': img,
                  'label': label}
        return sample
            
    def file_name(self, indx):
        return self.df.iloc[indx]['file_name']
    
    def resize(self, img, scale=2):    
        shape = (np.array(img.shape[:2]) / scale).astype(np.int32)
        return cv2.resize(img, shape)
    
    def parse(self, st):        
        # Преобразуем строку вида '[2, 5, 11]' в список [2, 5, 11]
        return [int(x) for x in re.findall("\d+", st)]

  return [int(x) for x in re.findall("\d+", st)]


In [50]:
class TorchImageDataset(ImageDataset, Dataset):
    def __init__(self, path, path2image, imgsz=256, MEAN = (0.485, 0.456, 0.406), STD = (0.229, 0.224, 0.225)):
        super().__init__(path, path2image)
        self.imgsz = imgsz
        self.fransform_img= transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize((self.imgsz, self.imgsz)),
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(10),    
                    transforms.ColorJitter(brightness=0.2, 
                                                   contrast=0.2, 
                                                   saturation=0.2, 
                                                   hue=0.2),
                    transforms.Normalize(mean=MEAN, std=STD),
        ])
            
    @lru_cache(10000)
    def __getitem__(self,indx):
        img, label = super().__getitem__(indx).values()
        if img is not None:
            img = self.fransform_img(img)
            return {'image':img,'label':label}

## Подготовка файла

## Получаем датасет

In [51]:
dataset = TorchImageDataset("./analis/data.csv",  "../Resize/")

## DataLoader

In [52]:
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():  
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device       

    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            img,label = tuple(b.values())
            img = to_device(img,self.device)
            label = to_device(label, self.device)
            yield {"image":img,"label":label}

    def __len__(self):
        """Number of batches"""
        return len(self.dl)

In [53]:
def split_data(dataset, train_ratio=0.7, val_ratio=0.1):

    total_size = len(dataset)
    train_size = int(train_ratio * total_size)
    val_size = int(val_ratio * total_size)
    test_size = total_size - train_size - val_size

    train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

    return train_data, val_data, test_data

In [71]:
train_data, val_data, test_data = split_data(dataset, train_ratio=0.6, val_ratio=0.05)

In [55]:
type(dataset[0])

dict

In [56]:
device = torch.device("cuda:0")

In [57]:
len(dataset)

14594

In [72]:
train_loader = DeviceDataLoader(DataLoader(train_data, batch_size=16,num_workers = 0), device=device)
val_loader   = DeviceDataLoader(DataLoader(val_data, batch_size=16,num_workers = 0), device=device)
test_loader   = DeviceDataLoader(DataLoader(test_data, batch_size=16,num_workers = 0), device=device)

# Обучение модели

In [None]:
class ModelMultilabel(nn.Module):
    def __init__(self, n_classes):
        super(ModelMultilabel, self).__init__()
        self.n_classes = n_classes
        
        self.model = timm.create_model('resnet18', pretrained=True)
        
        self.model.fc = nn.Identity()
        
        self.fc = nn.Linear(512, self.n_classes) 
        
    def forward(self, x):

        x = self.model(x)

        x = self.fc(x)
        
        return x

def train_model(model, dataloader_train, dataloader_valid, learningRate, num_epochs, device = 'cuda:0'):
    logger = logging.getLogger('api_log')
    logger.setLevel(logging.INFO)
    file_handler = logging.FileHandler(f'log_{learningRate}.txt')
    file_handler.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    logger.debug(f"Start_train  lr={learningRate}")
    
    # Установка оптимизатора с текущей скоростью обучения
    optimizer = optim.Adam(model.parameters(), lr=learningRate)
    criterion = nn.CrossEntropyLoss()
    # Цикл обучения
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        all_labels = []
        all_predictions = []

        for s in tqdm(dataloader_train, desc=f'Training Epoch {epoch+1}/{num_epochs}'):
            if s is None:
                continue
            inputs = s['image']
            labels = s['label']
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)  # Выходы модели
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            predicted = outputs
            all_labels.extend(labels.detach().cpu().numpy())
            all_predictions.extend(predicted.detach().cpu().numpy())
        # Средняя потеря за эпоху
        epoch_loss = running_loss / len(dataloader_train)
        # precision = precision_score(all_labels, all_predictions, average='weighted')
        # recall = recall_score(all_labels, all_predictions, average='weighted')
        # accuracy = accuracy_score(all_labels, all_predictions)
        f_score = F_score(all_predictions, all_labels)
        # print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, "
        #       f"Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f}, F_score: {f_score:.4f}")
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, F_score: {f_score:.4f}")
        
        # logger.debug(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}, Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f},  F_score: {f_score:.4f}")
        validate_model(model, dataloader_valid, logger)    
    # Сохранение модели
    torch.save(model.state_dict(), f'model_{learningRate}.pth')
    
def validate_model(model, dataloader_valid, logger, device='cuda:0'):
    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for s in tqdm(dataloader_valid, desc='Validation'):
            if s is None:
                continue
            inputs = s['image']
            labels = s['label']
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            predicted = outputs
            all_labels.extend(labels.detach().cpu().numpy())
            all_predictions.extend(predicted.detach().cpu().numpy())

    # precision = precision_score(all_labels, all_predictions, average='weighted')
    # recall = recall_score(all_labels, all_predictions, average='weighted')
    # accuracy = accuracy_score(all_labels, all_predictions)
    f_score = F_score(all_predictions, all_labels)
    # print(f"Validation Results - Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f}")
    # logger.debug(f"Validation Results - Recall: {recall:.4f}, Precision: {precision:.4f}, Accuracy: {accuracy:.4f}, F_score: {f_score:.4f}")
    logger.info(f"Validation Results -  F_score: {f_score:.4f}")


def F_score(output, label, threshold=0.5, beta=1): #Calculate the accuracy of the model
    if isinstance(output, list):
        output = np.array(output)
    if isinstance(label, list):
        label = np.array(label)
    print(label.shape, output.shape)
    assert (output.shape == label.shape),'shape is different'
    prob = output > threshold
    label = label > threshold

    TP = (prob == label).sum()
    TN = (np.bitwise_not(prob) == np.bitwise_not(label)).sum()
    FP = (prob == np.bitwise_not(label)).sum()
    FN = (np.bitwise_not(prob) == label).sum()
    print(TP, TN, FP, FN)
    precision = np.mean(TP / (TP + FP + 1e-12))
    recall = np.mean(TP / (TP + FN + 1e-12))
    F2 = (1 + beta**2) * precision * recall / (beta**2 * precision + recall + 1e-12)
    return F2


In [68]:
import timm

In [74]:
model = ModelMultilabel(15).to(device)

In [70]:
dataset[1]['label']

array([0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.])

In [75]:
train_model(model, train_loader, val_loader, 0.001, 50)

Training Epoch 1/50: 100%|██████████| 548/548 [10:57<00:00,  1.20s/it]


(8756, 15) (8756, 15)
111997 111997 19343 19343


Validation: 100%|██████████| 46/46 [01:37<00:00,  2.11s/it]


(729, 15) (729, 15)
9545 9545 1390 1390


Training Epoch 2/50: 100%|██████████| 548/548 [00:39<00:00, 13.90it/s]


(8756, 15) (8756, 15)
115233 115233 16107 16107


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.30it/s]


(729, 15) (729, 15)
9638 9638 1297 1297


Training Epoch 3/50: 100%|██████████| 548/548 [00:39<00:00, 13.93it/s]


(8756, 15) (8756, 15)
118204 118204 13136 13136


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.97it/s]


(729, 15) (729, 15)
9678 9678 1257 1257


Training Epoch 4/50: 100%|██████████| 548/548 [00:39<00:00, 13.92it/s]


(8756, 15) (8756, 15)
120702 120702 10638 10638


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.71it/s]


(729, 15) (729, 15)
9669 9669 1266 1266


Training Epoch 5/50: 100%|██████████| 548/548 [00:39<00:00, 13.89it/s]


(8756, 15) (8756, 15)
122879 122879 8461 8461


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.64it/s]


(729, 15) (729, 15)
9748 9748 1187 1187


Training Epoch 6/50: 100%|██████████| 548/548 [00:39<00:00, 13.82it/s]


(8756, 15) (8756, 15)
124174 124174 7166 7166


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.04it/s]


(729, 15) (729, 15)
9801 9801 1134 1134


Training Epoch 7/50: 100%|██████████| 548/548 [00:39<00:00, 13.84it/s]


(8756, 15) (8756, 15)
125429 125429 5911 5911


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.16it/s]


(729, 15) (729, 15)
9843 9843 1092 1092


Training Epoch 8/50: 100%|██████████| 548/548 [00:39<00:00, 13.81it/s]


(8756, 15) (8756, 15)
126615 126615 4725 4725


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.53it/s]


(729, 15) (729, 15)
9860 9860 1075 1075


Training Epoch 9/50: 100%|██████████| 548/548 [00:39<00:00, 13.85it/s]


(8756, 15) (8756, 15)
127489 127489 3851 3851


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.45it/s]


(729, 15) (729, 15)
9875 9875 1060 1060


Training Epoch 10/50: 100%|██████████| 548/548 [00:39<00:00, 13.88it/s]


(8756, 15) (8756, 15)
127973 127973 3367 3367


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.35it/s]


(729, 15) (729, 15)
9840 9840 1095 1095


Training Epoch 11/50: 100%|██████████| 548/548 [00:39<00:00, 13.89it/s]


(8756, 15) (8756, 15)
127379 127379 3961 3961


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.17it/s]


(729, 15) (729, 15)
9788 9788 1147 1147


Training Epoch 12/50: 100%|██████████| 548/548 [00:39<00:00, 13.92it/s]


(8756, 15) (8756, 15)
127588 127588 3752 3752


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.56it/s]


(729, 15) (729, 15)
9904 9904 1031 1031


Training Epoch 13/50: 100%|██████████| 548/548 [00:39<00:00, 13.85it/s]


(8756, 15) (8756, 15)
128956 128956 2384 2384


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.04it/s]


(729, 15) (729, 15)
9925 9925 1010 1010


Training Epoch 14/50: 100%|██████████| 548/548 [00:39<00:00, 13.88it/s]


(8756, 15) (8756, 15)
129823 129823 1517 1517


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.28it/s]


(729, 15) (729, 15)
9950 9950 985 985


Training Epoch 15/50: 100%|██████████| 548/548 [00:39<00:00, 13.94it/s]


(8756, 15) (8756, 15)
130285 130285 1055 1055


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.82it/s]


(729, 15) (729, 15)
9976 9976 959 959


Training Epoch 16/50: 100%|██████████| 548/548 [00:39<00:00, 13.91it/s]


(8756, 15) (8756, 15)
130496 130496 844 844


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.41it/s]


(729, 15) (729, 15)
9956 9956 979 979


Training Epoch 17/50: 100%|██████████| 548/548 [00:39<00:00, 13.81it/s]


(8756, 15) (8756, 15)
130399 130399 941 941


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.10it/s]


(729, 15) (729, 15)
9957 9957 978 978


Training Epoch 18/50: 100%|██████████| 548/548 [00:39<00:00, 13.93it/s]


(8756, 15) (8756, 15)
128136 128136 3204 3204


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.88it/s]


(729, 15) (729, 15)
9791 9791 1144 1144


Training Epoch 19/50: 100%|██████████| 548/548 [00:39<00:00, 13.90it/s]


(8756, 15) (8756, 15)
128003 128003 3337 3337


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.67it/s]


(729, 15) (729, 15)
9909 9909 1026 1026


Training Epoch 20/50: 100%|██████████| 548/548 [00:39<00:00, 13.92it/s]


(8756, 15) (8756, 15)
130067 130067 1273 1273


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.90it/s]


(729, 15) (729, 15)
9937 9937 998 998


Training Epoch 21/50: 100%|██████████| 548/548 [00:39<00:00, 13.94it/s]


(8756, 15) (8756, 15)
130714 130714 626 626


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.53it/s]


(729, 15) (729, 15)
9992 9992 943 943


Training Epoch 22/50: 100%|██████████| 548/548 [00:39<00:00, 13.91it/s]


(8756, 15) (8756, 15)
131021 131021 319 319


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.03it/s]


(729, 15) (729, 15)
10017 10017 918 918


Training Epoch 23/50: 100%|██████████| 548/548 [00:39<00:00, 13.92it/s]


(8756, 15) (8756, 15)
131133 131133 207 207


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.13it/s]


(729, 15) (729, 15)
10019 10019 916 916


Training Epoch 24/50: 100%|██████████| 548/548 [00:39<00:00, 13.89it/s]


(8756, 15) (8756, 15)
131165 131165 175 175


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.74it/s]


(729, 15) (729, 15)
10032 10032 903 903


Training Epoch 25/50: 100%|██████████| 548/548 [00:39<00:00, 13.85it/s]


(8756, 15) (8756, 15)
131187 131187 153 153


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.26it/s]


(729, 15) (729, 15)
10015 10015 920 920


Training Epoch 26/50: 100%|██████████| 548/548 [00:39<00:00, 13.92it/s]


(8756, 15) (8756, 15)
129525 129525 1815 1815


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.86it/s]


(729, 15) (729, 15)
9852 9852 1083 1083


Training Epoch 27/50: 100%|██████████| 548/548 [00:39<00:00, 13.96it/s]


(8756, 15) (8756, 15)
128021 128021 3319 3319


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.90it/s]


(729, 15) (729, 15)
9963 9963 972 972


Training Epoch 28/50: 100%|██████████| 548/548 [00:39<00:00, 13.90it/s]


(8756, 15) (8756, 15)
130381 130381 959 959


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.73it/s]


(729, 15) (729, 15)
9986 9986 949 949


Training Epoch 29/50: 100%|██████████| 548/548 [00:39<00:00, 13.94it/s]


(8756, 15) (8756, 15)
130920 130920 420 420


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.10it/s]


(729, 15) (729, 15)
9998 9998 937 937


Training Epoch 30/50: 100%|██████████| 548/548 [00:39<00:00, 13.93it/s]


(8756, 15) (8756, 15)
131132 131132 208 208


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.58it/s]


(729, 15) (729, 15)
10024 10024 911 911


Training Epoch 31/50: 100%|██████████| 548/548 [00:39<00:00, 13.94it/s]


(8756, 15) (8756, 15)
131234 131234 106 106


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.95it/s]


(729, 15) (729, 15)
10026 10026 909 909


Training Epoch 32/50: 100%|██████████| 548/548 [00:39<00:00, 13.88it/s]


(8756, 15) (8756, 15)
131256 131256 84 84


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.13it/s]


(729, 15) (729, 15)
10017 10017 918 918


Training Epoch 33/50: 100%|██████████| 548/548 [00:39<00:00, 13.97it/s]


(8756, 15) (8756, 15)
131274 131274 66 66


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.07it/s]


(729, 15) (729, 15)
10046 10046 889 889


Training Epoch 34/50: 100%|██████████| 548/548 [00:39<00:00, 13.90it/s]


(8756, 15) (8756, 15)
131289 131289 51 51


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.43it/s]


(729, 15) (729, 15)
10026 10026 909 909


Training Epoch 35/50: 100%|██████████| 548/548 [00:39<00:00, 13.96it/s]


(8756, 15) (8756, 15)
129415 129415 1925 1925


Validation: 100%|██████████| 46/46 [00:01<00:00, 41.98it/s]


(729, 15) (729, 15)
9815 9815 1120 1120


Training Epoch 36/50: 100%|██████████| 548/548 [00:39<00:00, 13.97it/s]


(8756, 15) (8756, 15)
128661 128661 2679 2679


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.75it/s]


(729, 15) (729, 15)
9967 9967 968 968


Training Epoch 37/50: 100%|██████████| 548/548 [00:39<00:00, 13.86it/s]


(8756, 15) (8756, 15)
130683 130683 657 657


Validation: 100%|██████████| 46/46 [00:01<00:00, 41.79it/s]


(729, 15) (729, 15)
9924 9924 1011 1011


Training Epoch 38/50: 100%|██████████| 548/548 [00:39<00:00, 13.94it/s]


(8756, 15) (8756, 15)
131046 131046 294 294


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.87it/s]


(729, 15) (729, 15)
9972 9972 963 963


Training Epoch 39/50: 100%|██████████| 548/548 [00:39<00:00, 13.94it/s]


(8756, 15) (8756, 15)
131248 131248 92 92


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.83it/s]


(729, 15) (729, 15)
10010 10010 925 925


Training Epoch 40/50: 100%|██████████| 548/548 [00:39<00:00, 13.99it/s]


(8756, 15) (8756, 15)
131293 131293 47 47


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.30it/s]


(729, 15) (729, 15)
10014 10014 921 921


Training Epoch 41/50: 100%|██████████| 548/548 [00:39<00:00, 14.01it/s]


(8756, 15) (8756, 15)
131308 131308 32 32


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.86it/s]


(729, 15) (729, 15)
10017 10017 918 918


Training Epoch 42/50: 100%|██████████| 548/548 [00:39<00:00, 13.95it/s]


(8756, 15) (8756, 15)
131314 131314 26 26


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.50it/s]


(729, 15) (729, 15)
10023 10023 912 912


Training Epoch 43/50: 100%|██████████| 548/548 [00:39<00:00, 13.88it/s]


(8756, 15) (8756, 15)
131305 131305 35 35


Validation: 100%|██████████| 46/46 [00:01<00:00, 41.92it/s]


(729, 15) (729, 15)
10019 10019 916 916


Training Epoch 44/50: 100%|██████████| 548/548 [00:39<00:00, 13.85it/s]


(8756, 15) (8756, 15)
129087 129087 2253 2253


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.48it/s]


(729, 15) (729, 15)
9874 9874 1061 1061


Training Epoch 45/50: 100%|██████████| 548/548 [00:39<00:00, 13.90it/s]


(8756, 15) (8756, 15)
130024 130024 1316 1316


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.20it/s]


(729, 15) (729, 15)
9982 9982 953 953


Training Epoch 46/50: 100%|██████████| 548/548 [00:39<00:00, 13.95it/s]


(8756, 15) (8756, 15)
130978 130978 362 362


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.46it/s]


(729, 15) (729, 15)
10031 10031 904 904


Training Epoch 47/50: 100%|██████████| 548/548 [00:39<00:00, 13.86it/s]


(8756, 15) (8756, 15)
131198 131198 142 142


Validation: 100%|██████████| 46/46 [00:01<00:00, 42.79it/s]


(729, 15) (729, 15)
10023 10023 912 912


Training Epoch 48/50: 100%|██████████| 548/548 [00:39<00:00, 13.87it/s]


(8756, 15) (8756, 15)
131300 131300 40 40


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.01it/s]


(729, 15) (729, 15)
10043 10043 892 892


Training Epoch 49/50: 100%|██████████| 548/548 [00:39<00:00, 13.94it/s]


(8756, 15) (8756, 15)
131322 131322 18 18


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.22it/s]


(729, 15) (729, 15)
10037 10037 898 898


Training Epoch 50/50: 100%|██████████| 548/548 [00:39<00:00, 13.97it/s]


(8756, 15) (8756, 15)
131333 131333 7 7


Validation: 100%|██████████| 46/46 [00:01<00:00, 43.06it/s]


(729, 15) (729, 15)
10038 10038 897 897


In [78]:
torch.save(model, "model80.pth")

In [None]:
model = torch.load("model.pth")
model.eval()

## Метрики

In [76]:
model.eval()
all_predictions = []
all_labels = []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)

        outputs = model(inputs)
        outputs = (outputs > 0.5).float()

        all_predictions.append(outputs.cpu().numpy())
        all_labels.append(labels.cpu().numpy())
    
all_predictions = np.vstack(all_predictions)
all_labels = np.vstack(all_labels)
    
f1 = f1_score(all_labels, all_predictions, average='weighted')  
print(f"F1-Score: {f1:.4f}")

Evaluating: 100%|██████████| 320/320 [13:07<00:00,  2.46s/it]

F1-Score: 0.8061





In [79]:
f1_score(all_labels, all_predictions, average='macro')  

np.float64(0.6028597595650425)

In [80]:
f1_score(all_labels, all_predictions, average='micro')  

np.float64(0.818353419485988)

In [77]:
import numpy as np
from sklearn.metrics import f1_score
import torch
from tqdm import tqdm

# Переводим модель в режим оценки
model.eval()

all_predictions = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        inputs = batch['image'].to(device)
        labels = batch['label'].to(device)

        outputs = model(inputs)
        outputs = (outputs > 0.5).float()  # Если задача бинарной классификации, иначе можно использовать softmax

        all_predictions.append(outputs.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

# Преобразуем списки в массивы
all_predictions = np.vstack(all_predictions)
all_labels = np.vstack(all_labels)

# Вычисляем F1-Score для каждого класса
f1_per_class = f1_score(all_labels, all_predictions, average=None)

# Выводим F1-Score для каждого класса
for i, f1 in enumerate(f1_per_class):
    print(f"Class {i} F1-Score: {f1:.4f}")

Evaluating: 100%|██████████| 320/320 [00:07<00:00, 42.48it/s]

Class 0 F1-Score: 0.7120
Class 1 F1-Score: 0.8665
Class 2 F1-Score: 0.4981
Class 3 F1-Score: 0.5759
Class 4 F1-Score: 0.9214
Class 5 F1-Score: 0.0551
Class 6 F1-Score: 0.8484
Class 7 F1-Score: 0.8639
Class 8 F1-Score: 0.4510
Class 9 F1-Score: 0.8189
Class 10 F1-Score: 0.8396
Class 11 F1-Score: 0.1928
Class 12 F1-Score: 0.6957
Class 13 F1-Score: 0.2500
Class 14 F1-Score: 0.4536



