In [1]:
import gc
import numpy as np
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import utils, segmentation, AL, presentation
from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print(device)

cuda:0


In [3]:
ds_full = segmentation.SegmentationDataset('./cut_stenoses_data/train/', 2)
ds_test = segmentation.SegmentationDataset('./cut_stenoses_data/test/', 2)

batch_size = 8
dl_full = torch.utils.data.DataLoader(ds_full, batch_size=batch_size, shuffle=False, drop_last=False)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=False)

In [4]:
# w = 0.
# for imgs, masks in dl_full:
#     w += masks.sum() / (masks.numel() * len(dl_full))
# loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1 / (1-w), 1/ w])).to(device)
loss_fn = nn.CrossEntropyLoss()

In [5]:
model = presentation.create_model().to(device)
optimizer = presentation.create_optimizer(model)

In [6]:
N = len(ds_full)
N_step = int(N*0.1)
N, N_step

(598, 59)

In [7]:
idxs = list(range(N))
np.random.seed(0)
np.random.shuffle(idxs)
labeled_idxs, unlabeled_idxs = idxs[:N_step], idxs[N_step:]

In [8]:
for n_iter in range(10):
    print(f'>>>>>>>>>>>> Iter {n_iter}')
    del model
    del optimizer
    gc.collect()
    ds_train = AL.PartialDs(ds_full, labeled_idxs)
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    model = presentation.create_model().to(device)
    optimizer = presentation.create_optimizer(model)

    num_epochs = 30
    for epoch in tqdm(range(num_epochs), desc='Epoch'):
        trn_loss = utils.train(dl_train, loss_fn, model, optimizer)
        val_loss, val_IoUs = segmentation.evaluate(dl_test, loss_fn, model)
        print(epoch, f'{trn_loss:.4f}', f'{val_loss:.4f}', val_IoUs)
    print()

    ds_unlabeled = AL.PartialDs(ds_full, unlabeled_idxs)
    dl_unlabeled = torch.utils.data.DataLoader(ds_unlabeled, batch_size=batch_size, shuffle=False, drop_last=False)
    model.eval()    
    entropies = []
    for inp, tgt in dl_unlabeled:
        with torch.no_grad():
            out = model(inp.to(device))
        entropies.append(AL.compute_entropies(out))
    entropies = torch.cat(entropies)
    order = entropies.sort(descending=True).indices
    labeled_idxs += torch.tensor(unlabeled_idxs)[order[:N_step]].tolist()
    unlabeled_idxs = torch.tensor(unlabeled_idxs)[order[N_step:]].tolist()

>>>>>>>>>>>> Iter 0


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.2718 0.1588 tensor([0.])
1 0.0182 0.0289 tensor([0.])
2 0.0138 0.0277 tensor([0.])
3 0.0114 0.0255 tensor([0.])
4 0.0099 0.0271 tensor([0.])
5 0.0089 0.0267 tensor([0.])
6 0.0081 0.0269 tensor([0.])
7 0.0080 0.0278 tensor([0.])
8 0.0074 0.0272 tensor([0.])
9 0.0074 0.0273 tensor([0.])
10 0.0072 0.0286 tensor([0.])
11 0.0071 0.0257 tensor([0.])
12 0.0070 0.0288 tensor([0.])
13 0.0070 0.0296 tensor([0.])
14 0.0068 0.0281 tensor([0.])
15 0.0068 0.0261 tensor([0.])
16 0.0067 0.0255 tensor([0.])
17 0.0064 0.0267 tensor([0.])
18 0.0064 0.0291 tensor([0.])
19 0.0063 0.0266 tensor([0.])
20 0.0062 0.0271 tensor([0.])
21 0.0060 0.0283 tensor([0.])
22 0.0061 0.0300 tensor([0.])
23 0.0060 0.0297 tensor([0.])
24 0.0058 0.0268 tensor([0.])
25 0.0058 0.0275 tensor([0.])
26 0.0057 0.0269 tensor([0.])
27 0.0056 0.0294 tensor([0.])
28 0.0056 0.0289 tensor([0.])
29 0.0056 0.0285 tensor([0.])

