In [24]:
import gc
import numpy as np
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import utils, segmentation
from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print(device)

cuda:0


In [22]:
class PartialDs(torch.utils.data.Dataset):
    def __init__(self, ds, idxs):
        self.ds = ds
        self.idxs = idxs
    
    def __len__(self):
        return len(self.idxs)
    
    def __getitem__(self, i): 
        return self.ds[self.idxs[i]]

def compute_entropies(out):
    prob = out.detach().cpu().softmax(dim=1)
    assert ((prob.sum(dim=1) - 1).abs() < 1e-5).all()
    entropy = - prob * prob.log()
    entropy = entropy.sum(dim=1) # sum over classes
    entropy = entropy.mean(dim=(-1,-2)) # average over pixels
    return entropy

def create_model():
    return smp.Unet(
        encoder_name="resnet18",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=2,                      # model output channels (number of classes in your dataset)
    )

def create_optimizer(model):
    return torch.optim.SGD(model.parameters(), lr=1e-1)

In [8]:
ds_full = segmentation.SegmentationDataset('./stenoses_data/train/', 2)
ds_test = segmentation.SegmentationDataset('./stenoses_data/test/', 2)

batch_size = 8
dl_full = torch.utils.data.DataLoader(ds_full, batch_size=batch_size, shuffle=False, drop_last=False)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=False)

In [13]:
w = 0.
for imgs, masks in dl_full:
    w += masks.sum() / (masks.numel() * len(dl_full))
loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1 / (1-w), 1/ w])).to(device)

In [16]:
model = create_model().to(device)
optimizer = create_optimizer(model)

In [36]:
N = len(ds_full)
N_step = int(N*0.1)
N, N_step

(598, 59)

In [37]:
idxs = list(range(N))
np.random.seed(0)
np.random.shuffle(idxs)
labeled_idxs, unlabeled_idxs = idxs[:N_step], idxs[N_step:]

In [38]:
for n_iter in range(10):
    print(f'>>>>>>>>>>>> Iter {n_iter}')
    del model
    del optimizer
    gc.collect()
    ds_train = PartialDs(ds_full, labeled_idxs)
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    model = create_model().to(device)
    optimizer = create_optimizer(model)

    num_epochs = 30
    for epoch in tqdm(range(num_epochs), desc='Epoch'):
        trn_loss = utils.train(dl_train, loss_fn, model, optimizer)
        val_loss, val_IoUs = segmentation.evaluate(dl_test, loss_fn, model)
        print(epoch, f'{trn_loss:.4f}', f'{val_loss:.4f}', val_IoUs)
    print()

    ds_unlabeled = PartialDs(ds_full, unlabeled_idxs)
    dl_unlabeled = torch.utils.data.DataLoader(ds_unlabeled, batch_size=batch_size, shuffle=False, drop_last=False)
    model.eval()    
    entropies = []
    for inp, tgt in dl_unlabeled:
        with torch.no_grad():
            out = model(inp.to(device))
        entropies.append(compute_entropies(out))
    entropies = torch.cat(entropies)
    order = entropies.sort(descending=True).indices
    labeled_idxs += torch.tensor(unlabeled_idxs)[order[:N_step]].tolist()
    unlabeled_idxs = torch.tensor(unlabeled_idxs)[order[N_step:]].tolist()

