<a href="https://colab.research.google.com/github/shockless/microcirculation-solution/blob/main/Pipeline%20Notebooks/Pretrain_LoD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Подготовка сессии

## Импорты

In [1]:
!pip install gdown

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!gdown 12hNXcrHr0v48m9VRr-GLj5eq3pnYU7Im
!pip install -r /content/requirements.txt -q

Downloading...
From: https://drive.google.com/uc?id=12hNXcrHr0v48m9VRr-GLj5eq3pnYU7Im
To: /content/requirements.txt
  0% 0.00/225 [00:00<?, ?B/s]100% 225/225 [00:00<00:00, 413kB/s]


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

import pandas as pd 
import numpy as np
import glob
import multiprocessing
from tqdm import tqdm
import cv2
from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.models import resnet18
from torchvision.utils import draw_segmentation_masks

from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import clear_output
import os

import warnings
warnings.filterwarnings("ignore")
import json
from GPUtil import showUtilization as gpu_usage
import GPUtil

# Класс тренировщика модели

In [None]:
from collections import Callable, defaultdict
from typing import Tuple, Optional, List, Any, Iterator, Dict

import numpy as np
import torch
from torch import nn
from torch.utils.data.dataloader import DataLoader

from tqdm import tqdm

In [None]:
from typing import Optional
class UnetTrainer:
    """
    Класс, реализующий обучение модели
    """

    def __init__(self,
                 model: nn.Module,
                 optimizer: torch.optim.Optimizer,
                 criterion,
                 device: str,
                 metric_functions = [],
                 epoch_number: int = 0,
                 lr_scheduler = None):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.lr_scheduler = lr_scheduler
        self.device = device

        self.model.to(self.device)

        self.metric_functions = metric_functions

        self.epoch_number = epoch_number

    @torch.no_grad()
    def evaluate_batch(self, val_iterator: Iterator, eval_on_n_batches: int) -> Optional[Dict[str, float]]:     
        predictions = []
        targets = []

        losses = []

        for real_batch_number in range(eval_on_n_batches):
            try:
                batch = next(val_iterator)

                xs = batch['image'].to(self.device)
                ys_true = batch['mask'].to(self.device)
            except StopIteration:
                if real_batch_number == 0:
                    return None
                else:
                    break
            ys_pred = self.model.eval()(xs)
            loss = self.criterion(ys_pred, ys_true)

            losses.append(loss.item())

            predictions.append(ys_pred.cpu())
            targets.append(ys_true.cpu())

        predictions = torch.cat(predictions, dim=0)
        targets = torch.cat(targets, dim=0)

        metrics = {'loss': np.mean(losses)}

        for metric_name, metric_fn in self.metric_functions:
            metrics[metric_name] = metric_fn(predictions, targets).item()

        return metrics

    @torch.no_grad()
    def evaluate(self, val_loader, eval_on_n_batches: int = 1) -> Dict[str, float]:
        """
        Вычисление метрик для эпохи
        """
        metrics_sum = defaultdict(float)
        num_batches = 0

        val_iterator = iter(val_loader)

        while True:
            batch_metrics = self.evaluate_batch(val_iterator, eval_on_n_batches)

            if batch_metrics is None:
                break

            for metric_name in batch_metrics:
                metrics_sum[metric_name] += batch_metrics[metric_name]

            num_batches += 1

        metrics = {}

        for metric_name in metrics_sum:
            metrics[metric_name] = metrics_sum[metric_name] / num_batches
        
        return metrics

    def fit_batch(self, train_iterator: Iterator, update_every_n_batches: int) -> Optional[Dict[str, float]]:
        """
        Тренировка модели на одном батче
        """
        self.model.train()
    
        predictions = []
        targets = []

        losses = []

        for real_batch_number in range(update_every_n_batches):
            self.optimizer.zero_grad()
            try:
                batch = next(train_iterator)

                xs = batch['image'].to(self.device)
                ys_true = batch['mask'].to(self.device)
            except StopIteration:
                if real_batch_number == 0:
                    return None
                else:
                    break

            ys_pred = self.model.train()(xs)
            loss = self.criterion(ys_pred, ys_true)

            (loss / update_every_n_batches).backward()

            losses.append(loss.item())

            predictions.append(ys_pred.cpu())
            targets.append(ys_true.cpu())

        self.optimizer.step()
        predictions = torch.cat(predictions, dim=0)
        targets = torch.cat(targets, dim=0)

        metrics = {'loss': np.mean(losses)}

        for metric_name, metric_fn in self.metric_functions:
            metrics[metric_name] = metric_fn(predictions, targets).item()
        print(metrics)
        return metrics

    def fit_epoch(self, train_loader, update_every_n_batches: int = 1) -> Dict[str, float]:
        """
        Одна эпоха тренировки модели
        """

        metrics_sum = defaultdict(float)
        num_batches = 0

        train_iterator = iter(train_loader)
        n_batches = len(train_loader)
        with tqdm(total=n_batches) as pbar:
            while True:
                batch_metrics = self.fit_batch(train_iterator, update_every_n_batches)

                if batch_metrics is None:
                    break

                for metric_name in batch_metrics:
                    metrics_sum[metric_name] += batch_metrics[metric_name]

                pbar.update(1)
                num_batches += 1

        metrics = {}

        for metric_name in metrics_sum:
            metrics[metric_name] = metrics_sum[metric_name] / num_batches
        
        return metrics

    def fit(self, train_loader, num_epochs: int,
            val_loader = None, update_every_n_batches: int = 1,
            ) -> Dict[str, np.ndarray]:
        """
        Метод, тренирующий модель и вычисляющий метрики для каждой эпохи
        """

        summary = defaultdict(list)

        def save_metrics(metrics: Dict[str, float], postfix: str = '') -> None:
          # Сохранение метрик в summary
            nonlocal summary, self

            for metric in metrics:
                metric_name, metric_value = f'{metric}{postfix}', metrics[metric]

                summary[metric_name].append(metric_value)

        for _ in tqdm(range(num_epochs - self.epoch_number), initial=self.epoch_number, total=num_epochs):
            self.epoch_number += 1

            train_metrics = self.fit_epoch(train_loader, update_every_n_batches)

            with torch.no_grad():
                save_metrics(train_metrics, postfix='_train')

                if val_loader is not None:
                    test_metrics = self.evaluate(val_loader)
                    save_metrics(test_metrics, postfix='_test')

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

        summary = {metric: np.array(summary[metric]) for metric in summary}

        return summary