>>>>>>>>>>>> Iter 1


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0941 0.0310 tensor([0.])
1 0.0148 0.0203 tensor([0.])
2 0.0103 0.0210 tensor([0.])
3 0.0090 0.0229 tensor([0.])
4 0.0082 0.0235 tensor([0.])
5 0.0077 0.0231 tensor([0.])
6 0.0075 0.0228 tensor([0.])
7 0.0072 0.0203 tensor([0.])
8 0.0068 0.0226 tensor([0.])
9 0.0067 0.0232 tensor([0.])
10 0.0065 0.0232 tensor([0.])
11 0.0062 0.0241 tensor([0.])
12 0.0061 0.0180 tensor([0.])
13 0.0059 0.0242 tensor([0.])
14 0.0058 0.0244 tensor([0.])
15 0.0057 0.0235 tensor([0.])
16 0.0055 0.0252 tensor([0.1380])
17 0.0054 0.0212 tensor([0.1814])
18 0.0052 0.0240 tensor([0.1838])
19 0.0052 0.0242 tensor([0.2184])
20 0.0050 0.0279 tensor([0.1946])
21 0.0049 0.0276 tensor([0.2058])
22 0.0050 0.0196 tensor([0.3214])
23 0.0048 0.0288 tensor([0.1957])
24 0.0048 0.0270 tensor([0.2479])
25 0.0046 0.0292 tensor([0.2335])
26 0.0045 0.0282 tensor([0.2466])
27 0.0044 0.0262 tensor([0.2773])
28 0.0043 0.0281 tensor([0.2538])
29 0.0043 0.0292 tensor([0.2630])

>>>>>>>>>>>> Iter 2


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0391 0.0293 tensor([2.2148e-05])
1 0.0125 0.0257 tensor([0.0002])
2 0.0091 0.0271 tensor([8.0074e-05])
3 0.0079 0.0281 tensor([0.0002])
4 0.0072 0.0274 tensor([0.0004])
5 0.0067 0.0297 tensor([0.0014])
6 0.0064 0.0286 tensor([0.0050])
7 0.0061 0.0292 tensor([0.0182])
8 0.0060 0.0308 tensor([0.1692])
9 0.0057 0.0296 tensor([0.2072])
10 0.0054 0.0313 tensor([0.2064])
11 0.0052 0.0303 tensor([0.2629])
12 0.0050 0.0320 tensor([0.2518])
13 0.0048 0.0319 tensor([0.2652])
14 0.0047 0.0315 tensor([0.2716])
15 0.0045 0.0324 tensor([0.2769])
16 0.0044 0.0326 tensor([0.2756])
17 0.0043 0.0335 tensor([0.2801])
18 0.0043 0.0337 tensor([0.2891])
19 0.0042 0.0325 tensor([0.2946])
20 0.0040 0.0330 tensor([0.2907])
21 0.0039 0.0323 tensor([0.2961])
22 0.0038 0.0329 tensor([0.3026])
23 0.0037 0.0332 tensor([0.3036])
24 0.0036 0.0330 tensor([0.3051])
25 0.0035 0.0325 tensor([0.3033])
26 0.0035 0.0336 tensor([0.3061])
27 0.0034 0.0339 tensor([0.3103])
28 0.0033 0.0335 tensor([0.3118])
29 0.0033 0.0328

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0782 0.0217 tensor([0.])
1 0.0102 0.0243 tensor([0.])
2 0.0088 0.0245 tensor([0.])
3 0.0081 0.0241 tensor([0.])
4 0.0075 0.0268 tensor([0.])
5 0.0071 0.0282 tensor([0.])
6 0.0069 0.0245 tensor([0.])
7 0.0064 0.0267 tensor([0.])
8 0.0061 0.0254 tensor([0.])
9 0.0058 0.0242 tensor([0.])
10 0.0056 0.0280 tensor([0.2072])
11 0.0054 0.0176 tensor([0.3240])
12 0.0052 0.0273 tensor([0.2683])
13 0.0050 0.0269 tensor([0.2695])
14 0.0047 0.0255 tensor([0.2831])
15 0.0046 0.0274 tensor([0.2830])
16 0.0044 0.0290 tensor([0.2763])
17 0.0043 0.0261 tensor([0.2950])
18 0.0041 0.0252 tensor([0.3121])
19 0.0040 0.0259 tensor([0.3081])
20 0.0039 0.0259 tensor([0.3090])
21 0.0037 0.0278 tensor([0.3073])
22 0.0036 0.0266 tensor([0.3143])
23 0.0035 0.0283 tensor([0.3065])
24 0.0034 0.0270 tensor([0.3134])
25 0.0034 0.0233 tensor([0.3389])
26 0.0033 0.0273 tensor([0.3177])
27 0.0033 0.0260 tensor([0.3318])
28 0.0031 0.0257 tensor([0.3302])
29 0.0030 0.0251 tensor([0.3335])