>>>>>>>>>>>> Iter 0


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.3922 0.7044 tensor([0.0018])
1 0.2308 1.1152 tensor([0.0031])
2 0.1208 0.2910 tensor([0.0761])
3 0.2856 0.7441 tensor([0.0055])
4 0.1518 1.2459 tensor([0.0169])
5 0.8364 10.1640 tensor([0.0023])
6 0.4502 1.7940 tensor([0.0140])
7 0.2337 0.6182 tensor([0.0358])
8 0.1190 0.1933 tensor([0.0569])
9 0.0760 0.6274 tensor([0.0807])
10 0.0628 0.2620 tensor([0.2193])
11 0.0333 0.1016 tensor([0.0539])
12 0.0487 0.1719 tensor([0.0658])
13 0.0355 0.2528 tensor([0.2424])
14 0.0225 0.1768 tensor([0.0688])
15 0.0236 0.2066 tensor([0.2889])
16 0.0168 0.1270 tensor([0.2406])
17 0.0368 2.9822 tensor([0.0090])
18 0.1015 0.0970 tensor([0.0948])
19 0.0254 0.1762 tensor([0.2325])
20 0.0250 0.0662 tensor([0.1401])
21 0.0183 0.2365 tensor([0.3025])
22 0.0153 0.1806 tensor([0.3047])
23 0.0116 0.1244 tensor([0.1906])
24 0.0105 0.1744 tensor([0.3327])
25 0.0095 0.2087 tensor([0.3619])
26 0.0109 0.1736 tensor([0.3120])
27 0.0088 0.1582 tensor([0.2989])
28 0.0097 0.1771 tensor([0.2748])
29 0.0081 0.1760 tensor

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.4297 1.9850 tensor([0.0027])
1 0.2042 0.2709 tensor([0.0460])
2 0.0568 0.1632 tensor([0.0197])
3 0.0473 0.1089 tensor([0.0584])
4 0.2227 1.4695 tensor([0.0104])
5 0.1012 0.2929 tensor([0.1568])
6 0.0511 0.6965 tensor([0.0289])
7 0.1381 0.2184 tensor([0.0187])
8 0.0715 0.2270 tensor([0.0279])
9 0.0638 0.1094 tensor([0.0314])
10 0.0418 0.2563 tensor([0.1261])
11 0.0751 0.5806 tensor([0.0128])
12 0.1658 0.6892 tensor([0.0509])
13 0.0702 0.4384 tensor([0.0948])
14 0.0369 0.1046 tensor([0.0872])
15 0.0221 0.3407 tensor([0.1548])
16 0.0182 0.2752 tensor([0.2010])
17 0.0169 0.1274 tensor([0.1426])
18 0.0141 0.4090 tensor([0.1952])
19 0.5389 0.6774 tensor([0.0058])
20 0.1499 0.2618 tensor([0.0866])
21 0.1163 0.1605 tensor([0.0720])
22 0.0638 0.0963 tensor([0.1215])
23 0.0382 0.6074 tensor([0.3053])
24 0.0526 0.0941 tensor([0.0221])
25 0.0408 0.5295 tensor([0.1478])
26 0.0310 0.7797 tensor([0.2325])
27 0.0210 0.2351 tensor([0.3157])
28 0.0161 0.7076 tensor([0.3258])
29 0.0130 0.0656 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.3675 0.2903 tensor([0.0503])
1 0.1687 0.1024 tensor([0.0196])
2 0.0748 0.0434 tensor([0.1083])
3 0.0239 0.0646 tensor([0.3530])
4 0.0177 0.0178 tensor([0.2337])
5 0.0115 0.0238 tensor([0.3377])
6 0.0093 0.0259 tensor([0.4037])
7 0.0077 0.0314 tensor([0.4237])
8 0.0064 0.0398 tensor([0.4865])
9 0.0069 0.0670 tensor([0.4786])
10 0.0054 0.0152 tensor([0.3625])
11 0.0047 0.1532 tensor([0.5971])
12 0.0136 0.0138 tensor([0.1793])
13 0.0082 0.0740 tensor([0.5949])
14 0.0060 0.0122 tensor([0.3940])
15 0.0050 0.0365 tensor([0.5338])
16 0.0051 0.0412 tensor([0.5498])
17 0.0045 0.0091 tensor([0.3855])
18 0.0041 0.0454 tensor([0.5675])
19 0.0037 0.0270 tensor([0.5301])
20 0.0043 0.0267 tensor([0.5090])
21 0.0034 0.0547 tensor([0.6000])
22 0.0033 0.0604 tensor([0.5981])
23 0.0031 0.0500 tensor([0.5839])
24 0.0029 0.0712 tensor([0.6100])
25 0.0030 0.0481 tensor([0.5803])
26 0.0028 0.0925 tensor([0.6367])
27 0.0027 0.1124 tensor([0.6325])
28 0.0026 0.1587 tensor([0.6558])
29 0.0027 0.0587 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.4479 0.1937 tensor([0.0704])
1 0.2461 1.0223 tensor([0.0208])
2 0.0835 0.1894 tensor([0.1178])
3 0.0273 0.0460 tensor([0.1311])
4 0.0220 0.0715 tensor([0.1272])
5 0.0153 0.0884 tensor([0.3619])
6 0.0152 0.2609 tensor([0.1841])
7 0.0459 0.1636 tensor([0.3047])
8 0.0129 0.0674 tensor([0.3215])
9 0.0097 0.0807 tensor([0.3509])
10 0.2718 0.1337 tensor([0.0996])
11 0.0542 0.0628 tensor([0.1253])
12 0.0331 0.0653 tensor([0.2378])
13 0.0167 0.0292 tensor([0.2581])
14 0.0621 0.0338 tensor([0.1198])
15 0.0204 0.0305 tensor([0.1697])
16 0.0152 0.0203 tensor([0.1317])
17 0.0119 0.0214 tensor([0.2340])
18 0.0097 0.0185 tensor([0.2104])
19 0.0087 0.0280 tensor([0.2460])
20 0.0081 0.0294 tensor([0.2923])
21 0.0108 0.0496 tensor([0.3202])
22 0.0066 0.0356 tensor([0.3408])
23 0.0059 0.0438 tensor([0.3631])
24 0.0058 0.0237 tensor([0.3014])
25 0.0053 0.0474 tensor([0.1454])
26 0.0061 0.0285 tensor([0.3707])
27 0.0046 0.0402 tensor([0.3733])
28 0.0044 0.0524 tensor([0.3920])
29 0.0043 0.0561 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.2050 0.2399 tensor([0.0696])
1 0.0431 0.0778 tensor([0.1175])
2 0.0226 0.0291 tensor([0.1098])
3 0.0142 0.0687 tensor([0.2993])
4 0.0093 0.0401 tensor([0.2535])
5 0.0075 0.1064 tensor([0.3818])
6 0.0061 0.1212 tensor([0.4334])
7 0.0059 0.0695 tensor([0.3397])
8 0.0050 0.1326 tensor([0.4442])
9 0.0052 0.0727 tensor([0.4157])
10 0.0043 0.1081 tensor([0.4603])
11 0.0065 0.1139 tensor([0.3733])
12 0.0040 0.1293 tensor([0.4087])
13 0.0040 0.1104 tensor([0.3954])
14 0.0035 0.1543 tensor([0.4818])
15 0.0034 0.1722 tensor([0.4909])
16 0.0035 0.2043 tensor([0.4937])
17 0.0033 0.1522 tensor([0.4636])
18 0.0032 0.1456 tensor([0.4783])
19 0.0030 0.0736 tensor([0.3791])
20 0.0034 0.1793 tensor([0.5000])
21 0.0028 0.1318 tensor([0.4577])
22 0.0028 0.1279 tensor([0.4700])
23 0.0025 0.1972 tensor([0.5160])
24 0.0024 0.1634 tensor([0.4990])
25 0.0023 0.1810 tensor([0.5221])
26 0.0024 0.1383 tensor([0.5317])
27 0.0022 0.1756 tensor([0.5349])
28 0.0023 0.1972 tensor([0.5307])
29 0.0023 0.1515 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.2077 0.2174 tensor([0.1378])
1 0.0435 0.0908 tensor([0.0673])
2 0.0262 0.0924 tensor([0.1825])
3 0.0147 0.0891 tensor([0.2205])
4 0.0095 0.0831 tensor([0.2123])
5 0.0083 0.1440 tensor([0.2808])
6 0.0068 0.1299 tensor([0.3184])
7 0.0072 0.2920 tensor([0.3800])
8 0.0097 0.0819 tensor([0.2098])
9 0.0079 0.2769 tensor([0.2905])
10 0.0060 0.2445 tensor([0.3656])
11 0.0053 0.1410 tensor([0.3945])
12 0.0046 0.1130 tensor([0.3046])
13 0.0044 0.2412 tensor([0.4270])
14 0.0062 0.1464 tensor([0.3416])
15 0.0041 0.1660 tensor([0.4456])
16 0.0037 0.1508 tensor([0.3856])
17 0.0038 0.2152 tensor([0.4372])
18 0.0034 0.2085 tensor([0.4169])
19 0.0032 0.3585 tensor([0.4926])
20 0.0034 0.2342 tensor([0.4033])
21 0.0032 0.2409 tensor([0.4256])
22 0.0028 0.2147 tensor([0.4380])
23 0.0029 0.3103 tensor([0.5178])
24 0.0028 0.3869 tensor([0.5155])
25 0.0028 0.2823 tensor([0.4944])
26 0.0026 0.3302 tensor([0.4830])
27 0.0025 0.3755 tensor([0.4940])
28 0.0024 0.3755 tensor([0.4757])
29 0.0026 0.4255 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1495 0.1456 tensor([0.0180])
1 0.0431 0.2839 tensor([0.1871])
2 0.0147 0.1432 tensor([0.2421])
3 0.0125 0.2884 tensor([0.4327])
4 0.0189 6.0143 tensor([0.0028])
5 0.0515 0.2270 tensor([0.0978])
6 0.0187 0.0621 tensor([0.2892])
7 0.0101 0.0281 tensor([0.2572])
8 0.0087 0.0569 tensor([0.3552])
9 0.0084 0.0875 tensor([0.3418])
10 0.0067 0.1263 tensor([0.3918])
11 0.0056 0.0497 tensor([0.2040])
12 0.0051 0.1055 tensor([0.4203])
13 0.0042 0.0927 tensor([0.4528])
14 0.0039 0.1458 tensor([0.5328])
15 0.0045 0.1625 tensor([0.4703])
16 0.0039 0.1381 tensor([0.5143])
17 0.0040 0.1557 tensor([0.4174])
18 0.0036 0.1382 tensor([0.4049])
19 0.0034 0.1256 tensor([0.5213])
20 0.0031 0.1711 tensor([0.5703])
21 0.0030 0.1834 tensor([0.5862])
22 0.0029 0.1422 tensor([0.5773])
23 0.0031 0.1372 tensor([0.5034])
24 0.0027 0.1841 tensor([0.5497])
25 0.0026 0.1362 tensor([0.5379])
26 0.0025 0.1559 tensor([0.5275])
27 0.0025 0.2505 tensor([0.5815])
28 0.0024 0.2239 tensor([0.5693])
29 0.0023 0.2471 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1606 0.0810 tensor([0.0443])
1 0.0284 0.1458 tensor([0.2030])
2 0.0188 0.2733 tensor([0.3417])
3 0.0095 0.2998 tensor([0.0291])
4 0.0106 0.2782 tensor([0.3776])
5 0.0057 0.3300 tensor([0.4260])
6 0.0053 0.3601 tensor([0.4158])
7 0.0043 0.3525 tensor([0.4504])
8 0.0049 0.1521 tensor([0.3459])
9 0.0041 0.5079 tensor([0.4671])
10 0.0036 0.2957 tensor([0.4137])
11 0.0034 0.6182 tensor([0.4775])
12 0.0031 0.6310 tensor([0.4735])
13 0.0035 0.6776 tensor([0.4630])
14 0.0029 0.6520 tensor([0.4961])
15 0.0027 0.3855 tensor([0.4768])
16 0.0026 0.7256 tensor([0.5117])
17 0.0024 0.6518 tensor([0.5029])
18 0.0023 0.7386 tensor([0.5070])
19 0.0025 0.5225 tensor([0.4928])
20 0.0023 0.7149 tensor([0.4927])
21 0.0022 0.7670 tensor([0.5097])
22 0.0021 0.7692 tensor([0.5145])
23 0.0021 0.6943 tensor([0.5122])
24 0.0021 0.7594 tensor([0.5069])
25 0.0023 0.6402 tensor([0.5023])
26 0.0020 0.8671 tensor([0.5044])
27 0.0018 0.7340 tensor([0.5164])
28 0.0022 0.7211 tensor([0.5361])
29 0.0019 0.6843 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1794 0.0822 tensor([0.0618])
1 0.0672 0.0844 tensor([0.0408])
2 0.0341 0.0859 tensor([0.1254])
3 0.0167 0.2901 tensor([0.2183])
4 0.0106 0.3729 tensor([0.1873])
5 0.0078 0.4210 tensor([0.2930])
6 0.0065 0.2550 tensor([0.2824])
7 0.0053 0.6577 tensor([0.3247])
8 0.0045 0.3382 tensor([0.3375])
9 0.0046 0.8443 tensor([0.3088])
10 0.0041 0.6376 tensor([0.3326])
11 0.0037 0.8353 tensor([0.3111])
12 0.0034 0.9003 tensor([0.3375])
13 0.0033 0.5608 tensor([0.3278])
14 0.0030 0.7179 tensor([0.3691])
15 0.0031 0.8090 tensor([0.3777])
16 0.0028 0.5004 tensor([0.3334])
17 0.0028 1.0730 tensor([0.3610])
18 0.0027 1.0672 tensor([0.3506])
19 0.0027 0.8493 tensor([0.3565])
20 0.0024 0.4560 tensor([0.3059])
21 0.0027 0.7683 tensor([0.4012])
22 0.0025 0.8096 tensor([0.3793])
23 0.0022 1.2691 tensor([0.3515])
24 0.0023 1.2566 tensor([0.3704])
25 0.0022 1.0252 tensor([0.3818])
26 0.0021 1.1999 tensor([0.3624])
27 0.0020 1.3967 tensor([0.3586])
28 0.0022 1.1502 tensor([0.3648])
29 0.0021 0.7052 tensor(

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.2554 0.1258 tensor([0.0367])
1 0.0280 0.0633 tensor([0.1762])
2 0.0119 0.0838 tensor([0.1499])
3 0.0073 0.1754 tensor([0.3337])
4 0.0059 0.1678 tensor([0.3370])
5 0.0049 0.2482 tensor([0.3903])
6 0.0053 0.1894 tensor([0.3112])
7 0.0043 0.1980 tensor([0.3741])
8 0.0054 0.1261 tensor([0.3195])
9 0.0049 0.2779 tensor([0.4366])
10 0.0033 0.2323 tensor([0.4338])
11 0.0030 0.2441 tensor([0.5050])
12 0.0028 0.2354 tensor([0.4853])
13 0.0027 0.1960 tensor([0.3966])
14 0.0025 0.2905 tensor([0.5087])
15 0.0024 0.2212 tensor([0.4538])
16 0.0024 0.2778 tensor([0.5047])
17 0.0025 0.2971 tensor([0.5027])
18 0.0023 0.3434 tensor([0.4996])
19 0.0022 0.3454 tensor([0.4820])
20 0.0021 0.2983 tensor([0.5111])
21 0.0022 0.2656 tensor([0.5018])
22 0.0020 0.3390 tensor([0.5132])
23 0.0019 0.3138 tensor([0.5060])
24 0.0019 0.3455 tensor([0.5164])
25 0.0019 0.3905 tensor([0.5284])
26 0.0019 0.3573 tensor([0.5253])
27 0.0018 0.3719 tensor([0.5126])
28 0.0017 0.3298 tensor([0.5132])
29 0.0017 0.3247 tensor(