In [1]:
import gc
import numpy as np
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import utils, segmentation
from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print(device)

cuda:1


In [3]:
class PartialDs(torch.utils.data.Dataset):
    def __init__(self, ds, idxs):
        self.ds = ds
        self.idxs = idxs
    
    def __len__(self):
        return len(self.idxs)
    
    def __getitem__(self, i): 
        return self.ds[self.idxs[i]]

def create_model():
    return smp.Unet(
        encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=2,                      # model output channels (number of classes in your dataset)
    )

def create_optimizer(model):
    return torch.optim.SGD(model.parameters(), lr=1e-1)

In [4]:
ds_full = segmentation.SegmentationDataset('./stenoses_data/train/', 2)
ds_test = segmentation.SegmentationDataset('./stenoses_data/test/', 2)

batch_size = 8
dl_full = torch.utils.data.DataLoader(ds_full, batch_size=batch_size, shuffle=False, drop_last=False)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=False)

In [5]:
w = 0.
for imgs, masks in dl_full:
    w += masks.sum() / (masks.numel() * len(dl_full))
loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1 / (1-w), 1/ w])).to(device)

In [6]:
model = create_model().to(device)
optimizer = create_optimizer(model)

In [7]:
N = len(ds_full)
N_step = int(N*0.1)
N, N_step

(598, 59)

In [8]:
idxs = list(range(N))
np.random.seed(0)
np.random.shuffle(idxs)

In [9]:
for n_iter in range(10):
    print(f'>>>>>>>>>>>> Iter {n_iter}')
    del model
    del optimizer
    gc.collect()
    labeled_idxs = idxs[:N_step*(n_iter+1)]
    ds_train = PartialDs(ds_full, labeled_idxs)
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    model = create_model().to(device)
    optimizer = create_optimizer(model)

    num_epochs = 30
    for epoch in tqdm(range(num_epochs), desc='Epoch'):
        trn_loss = utils.train(dl_train, loss_fn, model, optimizer)
        val_loss, val_IoUs = segmentation.evaluate(dl_test, loss_fn, model)
        print(epoch, f'{trn_loss:.4f}', f'{val_loss:.4f}', val_IoUs)
    print()

