In [1]:
import gc
import numpy as np
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import utils, segmentation, AL, presentation
from tqdm.auto import tqdm

In [2]:
device = torch.device('cuda:1') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print(device)

cuda:1


In [3]:
ds_full = segmentation.SegmentationDataset('./cut_stenoses_data/train/', 2)
ds_test = segmentation.SegmentationDataset('./cut_stenoses_data/test/', 2)

batch_size = 8
dl_full = torch.utils.data.DataLoader(ds_full, batch_size=batch_size, shuffle=False, drop_last=False)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=batch_size, shuffle=False, drop_last=False)

In [4]:
# w = 0.
# for imgs, masks in dl_full:
#     w += masks.sum() / (masks.numel() * len(dl_full))
# loss_fn = nn.CrossEntropyLoss(weight=torch.Tensor([1 / (1-w), 1/ w])).to(device)
loss_fn = nn.CrossEntropyLoss()

In [5]:
model = presentation.create_model().to(device)
optimizer = presentation.create_optimizer(model)

In [6]:
N = len(ds_full)
N_step = int(N*0.1)
N, N_step

(598, 59)

In [7]:
idxs = list(range(N))
np.random.seed(0)
np.random.shuffle(idxs)
labeled_idxs, unlabeled_idxs = idxs[:N_step], idxs[N_step:]

In [8]:
for n_iter in range(10):
    print(f'>>>>>>>>>>>> Iter {n_iter}')
    del model
    del optimizer
    gc.collect()
    ds_train = AL.PartialDs(ds_full, labeled_idxs)
    dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, shuffle=True, drop_last=True)

    model = presentation.create_model().to(device)
    optimizer = presentation.create_optimizer(model)

    num_epochs = 30
    for epoch in tqdm(range(num_epochs), desc='Epoch'):
        trn_loss = utils.train(dl_train, loss_fn, model, optimizer)
        val_loss, val_IoUs = segmentation.evaluate(dl_test, loss_fn, model)
        print(epoch, f'{trn_loss:.4f}', f'{val_loss:.4f}', val_IoUs)
    print()

    ds_unlabeled = AL.PartialDs(ds_full, unlabeled_idxs)
    dl_unlabeled = torch.utils.data.DataLoader(ds_unlabeled, batch_size=batch_size, shuffle=False, drop_last=False)
    model.eval()    
    entropies = []
    for inp, tgt in dl_unlabeled:
        with torch.no_grad():
            out = model(inp.to(device))
        entropies.append(AL.compute_entropies(out))
    entropies = torch.cat(entropies)
    order = entropies.sort(descending=False).indices
    labeled_idxs += torch.tensor(unlabeled_idxs)[order[:N_step]].tolist()
    unlabeled_idxs = torch.tensor(unlabeled_idxs)[order[N_step:]].tolist()

>>>>>>>>>>>> Iter 0


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1729 0.0454 tensor([0.])
1 0.0233 0.0320 tensor([0.])
2 0.0156 0.0268 tensor([0.])
3 0.0118 0.0240 tensor([0.])
4 0.0099 0.0238 tensor([0.])
5 0.0095 0.0220 tensor([0.])
6 0.0090 0.0234 tensor([0.])
7 0.0083 0.0251 tensor([0.])
8 0.0076 0.0253 tensor([0.])
9 0.0073 0.0249 tensor([0.])
10 0.0072 0.0264 tensor([0.])
11 0.0072 0.0254 tensor([0.])
12 0.0069 0.0251 tensor([0.])
13 0.0065 0.0247 tensor([0.])
14 0.0065 0.0244 tensor([0.])
15 0.0063 0.0249 tensor([0.])
16 0.0063 0.0247 tensor([0.])
17 0.0061 0.0263 tensor([0.])
18 0.0059 0.0270 tensor([0.])
19 0.0060 0.0276 tensor([0.])
20 0.0059 0.0263 tensor([0.])
21 0.0059 0.0254 tensor([0.])
22 0.0058 0.0267 tensor([0.])
23 0.0057 0.0245 tensor([0.])
24 0.0056 0.0289 tensor([0.])
25 0.0055 0.0253 tensor([0.])
26 0.0055 0.0278 tensor([0.])
27 0.0057 0.0266 tensor([0.])
28 0.0054 0.0252 tensor([0.])
29 0.0054 0.0270 tensor([0.])

