# LunksNN

In [None]:
import numpy as np
import albumentations
import random
import torch
import os

from pycocotools.coco import COCO
from albumentations.pytorch.transforms import ToTensorV2
from accelerate import Accelerator
from torch.utils.data import DataLoader
from PIL import Image
from torchvision import datasets
from torchvision.datasets import VOCSegmentation

from main import create_masks_for_all_images
from dataset import LunksDataset
from Unet import UNet, count_model_params
from train import (
    CheckpointSaver,
    IoUMetric,
    MulticlassCrossEntropyLoss,
    MulticlassDiceLoss,
    load_checkpoint,
    train
)

In [None]:
def seed_everything(seed: int = 314159, torch_deterministic: bool = False) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(torch_deterministic)


seed_everything(42, torch_deterministic=False)

In [None]:
annFile = 'instances_default.json'
coco = COCO(annFile)

create_masks_for_all_images(coco, "Dataset/Masks")

In [None]:
IMAGE_SIZE = 512
transforms = albumentations.Compose(
    [
        albumentations.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        albumentations.AdvancedBlur(p=0.5),
        albumentations.GaussNoise(p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.CLAHE(p=0.5),
        albumentations.RandomBrightnessContrast(p=0.5),
        albumentations.RandomGamma(p=0.5),
        albumentations.ColorJitter(p=0.5),
        ToTensorV2(),
    ]
)

In [None]:
train_dataset = LunksDataset(root_dir="Dataset", transforms=transforms)

In [None]:
len(train_dataset)

In [None]:
image, mask = train_dataset[0]

In [None]:
image.shape

In [None]:
mask.shape

In [None]:
model = UNet(in_channels=3, out_channels=21)

In [None]:
count_model_params(model)

In [None]:
accelerator = Accelerator(cpu=True, mixed_precision="fp16")

In [None]:
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_WORKERS = 2
EPOCH_NUM = 20
CHECKPOINTS_DIR = "checkpoints"
TENSORBOARD_DIR = "tensorboard"
RM_CHECKPOINTS_DIR = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)
val_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)

model = UNet(in_channels=3, out_channels=21)

loss_fn = MulticlassCrossEntropyLoss(ignore_index=0)  # MulticlassDiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=optimizer, step_size=5, gamma=0.8
)
metric_fn = loss_fn

os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
checkpointer = CheckpointSaver(
    accelerator=accelerator,
    model=model,
    metric_name="DICE",
    save_dir=CHECKPOINTS_DIR,
    rm_save_dir=RM_CHECKPOINTS_DIR,
    max_history=5,
    should_minimize=True,
)

In [None]:
os.makedirs(TENSORBOARD_DIR, exist_ok=True)
tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir=TENSORBOARD_DIR)

In [None]:
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)

In [ ]:
train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_function=loss_fn,
    metric_function=metric_fn,
    lr_scheduler=lr_scheduler,
    accelerator=accelerator,
    epoch_num=EPOCH_NUM,
    checkpointer=checkpointer,
    tb_logger=tensorboard_logger,
    save_on_val=True,
)