<a href="https://colab.research.google.com/github/tzopiz/TMJ/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# colab не может найти зависимость
# !pip install -r requirements.txt

In [None]:
import os
import random
from os.path import join as pjoin
from shutil import rmtree

import albumentations as A
import numpy as np
import torch

from accelerate import Accelerator
from albumentations.pytorch.transforms import ToTensorV2

from matplotlib import pyplot as plt
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter

from train import (
    CheckpointSaver,
    load_checkpoint,
    train
)
from LossFunc.DiceLoss import DiceLoss
from MetricFunc.MeanIoU import MeanIoU
from LossFunc.FocalLoss import FocalLoss
from MetricFunc.CustomMeanIoU import CustomMeanIoU

from helpy import compute_class_weights, visualize_prediction

from TMJDataset import TMJDataset
from unet import UNet

In [None]:
def seed_everything(seed: int = 314159, torch_deterministic: bool = False) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(torch_deterministic)


seed_everything(42, torch_deterministic=False)

In [None]:
IMAGE_SIZE = 512

# Базовые преобразования
basic_transforms = A.Compose([
    A.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE, p=1.0),  # Изменение размера
    A.PadIfNeeded(min_height=IMAGE_SIZE, min_width=IMAGE_SIZE, p=1.0),  # Добавление паддинга
    A.CropNonEmptyMaskIfExists(height=IMAGE_SIZE, width=IMAGE_SIZE),  # Обрезка
    A.HorizontalFlip(p=0.5)  # Случайный горизонтальный флип
])

# Агрессивные преобразования, с уменьшением вероятности применения сильных аугментаций
aggressive_transforms = A.Compose([
    A.OneOf([
        A.AdvancedBlur(p=0.5),  # Меньше размытости
        A.CLAHE(p=0.5),  # Меньше CLAHE
    ], p=0.3),  # Применяется с меньшей вероятностью

    A.OneOf([
        A.RandomBrightnessContrast(p=0.5),  # Случайная яркость и контраст
        A.RandomGamma(p=0.5),  # Случайная гамма
        A.ColorJitter(p=0.5),  # Случайное изменение яркости, контраста и насыщенности
    ], p=0.3),  # Применяется с вероятностью 50%

    A.Rotate(limit=15, p=0.3),  # Меньше поворота, чтобы сохранить форму маленьких объектов
    A.ElasticTransform(alpha=1, sigma=50, p=0.5),  # Эластичные деформации
])

# Итоговая трансформация, где сначала применяются базовые, затем агрессивные
transforms = A.Compose([
    basic_transforms,  # Базовые преобразования
    aggressive_transforms,  # Агрессивные преобразования
    ToTensorV2(),  # Преобразование в тензор
])

In [None]:
full_dataset = TMJDataset(
    image_dir="full_dataset/images",
    mask_dir="full_dataset/masks",
    transforms=transforms
)

In [None]:
full_dataset.visualize(0)

In [None]:
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

print('Количество изображений в полном датасете:',len(full_dataset))
print('Количество изображений в тренировочном датасете:',len(train_dataset))
print('Количество изображений в валидационном датасете:',len(val_dataset))
print('Количество изображений в тестовом датасете:',len(test_dataset))

## Обучение модели

In [None]:
CLASSES_NUM = 3

LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-6
BETAS = (0.8, 0.999)
BATCH_SIZE = 32
NUM_WORKERS = 4
EPOCH_NUM = 50
CHECKPOINTS_DIR = "checkpoints"
TENSORBOARD_DIR = "tensorboard"
RM_CHECKPOINTS_DIR = False
CLASS_WEIGHTS = compute_class_weights(dataset=full_dataset, num_classes=CLASSES_NUM)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
CLASS_WEIGHTS

In [None]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    drop_last=True,
)
val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    drop_last=True,
)

accelerator = Accelerator(cpu=False, mixed_precision="fp16")
model = UNet(in_channels=3, out_channels=CLASSES_NUM)

loss_fn = FocalLoss()
metric_fn = CustomMeanIoU(num_classes=CLASSES_NUM, class_weights=CLASS_WEIGHTS)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=BETAS
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=optimizer, step_size=10, gamma=0.85
)

os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
checkpointer = CheckpointSaver(
    accelerator=accelerator,
    model=model,
    metric_name="mIoU",
    save_dir=CHECKPOINTS_DIR,
    rm_save_dir=RM_CHECKPOINTS_DIR,
    max_history=5,
    should_minimize=False,
)

In [None]:
os.makedirs(TENSORBOARD_DIR, exist_ok=True)
tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir=TENSORBOARD_DIR)

In [None]:
# акселерируем
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)

In [None]:
train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_function=loss_fn,
    metric_function=metric_fn,
    lr_scheduler=lr_scheduler,
    accelerator=accelerator,
    epoch_num=EPOCH_NUM,
    checkpointer=checkpointer,
    tb_logger=tensorboard_logger,
    save_on_val=True,
    show_every_x_batch=15,
)

In [None]:
model = UNet(in_channels=3, out_channels=CLASSES_NUM)
model = load_checkpoint(
    model=model, load_path=pjoin(CHECKPOINTS_DIR, "model_checkpoint_best.pt")
)
model = model.to(DEVICE)
model.eval();

In [None]:
visualize_prediction(
    model=model,
    dataset=test_dataset,
    index=1,
    device=DEVICE,
    threshold=0.99
)