>>>>>>>>>>>> Iter 1


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.1108 0.0346 tensor([0.])
1 0.0143 0.0214 tensor([0.])
2 0.0107 0.0210 tensor([0.])
3 0.0088 0.0220 tensor([0.])
4 0.0080 0.0216 tensor([0.])
5 0.0071 0.0225 tensor([0.])
6 0.0069 0.0215 tensor([0.])
7 0.0064 0.0231 tensor([0.])
8 0.0063 0.0239 tensor([0.])
9 0.0062 0.0209 tensor([0.])
10 0.0061 0.0218 tensor([0.])
11 0.0059 0.0228 tensor([0.])
12 0.0056 0.0230 tensor([0.])
13 0.0054 0.0223 tensor([0.])
14 0.0054 0.0227 tensor([0.])
15 0.0053 0.0225 tensor([0.])
16 0.0052 0.0216 tensor([0.])
17 0.0050 0.0223 tensor([0.])
18 0.0049 0.0214 tensor([0.])
19 0.0048 0.0217 tensor([0.])
20 0.0047 0.0201 tensor([0.])
21 0.0047 0.0223 tensor([2.3981e-05])
22 0.0046 0.0194 tensor([0.0011])
23 0.0045 0.0191 tensor([0.3019])
24 0.0044 0.0202 tensor([0.2979])
25 0.0043 0.0205 tensor([0.3114])
26 0.0042 0.0207 tensor([0.3061])
27 0.0042 0.0208 tensor([0.3223])
28 0.0041 0.0206 tensor([0.3101])
29 0.0042 0.0202 tensor([0.3324])

>>>>>>>>>>>> Iter 2


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0856 0.0271 tensor([0.])
1 0.0116 0.0261 tensor([0.])
2 0.0087 0.0254 tensor([0.])
3 0.0075 0.0271 tensor([0.])
4 0.0066 0.0272 tensor([0.])
5 0.0062 0.0265 tensor([0.])
6 0.0058 0.0277 tensor([0.])
7 0.0055 0.0284 tensor([0.])
8 0.0054 0.0245 tensor([0.])
9 0.0052 0.0273 tensor([0.])
10 0.0050 0.0284 tensor([0.])
11 0.0048 0.0279 tensor([0.])
12 0.0047 0.0243 tensor([0.])
13 0.0045 0.0279 tensor([0.])
14 0.0044 0.0268 tensor([0.])
15 0.0043 0.0281 tensor([0.])
16 0.0042 0.0271 tensor([0.])
17 0.0041 0.0275 tensor([0.2080])
18 0.0041 0.0255 tensor([0.2551])
19 0.0039 0.0235 tensor([0.2960])
20 0.0038 0.0266 tensor([0.2643])
21 0.0038 0.0256 tensor([0.2891])
22 0.0036 0.0267 tensor([0.2911])
23 0.0035 0.0257 tensor([0.3020])
24 0.0034 0.0250 tensor([0.3195])
25 0.0034 0.0254 tensor([0.3188])
26 0.0033 0.0260 tensor([0.3188])
27 0.0032 0.0256 tensor([0.3225])
28 0.0032 0.0256 tensor([0.3258])
29 0.0031 0.0245 tensor([0.3397])

>>>>>>>>>>>> Iter 3


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0416 0.0246 tensor([0.])
1 0.0097 0.0272 tensor([0.])
2 0.0071 0.0249 tensor([0.])
3 0.0063 0.0243 tensor([0.])
4 0.0057 0.0236 tensor([0.])
5 0.0052 0.0236 tensor([0.])
6 0.0050 0.0242 tensor([0.])
7 0.0048 0.0253 tensor([0.])
8 0.0046 0.0252 tensor([0.])
9 0.0044 0.0249 tensor([3.5802e-05])
10 0.0042 0.0254 tensor([0.0011])
11 0.0042 0.0235 tensor([0.1899])
12 0.0040 0.0266 tensor([0.2172])
13 0.0039 0.0221 tensor([0.2733])
14 0.0038 0.0240 tensor([0.2756])
15 0.0036 0.0250 tensor([0.2752])
16 0.0035 0.0253 tensor([0.2899])
17 0.0034 0.0260 tensor([0.2798])
18 0.0034 0.0267 tensor([0.2857])
19 0.0033 0.0249 tensor([0.3126])
20 0.0032 0.0261 tensor([0.3015])
21 0.0032 0.0240 tensor([0.3284])
22 0.0030 0.0233 tensor([0.3374])
23 0.0029 0.0255 tensor([0.3242])
24 0.0029 0.0242 tensor([0.3447])
25 0.0028 0.0217 tensor([0.3707])
26 0.0028 0.0245 tensor([0.3539])
27 0.0027 0.0230 tensor([0.3658])
28 0.0026 0.0209 tensor([0.3893])
29 0.0026 0.0233 tensor([0.3707])

