# LunksNN

In [1]:
import numpy as np
import albumentations
import random
import torch
import os

from pycocotools.coco import COCO
from albumentations.pytorch.transforms import ToTensorV2
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torch.nn import functional as F
from matplotlib import pyplot as plt
from os.path import join as pjoin
from PIL import Image
from torchvision import datasets
from torchvision.datasets import VOCSegmentation

from main import create_masks_for_all_images
from dataset import LunksDataset
from Unet import UNet, count_model_params
from train import (
    CheckpointSaver,
    IoUMetric,
    MulticlassCrossEntropyLoss,
    MulticlassDiceLoss,
    load_checkpoint,
    train
)

In [2]:
def seed_everything(seed: int = 314159, torch_deterministic: bool = False) -> None:
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.use_deterministic_algorithms(torch_deterministic)


seed_everything(42, torch_deterministic=False)

In [3]:
annFile = 'instances_default.json'
coco = COCO(annFile)

create_masks_for_all_images(coco, "Dataset/Masks")

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [4]:
IMAGE_SIZE = 256
transforms = albumentations.Compose(
    [
        albumentations.Resize(height=IMAGE_SIZE, width=IMAGE_SIZE),
        albumentations.HorizontalFlip(p=0.5),
        ToTensorV2(),
    ]
)

In [5]:
train_dataset = LunksDataset(root_dir="Dataset", transforms=transforms)

In [6]:
len(train_dataset)

27

In [7]:
image, mask = train_dataset[0]

In [8]:
image.shape

torch.Size([3, 512, 512])

In [9]:
mask.shape

torch.Size([1, 512, 512])

In [10]:
model = UNet(in_channels=3, out_channels=21)

In [11]:
count_model_params(model)

17264277

In [12]:
accelerator = Accelerator(cpu=False, mixed_precision="fp16")

In [13]:
LEARNING_RATE = 1e-4
BATCH_SIZE = 4
NUM_WORKERS = 2
EPOCH_NUM = 20
CHECKPOINTS_DIR = "checkpoints"
TENSORBOARD_DIR = "tensorboard"
RM_CHECKPOINTS_DIR = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)
val_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True
)

model = UNet(in_channels=3, out_channels=1)

loss_fn = MulticlassDiceLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer=optimizer, step_size=5, gamma=0.8
)
metric_fn = loss_fn

os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
checkpointer = CheckpointSaver(
    accelerator=accelerator,
    model=model,
    metric_name="DICE",
    save_dir=CHECKPOINTS_DIR,
    rm_save_dir=RM_CHECKPOINTS_DIR,
    max_history=5,
    should_minimize=True,
)

In [15]:
os.makedirs(TENSORBOARD_DIR, exist_ok=True)
tensorboard_logger = torch.utils.tensorboard.SummaryWriter(log_dir=TENSORBOARD_DIR)

In [16]:
model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler
)

In [17]:
train(
    model=model,
    optimizer=optimizer,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_function=loss_fn,
    metric_function=metric_fn,
    lr_scheduler=lr_scheduler,
    accelerator=accelerator,
    epoch_num=EPOCH_NUM,
    checkpointer=checkpointer,
    tb_logger=tensorboard_logger,
    save_on_val=True,
)

  0%|          | 0/20 [00:00<?, ?it/s]

------------------------------
Epoch 0/20


Training:   0%|          | 0/7 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
model = UNet(in_channels=3, out_channels=1)
model = load_checkpoint(
    model=model, load_path=pjoin(CHECKPOINTS_DIR, "model_checkpoint_best.pt")
)
model = model.to(DEVICE)
model.eval()

In [ ]:
sample_idx = 2
image, target = train_dataset[sample_idx]

thresold = 0.5
preds = F.sigmoid(model(image.unsqueeze(0).to(DEVICE))).squeeze(0)
binary_preds = (preds > thresold)

fig, ax = plt.subplots(1, 3, figsize=(9, 18))
ax[0].imshow(image.numpy().transpose(1, 2, 0).astype(np.uint8))
ax[1].imshow(target.numpy().transpose(1, 2, 0).astype(np.uint8))
ax[2].imshow(binary_preds.cpu()[0]);