# Метрики

In [None]:
# F1-мера
class SoftDice:
    def __init__(self, epsilon=1e-8):
        self.epsilon = epsilon

    def __call__(self, predictions: List[Dict[str, torch.Tensor]],
                 targets: List[Dict[str, torch.Tensor]]) -> torch.Tensor:
        numerator = torch.sum(2 * predictions * targets)
        denominator = torch.sum(predictions + targets)
        return numerator / (denominator + self.epsilon)

# Метрика полноты
class Recall:
    def __init__(self, epsilon=1e-81):
        self.epsilon = epsilon

    def __call__(self, predictions: List[Dict[str, torch.Tensor]],
                 targets: List[Dict[str, torch.Tensor]]) -> torch.Tensor:
        numerator = torch.sum(predictions * targets)
        denominator = torch.sum(targets)

        return numerator / (denominator + self.epsilon)

# Метрика точности
class Accuracy:
    def __init__(self, epsilon=1e-8):
        self.epsilon = epsilon

    def __call__(self, predictions: list, targets: list) -> torch.Tensor:
        numerator = torch.sum(predictions * targets)
        denominator = torch.sum(predictions)

        return numerator / (denominator + self.epsilon)

def make_metrics():
    soft_dice = SoftDice()
    recall = Recall()
    accuracy = Accuracy()

    def exp_dice(pred, target):
        return soft_dice(torch.exp(pred[:, 1:]), target[:, 1:])

    def exp_accuracy(pred, target):
        return accuracy(torch.exp(pred[:, 1:]), target[:, 1:])

    def exp_recall(pred, target):
        return recall(torch.exp(pred[:, 1:]), target[:, 1:])

    return [('dice', exp_dice),
            ('accuracy', exp_accuracy),
            ('recall', exp_recall),
            ]

# Датасет претрейна


