In [1]:
# import numpy as np
# import pickle, os, torch
# from pathlib import Path
# from PIL import Image
# from segmentation_models_pytorch.encoders import get_preprocessing_fn
# import torchvision
# import torchvision.transforms as TF
# import torchvision.transforms.v2 as TF2

# from tqdm.auto import tqdm
# import matplotlib.pyplot as plt
import os
import utils, segmentation
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

In [2]:
ds_train = segmentation.SegmentationDataset('./stenoses_data/train/', 2)
ds_test = segmentation.SegmentationDataset('./stenoses_data/test/', 2)

batch_size = 8
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=False)

In [3]:
len(dl_train), len(dl_test)

(74, 19)

In [4]:
device = torch.device('cuda:1')

In [5]:
model = smp.Unet(
    encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=2,                      # model output channels (number of classes in your dataset)
).to(device)

In [6]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

In [7]:
w = 0.
for imgs, masks in dl_train:
    w += masks.sum() / (masks.numel() * len(dl_train))
loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1 / (1-w), 1/ w])).to(device)

In [8]:
for epoch in range(50):
    trn_loss = utils.train(dl_train, loss_fn, model, optimizer)
    val_loss, val_IoUs = segmentation.evaluate(dl_test, loss_fn, model)
    print(epoch, f'{trn_loss:.4f}', f'{val_loss:.4f}', val_IoUs)

0 0.1524 0.3343 tensor([0.0193])
1 0.0458 0.1384 tensor([0.0743])
2 0.0329 0.0797 tensor([0.2065])
3 0.0099 0.0599 tensor([0.3718])
4 0.0115 0.1356 tensor([0.4679])
5 0.0057 0.0765 tensor([0.4250])
6 0.0053 0.0998 tensor([0.4504])
7 0.0049 0.1408 tensor([0.5068])
8 0.0039 0.0850 tensor([0.4850])
9 0.0037 0.1979 tensor([0.5076])
10 0.0044 0.2669 tensor([0.5244])
11 0.0036 0.2974 tensor([0.5341])
12 0.0034 0.1722 tensor([0.5702])
13 0.0037 0.3034 tensor([0.5395])
14 0.0029 0.1029 tensor([0.5147])
15 0.0031 0.2141 tensor([0.5979])
16 0.0029 0.1973 tensor([0.5651])
17 0.0026 0.2019 tensor([0.5863])
18 0.0026 0.2082 tensor([0.6056])
19 0.0025 0.2655 tensor([0.6453])
20 0.0025 0.2314 tensor([0.5976])
21 0.0023 0.2352 tensor([0.6204])
22 0.0023 0.2559 tensor([0.6186])
23 0.0022 0.3349 tensor([0.6192])
24 0.0022 0.2432 tensor([0.6294])
25 0.0021 0.3590 tensor([0.6132])
26 0.0021 0.3663 tensor([0.6115])
27 0.0020 0.3706 tensor([0.6164])
28 0.0019 0.3451 tensor([0.6219])
29 0.0020 0.4271 tensor(