>>>>>>>>>>>> Iter 4


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0396 0.0253 tensor([0.])
1 0.0101 0.0261 tensor([0.])
2 0.0081 0.0280 tensor([0.])
3 0.0073 0.0315 tensor([0.])
4 0.0067 0.0301 tensor([0.])
5 0.0063 0.0298 tensor([0.])
6 0.0059 0.0287 tensor([0.0577])
7 0.0056 0.0258 tensor([0.1103])
8 0.0052 0.0307 tensor([0.0969])
9 0.0050 0.0294 tensor([0.1370])
10 0.0049 0.0313 tensor([0.1422])
11 0.0046 0.0311 tensor([0.1547])
12 0.0044 0.0301 tensor([0.1892])
13 0.0042 0.0298 tensor([0.1890])
14 0.0040 0.0323 tensor([0.1654])
15 0.0040 0.0313 tensor([0.1784])
16 0.0038 0.0314 tensor([0.1910])
17 0.0037 0.0299 tensor([0.2326])
18 0.0035 0.0327 tensor([0.1914])
19 0.0034 0.0300 tensor([0.2470])
20 0.0034 0.0317 tensor([0.2160])
21 0.0032 0.0294 tensor([0.2523])
22 0.0031 0.0298 tensor([0.2545])
23 0.0030 0.0309 tensor([0.2466])
24 0.0030 0.0295 tensor([0.2851])
25 0.0028 0.0331 tensor([0.2419])
26 0.0035 0.0277 tensor([0.3030])
27 0.0028 0.0285 tensor([0.3025])
28 0.0027 0.0280 tensor([0.3124])
29 0.0026 0.0299 tensor([0.2905])

>>>>>>>>>>>> 

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0440 0.0229 tensor([0.])
1 0.0086 0.0212 tensor([0.])
2 0.0074 0.0216 tensor([0.])
3 0.0067 0.0224 tensor([0.])
4 0.0062 0.0213 tensor([0.])
5 0.0057 0.0237 tensor([0.])
6 0.0054 0.0204 tensor([0.2712])
7 0.0051 0.0232 tensor([0.2746])
8 0.0048 0.0231 tensor([0.3003])
9 0.0046 0.0192 tensor([0.3702])
10 0.0043 0.0217 tensor([0.3546])
11 0.0041 0.0228 tensor([0.3437])
12 0.0039 0.0196 tensor([0.3873])
13 0.0038 0.0179 tensor([0.4217])
14 0.0036 0.0219 tensor([0.3682])
15 0.0034 0.0213 tensor([0.3831])
16 0.0033 0.0203 tensor([0.4057])
17 0.0031 0.0212 tensor([0.3979])
18 0.0031 0.0195 tensor([0.4213])
19 0.0030 0.0196 tensor([0.4254])
20 0.0029 0.0207 tensor([0.4146])
21 0.0027 0.0218 tensor([0.4099])
22 0.0027 0.0215 tensor([0.4178])
23 0.0026 0.0221 tensor([0.4199])
24 0.0026 0.0237 tensor([0.3987])
25 0.0025 0.0192 tensor([0.4480])
26 0.0024 0.0199 tensor([0.4456])
27 0.0023 0.0220 tensor([0.4266])
28 0.0023 0.0215 tensor([0.4327])
29 0.0022 0.0212 tensor([0.4475])

>>>>>>>>>>>> 

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0354 0.0217 tensor([0.])
1 0.0083 0.0226 tensor([0.])
2 0.0071 0.0249 tensor([0.])
3 0.0065 0.0232 tensor([0.])
4 0.0059 0.0234 tensor([0.])
5 0.0054 0.0242 tensor([0.2724])
6 0.0051 0.0251 tensor([0.2964])
7 0.0048 0.0246 tensor([0.3183])
8 0.0046 0.0240 tensor([0.3212])
9 0.0043 0.0261 tensor([0.3242])
10 0.0041 0.0252 tensor([0.3326])
11 0.0039 0.0224 tensor([0.3560])
12 0.0037 0.0204 tensor([0.3773])
13 0.0035 0.0249 tensor([0.3451])
14 0.0034 0.0247 tensor([0.3516])
15 0.0032 0.0238 tensor([0.3597])
16 0.0031 0.0253 tensor([0.3548])
17 0.0030 0.0229 tensor([0.3773])
18 0.0029 0.0218 tensor([0.3788])
19 0.0028 0.0248 tensor([0.3698])
20 0.0027 0.0242 tensor([0.3758])
21 0.0026 0.0251 tensor([0.3700])
22 0.0025 0.0251 tensor([0.3761])
23 0.0024 0.0235 tensor([0.3817])
24 0.0024 0.0254 tensor([0.3817])
25 0.0023 0.0281 tensor([0.3555])
26 0.0022 0.0235 tensor([0.3901])
27 0.0022 0.0237 tensor([0.3973])
28 0.0021 0.0239 tensor([0.3980])
29 0.0021 0.0247 tensor([0.3950])