In [None]:
class PngEyeDataset(Dataset):
    """
    Класс датасета, организующий загрузку и получение изображений и соответствующих разметок
    """

    def __init__(self,
                 data_folder: str,
                 mode: str = "train",
                 dataset_names=['HRF', 'ChaseDB', 'ORVS', 'DR-Hagis', 'IOSTAR', 'ARIA', 'DRIVE'],
                 transform=None):
        self.class_ids = {"vessel": 1}

        self.mode = mode.lower()
        self.data_folder = data_folder
        self.transform = transform

        self.dataset_names = dataset_names
        self._image_files = []
        self._mask_files = []
        """
        DATASET/
        ....Test/
        ........Original/
        ............Images/
        ............Labels/
        ....Train/
        ........Original/
        ............Images/
        ............Labels/
        """

        for dataset_name in self.dataset_names:
            dataset_path = f"{data_folder}/{dataset_name}"
            test_part = f"{dataset_path}/Test/Original"
            train_part = f"{dataset_path}/Train/Original"

            images_files = list(
                sorted([f"{train_part}/Images/" + path for path in os.listdir(f"{train_part}/Images/")] + \
                       [f"{test_part}/Images/" + path for path in os.listdir(f"{test_part}/Images/")]))

            self._image_files.extend(images_files)

            if self.mode == "train":
                mask_files = list(
                    sorted([f"{train_part}/Labels/" + path for path in os.listdir(f"{train_part}/Labels/")] + \
                           [f"{test_part}/Labels/" + path for path in os.listdir(f"{test_part}/Labels/")]))

                self._mask_files.extend(mask_files)

    @staticmethod
    def read_image(path: str) -> np.ndarray:
        image = cv2.imread(str(path), cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = np.array(image / 255, dtype=np.float32)
        return image

    @staticmethod
    def read_mask(path: str) -> np.ndarray:
        mask = np.expand_dims(cv2.threshold(cv2.imread(path, cv2.IMREAD_GRAYSCALE), 128, 255, cv2.THRESH_BINARY)[1], 2)
        mask = np.concatenate(((255 - mask), mask), axis=2)/255

        return mask

    def __getitem__(self, idx: int) -> dict:
        # Достаём имя файла по индексу
        image_path = self._image_files[idx]

        # Получаем соответствующий файл разметки

        image = self.read_image(image_path)

        if self.mode in ("train", "val"):
            mask_path = self._mask_files[idx]
            mask = self.read_mask(mask_path)

        if self.mode in ("train", "val"):
            sample = {'image': image,
                      'mask': mask}
        else:
            sample = {"image": image}

        if self.transform is not None:
            sample = self.transform(**sample)
            
        return sample

    def __len__(self):
        return len(self._image_files)

    # Метод для проверки состояния датасета
    def make_report(self):
        reports = []
        if (not self.data_folder):
            reports.append("data folder path is not given")
        if (len(self._image_files) == 0):
            reports.append("No images were loaded")
        else:
            reports.append(f"Found {len(self._image_files)} images")
        if (len(self._mask_files) == 0):
            reports.append("No masks were loaded")
        else:
            reports.append(f"Found {len(self._mask_files)} masks")
        return reports


# Модель

In [None]:
!git clone https://github.com/AbdullahSarhan/ICPRVessels/

In [None]:
def make_criterion():
    
    soft_dice = SoftDice()
    def exp_dice(pred, target):
        return 1 - soft_dice(torch.exp(pred[:, 1:]), target[:, 1:])

    return exp_dice

criterion = make_criterion()

In [None]:
cores = multiprocessing.cpu_count()
batch_size=1
encoder_pretrain='timm-regnety_064'
lr_e_pretrain =  1e-5
lr_d_pretrain = 1e-3
pretrain_data_folder='/content/ICPRVessels/Vessels-Datasets'

In [None]:
!mv /content/ICPRVessels/Vessels-Datasets/DRIVE/Test/Original/labels /content/ICPRVessels/Vessels-Datasets/DRIVE/Test/Original/Labels

In [None]:
torch.cuda.empty_cache()

In [None]:
!pip install madgrad

In [None]:
import madgrad

In [None]:
size = 1024
train_list = [A.Resize(1624,1232),
              A.LongestMaxSize(size, interpolation=cv2.INTER_CUBIC),
              A.PadIfNeeded(size, size),
              ToTensorV2(transpose_mask=True),
              ]

transforms = {'train': train_list}

dataset_pretrain = PngEyeDataset(
    pretrain_data_folder, 
    mode="train",
    transform=A.Compose(transforms['train'])
)
for msg in dataset_pretrain.make_report():
    print(msg)

pretrain_loader = torch.utils.data.DataLoader(dataset_pretrain, batch_size,
                                num_workers=cores,
                                shuffle=True, drop_last=True)



torch.cuda.empty_cache()
model_pretrain = smp.UnetPlusPlus(encoder_pretrain, activation='logsoftmax', classes=2, encoder_weights=None).cuda()
optimizer = madgrad.MADGRAD(model_pretrain.parameters(), 1e-4)
trainer_pretrain = UnetTrainer(model_pretrain, optimizer, criterion, 'cuda', metric_functions=make_metrics())
summary = trainer_pretrain.fit(pretrain_loader, 10)
torch.save(model_pretrain, "/content/drive/MyDrive/Competitions/LeadersOfDigital/models/model_pretrain.pt")

In [None]:
summary