>>>>>>>>>>>> Iter 4


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0515 0.0224 tensor([0.])
1 0.0079 0.0237 tensor([0.])
2 0.0065 0.0233 tensor([0.])
3 0.0058 0.0233 tensor([0.])
4 0.0053 0.0206 tensor([0.])
5 0.0051 0.0201 tensor([0.])
6 0.0048 0.0200 tensor([0.])
7 0.0046 0.0178 tensor([0.])
8 0.0044 0.0215 tensor([0.])
9 0.0043 0.0202 tensor([0.0007])
10 0.0040 0.0214 tensor([0.2591])
11 0.0038 0.0207 tensor([0.3080])
12 0.0037 0.0215 tensor([0.3153])
13 0.0036 0.0206 tensor([0.3358])
14 0.0035 0.0216 tensor([0.3429])
15 0.0034 0.0200 tensor([0.3736])
16 0.0032 0.0225 tensor([0.3462])
17 0.0031 0.0235 tensor([0.3427])
18 0.0030 0.0216 tensor([0.3691])
19 0.0029 0.0202 tensor([0.4032])
20 0.0029 0.0212 tensor([0.3930])
21 0.0028 0.0201 tensor([0.4069])
22 0.0027 0.0204 tensor([0.4023])
23 0.0026 0.0216 tensor([0.3978])
24 0.0026 0.0225 tensor([0.3916])
25 0.0025 0.0219 tensor([0.3995])
26 0.0025 0.0209 tensor([0.4132])
27 0.0024 0.0200 tensor([0.4262])
28 0.0023 0.0203 tensor([0.4312])
29 0.0022 0.0196 tensor([0.4396])

>>>>>>>>>>>> Iter 5


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0305 0.0252 tensor([0.])
1 0.0088 0.0231 tensor([0.])
2 0.0067 0.0204 tensor([0.])
3 0.0057 0.0209 tensor([0.])
4 0.0052 0.0217 tensor([0.])
5 0.0048 0.0198 tensor([0.])
6 0.0048 0.0188 tensor([0.])
7 0.0045 0.0202 tensor([0.2282])
8 0.0042 0.0179 tensor([0.2959])
9 0.0039 0.0192 tensor([0.3278])
10 0.0037 0.0218 tensor([0.2879])
11 0.0036 0.0212 tensor([0.3212])
12 0.0035 0.0208 tensor([0.3361])
13 0.0033 0.0227 tensor([0.3290])
14 0.0032 0.0210 tensor([0.3518])
15 0.0031 0.0224 tensor([0.3406])
16 0.0030 0.0190 tensor([0.4022])
17 0.0029 0.0196 tensor([0.3923])
18 0.0028 0.0214 tensor([0.3786])
19 0.0027 0.0232 tensor([0.3626])
20 0.0026 0.0237 tensor([0.3617])
21 0.0026 0.0207 tensor([0.4010])
22 0.0024 0.0204 tensor([0.4153])
23 0.0024 0.0208 tensor([0.4196])
24 0.0024 0.0221 tensor([0.4072])
25 0.0023 0.0193 tensor([0.4490])
26 0.0022 0.0199 tensor([0.4437])
27 0.0021 0.0213 tensor([0.4303])
28 0.0021 0.0198 tensor([0.4536])
29 0.0020 0.0190 tensor([0.4722])

>>>>>>>>>>>> Iter

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0607 0.0236 tensor([0.])
1 0.0075 0.0223 tensor([0.])
2 0.0064 0.0230 tensor([0.])
3 0.0058 0.0216 tensor([0.])
4 0.0053 0.0223 tensor([0.])
5 0.0050 0.0181 tensor([0.])
6 0.0047 0.0187 tensor([0.])
7 0.0044 0.0205 tensor([0.])
8 0.0042 0.0197 tensor([0.2589])
9 0.0040 0.0196 tensor([0.3140])
10 0.0037 0.0194 tensor([0.3354])
11 0.0036 0.0209 tensor([0.3408])
12 0.0034 0.0194 tensor([0.3752])
13 0.0032 0.0177 tensor([0.4151])
14 0.0031 0.0185 tensor([0.4228])
15 0.0030 0.0184 tensor([0.4237])
16 0.0029 0.0186 tensor([0.4304])
17 0.0028 0.0148 tensor([0.5062])
18 0.0027 0.0180 tensor([0.4506])
19 0.0026 0.0183 tensor([0.4511])
20 0.0025 0.0189 tensor([0.4583])
21 0.0024 0.0186 tensor([0.4637])
22 0.0023 0.0175 tensor([0.4757])
23 0.0022 0.0161 tensor([0.5130])
24 0.0022 0.0173 tensor([0.4841])
25 0.0021 0.0176 tensor([0.4904])
26 0.0021 0.0178 tensor([0.4909])
27 0.0020 0.0180 tensor([0.4802])
28 0.0019 0.0169 tensor([0.5096])
29 0.0019 0.0174 tensor([0.5030])