>>>>>>>>>

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0304 0.0206 tensor([0.])
1 0.0078 0.0247 tensor([0.])
2 0.0067 0.0230 tensor([0.])
3 0.0059 0.0192 tensor([0.])
4 0.0054 0.0218 tensor([0.2243])
5 0.0050 0.0225 tensor([0.2773])
6 0.0046 0.0230 tensor([0.2957])
7 0.0043 0.0221 tensor([0.3247])
8 0.0041 0.0218 tensor([0.3336])
9 0.0038 0.0208 tensor([0.3657])
10 0.0037 0.0207 tensor([0.3869])
11 0.0034 0.0190 tensor([0.4153])
12 0.0032 0.0209 tensor([0.3903])
13 0.0031 0.0173 tensor([0.4407])
14 0.0030 0.0206 tensor([0.4049])
15 0.0028 0.0199 tensor([0.4159])
16 0.0027 0.0212 tensor([0.4059])
17 0.0026 0.0196 tensor([0.4295])
18 0.0025 0.0197 tensor([0.4278])
19 0.0024 0.0186 tensor([0.4482])
20 0.0023 0.0196 tensor([0.4394])
21 0.0023 0.0163 tensor([0.4848])
22 0.0022 0.0192 tensor([0.4543])
23 0.0021 0.0196 tensor([0.4501])
24 0.0021 0.0181 tensor([0.4679])
25 0.0020 0.0183 tensor([0.4744])
26 0.0020 0.0182 tensor([0.4788])
27 0.0019 0.0189 tensor([0.4687])
28 0.0019 0.0188 tensor([0.4782])
29 0.0018 0.0191 tensor([0.4738])

>>>>>

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0304 0.0271 tensor([0.])
1 0.0074 0.0272 tensor([0.])
2 0.0064 0.0246 tensor([0.])
3 0.0057 0.0253 tensor([0.])
4 0.0052 0.0272 tensor([0.2015])
5 0.0048 0.0273 tensor([0.2556])
6 0.0044 0.0247 tensor([0.3117])
7 0.0041 0.0259 tensor([0.3157])
8 0.0039 0.0242 tensor([0.3398])
9 0.0036 0.0252 tensor([0.3548])
10 0.0034 0.0246 tensor([0.3722])
11 0.0033 0.0224 tensor([0.4061])
12 0.0031 0.0279 tensor([0.3506])
13 0.0030 0.0237 tensor([0.4026])
14 0.0028 0.0247 tensor([0.3965])
15 0.0027 0.0239 tensor([0.4138])
16 0.0025 0.0240 tensor([0.4092])
17 0.0025 0.0233 tensor([0.4228])
18 0.0024 0.0267 tensor([0.3901])
19 0.0024 0.0231 tensor([0.4249])
20 0.0024 0.0242 tensor([0.4160])
21 0.0022 0.0242 tensor([0.4171])
22 0.0021 0.0246 tensor([0.4202])
23 0.0020 0.0227 tensor([0.4396])
24 0.0020 0.0230 tensor([0.4421])
25 0.0019 0.0231 tensor([0.4449])
26 0.0019 0.0237 tensor([0.4312])
27 0.0018 0.0230 tensor([0.4477])
28 0.0018 0.0238 tensor([0.4376])
29 0.0018 0.0225 tensor([0.4473])

>>>>>

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0202 0.0174 tensor([8.2320e-05])
1 0.0069 0.0202 tensor([8.0599e-05])
2 0.0058 0.0237 tensor([0.0027])
3 0.0052 0.0205 tensor([0.2344])
4 0.0046 0.0205 tensor([0.2941])
5 0.0043 0.0216 tensor([0.3072])
6 0.0040 0.0212 tensor([0.3334])
7 0.0037 0.0210 tensor([0.3530])
8 0.0035 0.0210 tensor([0.3717])
9 0.0033 0.0229 tensor([0.3516])
10 0.0031 0.0209 tensor([0.3908])
11 0.0029 0.0205 tensor([0.4163])
12 0.0028 0.0187 tensor([0.4497])
13 0.0026 0.0202 tensor([0.4277])
14 0.0025 0.0200 tensor([0.4369])
15 0.0021 0.0189 tensor([0.4483])
16 0.0016 0.0228 tensor([0.3996])
17 0.0015 0.0191 tensor([0.4535])
18 0.0013 0.0199 tensor([0.4535])
19 0.0013 0.0190 tensor([0.4636])
20 0.0013 0.0187 tensor([0.4698])
21 0.0012 0.0177 tensor([0.4853])
22 0.0011 0.0181 tensor([0.4808])
23 0.0011 0.0180 tensor([0.4893])
24 0.0010 0.0163 tensor([0.5125])
25 0.0010 0.0179 tensor([0.4968])
26 0.0010 0.0183 tensor([0.4951])
27 0.0010 0.0159 tensor([0.5280])
28 0.0010 0.0188 tensor([0.4912])
29 0.0009 0.0171