>>>>>>>>>>>> Iter 0


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.4571 1.8557 tensor([0.0024])
1 0.2324 1.4524 tensor([0.0052])
2 0.1826 0.3398 tensor([0.0085])
3 0.1570 0.2998 tensor([0.0103])
4 0.1024 0.2925 tensor([0.0272])
5 0.0667 0.1586 tensor([0.0863])
6 0.0413 0.3316 tensor([0.1308])
7 0.0288 0.2890 tensor([0.1666])
8 0.0280 0.1672 tensor([0.0607])
9 0.0223 0.1043 tensor([0.0771])
10 0.0195 0.4539 tensor([0.1302])
11 0.0398 0.0819 tensor([0.0641])
12 0.0423 0.6755 tensor([0.0098])
13 0.0612 0.1405 tensor([0.0487])
14 0.0237 0.5028 tensor([0.0989])
15 0.0427 0.0821 tensor([0.0293])
16 0.0298 0.1822 tensor([0.0858])
17 0.0220 0.2859 tensor([0.1123])
18 0.0184 0.3837 tensor([0.1588])
19 0.0157 0.6191 tensor([0.1589])
20 0.0131 0.7149 tensor([0.1495])
21 0.0114 0.9659 tensor([0.1259])
22 0.0104 0.8656 tensor([0.1391])
23 0.0096 0.7203 tensor([0.1507])
24 0.0084 1.2477 tensor([0.1197])
25 0.0075 1.2970 tensor([0.1323])
26 0.1855 7.5379 tensor([0.0026])
27 0.0981 0.3428 tensor([0.0106])
28 0.0596 0.1245 tensor([0.0450])
29 0.0420 0.1521 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.6060 45.0837 tensor([0.0018])
1 0.3484 0.1699 tensor([0.0140])
2 0.0936 0.1066 tensor([0.0341])
3 0.0505 0.0817 tensor([0.0396])
4 0.0551 0.1432 tensor([0.0967])
5 0.0234 0.3747 tensor([0.2221])
6 0.0542 0.1971 tensor([0.0138])
7 0.0374 0.0871 tensor([0.0970])
8 0.0205 0.0872 tensor([0.1718])
9 0.0184 0.1300 tensor([0.2656])
10 0.0157 0.1563 tensor([0.3081])
11 0.0109 0.1112 tensor([0.2042])
12 0.0087 0.1234 tensor([0.2875])
13 0.0079 0.1859 tensor([0.3405])
14 0.0071 0.0997 tensor([0.2087])
15 0.0070 0.0950 tensor([0.2298])
16 0.0086 0.1654 tensor([0.3693])
17 0.0067 0.2082 tensor([0.3692])
18 0.0055 0.2116 tensor([0.3937])
19 0.0054 0.2618 tensor([0.4202])
20 0.0050 0.2577 tensor([0.4057])
21 0.0050 0.1985 tensor([0.3630])
22 0.0045 0.1554 tensor([0.3625])
23 0.0042 0.1872 tensor([0.3794])
24 0.0122 0.1670 tensor([0.3295])
25 0.0059 0.2948 tensor([0.3939])
26 0.0045 0.2873 tensor([0.4076])
27 0.0046 0.2941 tensor([0.4267])
28 0.0040 0.3315 tensor([0.4364])
29 0.0039 0.2775 tensor

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.2553 0.5063 tensor([0.0048])
1 0.1111 0.1315 tensor([0.0215])
2 0.0589 0.1776 tensor([0.0408])
3 0.0372 0.1424 tensor([0.1741])
4 0.0296 0.1392 tensor([0.1628])
5 0.0272 0.0962 tensor([0.0352])
6 0.0176 0.0488 tensor([0.2248])
7 0.0270 0.0225 tensor([0.2117])
8 0.0118 0.0156 tensor([0.1631])
9 0.0109 0.0187 tensor([0.2890])
10 0.0070 0.0316 tensor([0.3660])
11 0.0064 0.0146 tensor([0.2187])
12 0.0059 0.0313 tensor([0.4004])
13 0.0050 0.0523 tensor([0.3796])
14 0.0131 0.0535 tensor([0.3863])
15 0.0059 0.0367 tensor([0.3708])
16 0.0075 0.0394 tensor([0.4025])
17 0.0049 0.0461 tensor([0.4243])
18 0.0045 0.0233 tensor([0.3300])
19 0.0047 0.0584 tensor([0.4689])
20 0.0045 0.0579 tensor([0.4622])
21 0.0036 0.0582 tensor([0.4514])
22 0.0035 0.0585 tensor([0.4879])
23 0.0035 0.0483 tensor([0.4356])
24 0.0034 0.0957 tensor([0.5359])
25 0.0033 0.0951 tensor([0.5457])
26 0.0030 0.0850 tensor([0.4620])
27 0.0030 0.0840 tensor([0.5132])
28 0.0027 0.1121 tensor([0.5778])
29 0.0027 0.1148 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.3044 0.1383 tensor([0.0416])
1 0.0931 0.0882 tensor([0.0276])
2 0.0321 0.1179 tensor([0.2183])
3 0.0153 0.4643 tensor([0.2540])
4 0.0158 0.3728 tensor([0.2900])
5 0.0087 0.5119 tensor([0.2892])
6 0.0082 0.2687 tensor([0.2796])
7 0.0096 0.7925 tensor([0.3441])
8 0.0069 0.6048 tensor([0.3434])
9 0.0068 0.4666 tensor([0.3544])
10 0.0053 0.2910 tensor([0.4215])
11 0.0046 0.6281 tensor([0.3931])
12 0.0043 0.3650 tensor([0.4118])
13 0.0037 0.6641 tensor([0.3942])
14 0.0036 0.9369 tensor([0.3694])
15 0.0041 0.7727 tensor([0.3962])
16 0.0057 0.7786 tensor([0.4056])
17 0.0036 0.7663 tensor([0.4311])
18 0.0034 0.9707 tensor([0.3751])
19 0.0034 0.6383 tensor([0.3723])
20 0.0030 1.0223 tensor([0.3843])
21 0.0030 0.7072 tensor([0.4158])
22 0.0027 0.8574 tensor([0.4001])
23 0.0026 0.7635 tensor([0.4131])
24 0.0027 0.3968 tensor([0.4220])
25 0.0037 0.7192 tensor([0.4215])
26 0.0029 0.8785 tensor([0.4261])
27 0.0027 0.9254 tensor([0.4096])
28 0.0025 0.9964 tensor([0.4279])
29 0.0023 0.8699 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.2109 0.1857 tensor([0.0553])
1 0.0469 0.0519 tensor([0.1047])
2 0.0420 0.2211 tensor([0.0135])
3 0.0343 0.0896 tensor([0.2556])
4 0.0886 0.2761 tensor([0.2324])
5 0.0473 0.2420 tensor([0.2700])
6 0.0158 0.1369 tensor([0.2299])
7 0.0116 0.1035 tensor([0.3630])
8 0.0085 0.1207 tensor([0.4284])
9 0.0076 0.1402 tensor([0.4451])
10 0.0981 0.0847 tensor([0.0230])
11 0.0466 0.0698 tensor([0.0746])
12 0.0307 0.0830 tensor([0.0778])
13 0.0210 0.3573 tensor([0.1472])
14 0.0167 0.1145 tensor([0.2524])
15 0.0105 0.2852 tensor([0.0324])
16 0.0122 0.1557 tensor([0.2769])
17 0.0109 0.0354 tensor([0.1472])
18 0.0076 0.1220 tensor([0.3219])
19 0.0367 0.1988 tensor([0.2303])
20 0.0143 0.1825 tensor([0.2777])
21 0.0083 0.2645 tensor([0.3087])
22 0.0072 0.2976 tensor([0.3742])
23 0.0060 0.3207 tensor([0.3701])
24 0.0059 0.2800 tensor([0.4124])
25 0.0051 0.3498 tensor([0.4155])
26 0.0047 0.3212 tensor([0.3706])
27 0.0057 0.3393 tensor([0.4631])
28 0.0043 0.2908 tensor([0.4321])
29 0.0042 0.3515 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1995 0.0893 tensor([0.0856])
1 0.0398 0.3438 tensor([0.0138])
2 0.0304 0.7154 tensor([0.1613])
3 0.0547 1.6217 tensor([0.0058])
4 0.0703 0.0709 tensor([0.1360])
5 0.0220 0.1787 tensor([0.2835])
6 0.0148 0.0968 tensor([0.1650])
7 0.0206 0.2362 tensor([0.2551])
8 0.0096 0.0582 tensor([0.2274])
9 0.0102 0.1471 tensor([0.0302])
10 0.0197 0.0858 tensor([0.2087])
11 0.0088 0.2233 tensor([0.3419])
12 0.0061 0.2812 tensor([0.3413])
13 0.0061 0.2950 tensor([0.2645])
14 0.0050 0.4791 tensor([0.3772])
15 0.0042 0.5004 tensor([0.4251])
16 0.0043 0.2760 tensor([0.3489])
17 0.0037 0.4666 tensor([0.4103])
18 0.0038 0.2848 tensor([0.3841])
19 0.0041 0.5977 tensor([0.4635])
20 0.0034 0.4945 tensor([0.4604])
21 0.0077 0.3496 tensor([0.4472])
22 0.0038 0.4338 tensor([0.4382])
23 0.0032 0.5123 tensor([0.4540])
24 0.0034 0.4757 tensor([0.4687])
25 0.0031 0.5806 tensor([0.4425])
26 0.0029 0.6128 tensor([0.4722])
27 0.0039 0.4910 tensor([0.4778])
28 0.0030 0.5622 tensor([0.5070])
29 0.0032 0.2409 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1753 0.5943 tensor([0.1725])
1 0.0817 0.0699 tensor([0.0563])
2 0.0281 0.0297 tensor([0.1321])
3 0.0159 0.0846 tensor([0.4163])
4 0.0106 0.0551 tensor([0.4572])
5 0.0070 0.0225 tensor([0.3293])
6 0.0055 0.1287 tensor([0.5579])
7 0.0073 0.1158 tensor([0.5196])
8 0.0052 0.1889 tensor([0.5982])
9 0.0065 0.0937 tensor([0.5506])
10 0.0041 0.1716 tensor([0.5766])
11 0.0036 0.1946 tensor([0.5781])
12 0.0034 0.1757 tensor([0.5566])
13 0.0037 0.0725 tensor([0.4164])
14 0.0034 0.1545 tensor([0.5426])
15 0.0033 0.2449 tensor([0.5610])
16 0.0030 0.2017 tensor([0.5573])
17 0.0032 0.1209 tensor([0.5316])
18 0.0027 0.2649 tensor([0.6010])
19 0.0030 0.1558 tensor([0.5795])
20 0.0247 0.2110 tensor([0.4343])
21 0.0060 0.0804 tensor([0.4678])
22 0.0051 0.1343 tensor([0.5296])
23 0.0041 0.1907 tensor([0.5267])
24 0.0038 0.2486 tensor([0.5424])
25 0.0033 0.1773 tensor([0.5233])
26 0.0037 0.1636 tensor([0.5633])
27 0.0032 0.1352 tensor([0.5935])
28 0.0031 0.2537 tensor([0.5529])
29 0.0027 0.1950 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1370 0.2510 tensor([0.2826])
1 0.0745 0.0587 tensor([0.1066])
2 0.0207 0.1725 tensor([0.4162])
3 0.0373 0.0544 tensor([0.0624])
4 0.0176 0.0671 tensor([0.3339])
5 0.0169 0.0802 tensor([0.4093])
6 0.0090 0.0298 tensor([0.1298])
7 0.0072 0.0217 tensor([0.3237])
8 0.0055 0.0633 tensor([0.4791])
9 0.0051 0.0320 tensor([0.3792])
10 0.0051 0.0554 tensor([0.4731])
11 0.0044 0.0887 tensor([0.5258])
12 0.0054 0.0862 tensor([0.5340])
13 0.0037 0.0367 tensor([0.4755])
14 0.0045 0.1145 tensor([0.5530])
15 0.0037 0.1043 tensor([0.5810])
16 0.0036 0.0441 tensor([0.4274])
17 0.0037 0.0500 tensor([0.3956])
18 0.0033 0.1604 tensor([0.5882])
19 0.0031 0.0692 tensor([0.5577])
20 0.0027 0.1070 tensor([0.5638])
21 0.0030 0.2465 tensor([0.6534])
22 0.0030 0.1622 tensor([0.6280])
23 0.0032 0.1008 tensor([0.5759])
24 0.0029 0.0734 tensor([0.5164])
25 0.0027 0.1936 tensor([0.6265])
26 0.0024 0.2202 tensor([0.6212])
27 0.0024 0.1765 tensor([0.6285])
28 0.0027 0.2307 tensor([0.6362])
29 0.0024 0.2489 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1372 0.1356 tensor([0.3113])
1 0.0144 0.0510 tensor([0.3007])
2 0.0096 0.1763 tensor([0.5076])
3 0.0069 0.0716 tensor([0.3916])
4 0.0058 0.0411 tensor([0.4242])
5 0.0051 0.1128 tensor([0.5068])
6 0.0043 0.1320 tensor([0.5372])
7 0.0035 0.1800 tensor([0.5835])
8 0.7794 0.6610 tensor([0.0167])
9 0.2081 0.5551 tensor([0.1203])
10 0.0446 0.1358 tensor([0.0847])
11 0.0191 0.3081 tensor([0.2146])
12 0.0132 0.2954 tensor([0.1351])
13 0.0101 0.5533 tensor([0.3189])
14 0.0082 0.1937 tensor([0.1872])
15 0.0082 0.2067 tensor([0.2819])
16 0.0060 0.2688 tensor([0.3819])
17 0.0075 0.3131 tensor([0.3250])
18 0.0059 0.2340 tensor([0.3652])
19 0.0055 0.2872 tensor([0.3469])
20 0.0050 0.2836 tensor([0.3696])
21 0.0045 0.2492 tensor([0.3893])
22 0.0045 0.3696 tensor([0.4056])
23 0.0041 0.2401 tensor([0.3137])
24 0.0047 0.2754 tensor([0.3991])
25 0.0044 0.3383 tensor([0.4447])
26 0.0039 0.3196 tensor([0.4235])
27 0.0039 0.3303 tensor([0.3975])
28 0.0036 0.3452 tensor([0.4272])
29 0.0036 0.3066 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1634 0.1159 tensor([0.0213])
1 0.0286 0.1450 tensor([0.0299])
2 0.0155 0.1590 tensor([0.2512])
3 0.0090 0.1300 tensor([0.0718])
4 0.0069 0.2851 tensor([0.3832])
5 0.0058 0.3977 tensor([0.4145])
6 0.0047 0.3593 tensor([0.3964])
7 0.0043 0.4222 tensor([0.4334])
8 0.0042 0.5226 tensor([0.4178])
9 0.0036 0.3227 tensor([0.3927])
10 0.0037 0.5332 tensor([0.4202])
11 0.0031 0.4034 tensor([0.4740])
12 0.0030 0.4113 tensor([0.3995])
13 0.0029 0.4095 tensor([0.4219])
14 0.0032 0.3531 tensor([0.4627])
15 0.0030 0.3750 tensor([0.4657])
16 0.0027 0.4949 tensor([0.4585])
17 0.0026 0.3753 tensor([0.4707])
18 0.0023 0.2968 tensor([0.4146])
19 0.0023 0.5324 tensor([0.4551])
20 0.0022 0.4745 tensor([0.4748])
21 0.0022 0.4011 tensor([0.4663])
22 0.0022 0.5349 tensor([0.4823])
23 0.0021 0.5446 tensor([0.4719])
24 0.0020 0.5418 tensor([0.4960])
25 0.0020 0.4030 tensor([0.4569])
26 0.0026 0.3308 tensor([0.2549])
27 0.0030 0.4416 tensor([0.4807])
28 0.0035 0.2308 tensor([0.4232])
29 0.0025 0.3406 tensor(