>>>>>>>>>>>> Iter 7


Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0267 0.0170 tensor([0.])
1 0.0070 0.0140 tensor([0.])
2 0.0057 0.0153 tensor([0.])
3 0.0051 0.0136 tensor([0.0002])
4 0.0047 0.0130 tensor([0.3560])
5 0.0044 0.0124 tensor([0.4550])
6 0.0041 0.0129 tensor([0.4719])
7 0.0038 0.0121 tensor([0.5211])
8 0.0037 0.0113 tensor([0.5549])
9 0.0035 0.0117 tensor([0.5438])
10 0.0033 0.0111 tensor([0.5591])
11 0.0031 0.0099 tensor([0.5957])
12 0.0030 0.0117 tensor([0.5512])
13 0.0028 0.0113 tensor([0.5708])
14 0.0027 0.0098 tensor([0.6020])
15 0.0026 0.0113 tensor([0.5727])
16 0.0025 0.0107 tensor([0.5858])
17 0.0024 0.0111 tensor([0.5825])
18 0.0023 0.0115 tensor([0.5768])
19 0.0022 0.0114 tensor([0.5842])
20 0.0022 0.0107 tensor([0.5983])
21 0.0021 0.0107 tensor([0.6010])
22 0.0020 0.0111 tensor([0.6006])
23 0.0020 0.0113 tensor([0.5951])
24 0.0019 0.0112 tensor([0.6015])
25 0.0019 0.0119 tensor([0.5897])
26 0.0018 0.0109 tensor([0.6068])
27 0.0018 0.0101 tensor([0.6179])
28 0.0017 0.0107 tensor([0.6111])
29 0.0017 0.0107 tensor([0.6145])

>

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0275 0.0202 tensor([0.])
1 0.0068 0.0209 tensor([0.])
2 0.0058 0.0210 tensor([0.])
3 0.0053 0.0210 tensor([0.])
4 0.0048 0.0205 tensor([0.2618])
5 0.0044 0.0203 tensor([0.3399])
6 0.0041 0.0195 tensor([0.3819])
7 0.0038 0.0198 tensor([0.4035])
8 0.0036 0.0137 tensor([0.4964])
9 0.0035 0.0160 tensor([0.4635])
10 0.0032 0.0162 tensor([0.4821])
11 0.0031 0.0160 tensor([0.4867])
12 0.0029 0.0180 tensor([0.4686])
13 0.0027 0.0175 tensor([0.4790])
14 0.0026 0.0184 tensor([0.4722])
15 0.0025 0.0197 tensor([0.4641])
16 0.0024 0.0202 tensor([0.4628])
17 0.0024 0.0188 tensor([0.4919])
18 0.0023 0.0195 tensor([0.4823])
19 0.0022 0.0173 tensor([0.4998])
20 0.0021 0.0184 tensor([0.4947])
21 0.0021 0.0181 tensor([0.5058])
22 0.0020 0.0181 tensor([0.5148])
23 0.0019 0.0191 tensor([0.5114])
24 0.0018 0.0159 tensor([0.5422])
25 0.0018 0.0183 tensor([0.5257])
26 0.0017 0.0185 tensor([0.5234])
27 0.0017 0.0173 tensor([0.5322])
28 0.0017 0.0182 tensor([0.5316])
29 0.0016 0.0175 tensor([0.5432])

>>>>>

Epoch:   0%|          | 0/30 [00:00<?, ?it/s]

0 0.0208 0.0238 tensor([0.])
1 0.0069 0.0218 tensor([0.])
2 0.0058 0.0256 tensor([0.0002])
3 0.0052 0.0238 tensor([0.2376])
4 0.0047 0.0217 tensor([0.3202])
5 0.0043 0.0218 tensor([0.3272])
6 0.0040 0.0214 tensor([0.3508])
7 0.0037 0.0206 tensor([0.3731])
8 0.0035 0.0199 tensor([0.3954])
9 0.0033 0.0207 tensor([0.4039])
10 0.0031 0.0183 tensor([0.4272])
11 0.0029 0.0197 tensor([0.4187])
12 0.0028 0.0224 tensor([0.3972])
13 0.0026 0.0252 tensor([0.3772])
14 0.0025 0.0197 tensor([0.4314])
15 0.0024 0.0214 tensor([0.4147])
16 0.0023 0.0223 tensor([0.4218])
17 0.0022 0.0233 tensor([0.4144])
18 0.0021 0.0233 tensor([0.4230])
19 0.0021 0.0234 tensor([0.4168])
20 0.0020 0.0213 tensor([0.4417])
21 0.0019 0.0220 tensor([0.4379])
22 0.0019 0.0219 tensor([0.4433])
23 0.0018 0.0233 tensor([0.4364])
24 0.0018 0.0205 tensor([0.4657])
25 0.0017 0.0218 tensor([0.4561])
26 0.0017 0.0223 tensor([0.4478])
27 0.0016 0.0216 tensor([0.4602])
28 0.0016 0.0223 tensor([0.4568])
29 0.0016 0.0213 tensor([0.4697]