In [1]:
from dataset import CT2US
from models.unet import UNet
from engine import trainer

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device: ", device)


LEARNING_RATE = 1e-3
BATCH_SIZE = 64
NUM_EPOCHS = 40

IMAGE_SIZE = 128
THRESHOLD = 0.5

Using device:  cuda


In [3]:
# Load dataset

image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMAGE_SIZE),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
mask_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize(IMAGE_SIZE),
])

dataset = CT2US(root="datasets/CT2US", image_transform=image_transform, mask_transform=mask_transform)

train_set, test_set = train_test_split(dataset, test_size=0.2, random_state=42)

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)



In [4]:
model = UNet(outSize=(IMAGE_SIZE, IMAGE_SIZE)).to(device)

In [5]:
loss = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

In [7]:
trainer(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    loss_fn=loss,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    device=device,
    epochs=NUM_EPOCHS,
)

Epoch 1:


100%|██████████| 58/58 [00:06<00:00,  9.20it/s]


Train Loss: 0.4575
Learning rate: 0.00100


100%|██████████| 15/15 [00:01<00:00, 14.04it/s]


val Loss: 0.3854

Epoch 2:


100%|██████████| 58/58 [00:04<00:00, 12.22it/s]


Train Loss: 0.3657
Learning rate: 0.00100


100%|██████████| 15/15 [00:01<00:00, 13.69it/s]


val Loss: 0.3436

Epoch 3:


100%|██████████| 58/58 [00:04<00:00, 12.25it/s]


Train Loss: 0.3292
Learning rate: 0.00099


100%|██████████| 15/15 [00:01<00:00, 14.65it/s]


val Loss: 0.3106

Epoch 4:


100%|██████████| 58/58 [00:04<00:00, 12.00it/s]


Train Loss: 0.3118
Learning rate: 0.00099


100%|██████████| 15/15 [00:01<00:00, 13.70it/s]


val Loss: 0.3247

Epoch 5:


100%|██████████| 58/58 [00:04<00:00, 12.26it/s]


Train Loss: 0.2907
Learning rate: 0.00098


100%|██████████| 15/15 [00:01<00:00, 14.27it/s]


val Loss: 0.2652

Epoch 6:


100%|██████████| 58/58 [00:04<00:00, 12.03it/s]


Train Loss: 0.2554
Learning rate: 0.00096


100%|██████████| 15/15 [00:01<00:00, 13.94it/s]


val Loss: 0.2484

Epoch 7:


100%|██████████| 58/58 [00:04<00:00, 12.03it/s]


Train Loss: 0.2306
Learning rate: 0.00095


100%|██████████| 15/15 [00:01<00:00, 13.58it/s]


val Loss: 0.2263

Epoch 8:


100%|██████████| 58/58 [00:04<00:00, 12.06it/s]


Train Loss: 0.2118
Learning rate: 0.00093


100%|██████████| 15/15 [00:01<00:00, 13.56it/s]


val Loss: 0.2028

Epoch 9:


100%|██████████| 58/58 [00:04<00:00, 11.91it/s]


Train Loss: 0.1971
Learning rate: 0.00090


100%|██████████| 15/15 [00:01<00:00, 12.99it/s]


val Loss: 0.1905

Epoch 10:


100%|██████████| 58/58 [00:04<00:00, 12.19it/s]


Train Loss: 0.1847
Learning rate: 0.00088


100%|██████████| 15/15 [00:01<00:00, 13.90it/s]


val Loss: 0.1885

Epoch 11:


100%|██████████| 58/58 [00:04<00:00, 12.24it/s]


Train Loss: 0.1785
Learning rate: 0.00085


100%|██████████| 15/15 [00:01<00:00, 13.81it/s]


val Loss: 0.1770

Epoch 12:


100%|██████████| 58/58 [00:04<00:00, 12.21it/s]


Train Loss: 0.1708
Learning rate: 0.00082


100%|██████████| 15/15 [00:01<00:00, 13.95it/s]


val Loss: 0.1685

Epoch 13:


100%|██████████| 58/58 [00:04<00:00, 12.17it/s]


Train Loss: 0.1632
Learning rate: 0.00079


100%|██████████| 15/15 [00:01<00:00, 12.72it/s]


val Loss: 0.1705

Epoch 14:


100%|██████████| 58/58 [00:04<00:00, 12.11it/s]


Train Loss: 0.1578
Learning rate: 0.00076


100%|██████████| 15/15 [00:01<00:00, 13.44it/s]


val Loss: 0.1607

Epoch 15:


100%|██████████| 58/58 [00:04<00:00, 12.02it/s]


Train Loss: 0.1532
Learning rate: 0.00073


100%|██████████| 15/15 [00:01<00:00, 14.16it/s]


val Loss: 0.1627

Epoch 16:


100%|██████████| 58/58 [00:04<00:00, 12.11it/s]


Train Loss: 0.1502
Learning rate: 0.00069


100%|██████████| 15/15 [00:01<00:00, 13.66it/s]


val Loss: 0.1495

Epoch 17:


100%|██████████| 58/58 [00:04<00:00, 12.20it/s]


Train Loss: 0.1466
Learning rate: 0.00065


100%|██████████| 15/15 [00:01<00:00, 13.31it/s]


val Loss: 0.1604

Epoch 18:


100%|██████████| 58/58 [00:04<00:00, 12.07it/s]


Train Loss: 0.1449
Learning rate: 0.00062


100%|██████████| 15/15 [00:01<00:00, 12.15it/s]


val Loss: 0.1445

Epoch 19:


100%|██████████| 58/58 [00:04<00:00, 12.12it/s]


Train Loss: 0.1393
Learning rate: 0.00058


100%|██████████| 15/15 [00:01<00:00, 13.75it/s]


val Loss: 0.1416

Epoch 20:


100%|██████████| 58/58 [00:04<00:00, 12.09it/s]


Train Loss: 0.1358
Learning rate: 0.00054


100%|██████████| 15/15 [00:01<00:00, 13.81it/s]


val Loss: 0.1492

Epoch 21:


100%|██████████| 58/58 [00:04<00:00, 11.97it/s]


Train Loss: 0.1352
Learning rate: 0.00050


100%|██████████| 15/15 [00:01<00:00, 14.16it/s]


val Loss: 0.1352

Epoch 22:


100%|██████████| 58/58 [00:04<00:00, 12.01it/s]


Train Loss: 0.1310
Learning rate: 0.00046


100%|██████████| 15/15 [00:01<00:00, 14.30it/s]


val Loss: 0.1355

Epoch 23:


100%|██████████| 58/58 [00:04<00:00, 12.34it/s]


Train Loss: 0.1293
Learning rate: 0.00042


100%|██████████| 15/15 [00:01<00:00, 13.32it/s]


val Loss: 0.1315

Epoch 24:


100%|██████████| 58/58 [00:04<00:00, 12.00it/s]


Train Loss: 0.1291
Learning rate: 0.00038


100%|██████████| 15/15 [00:01<00:00, 13.10it/s]


val Loss: 0.1313

Epoch 25:


100%|██████████| 58/58 [00:04<00:00, 11.85it/s]


Train Loss: 0.1273
Learning rate: 0.00035


100%|██████████| 15/15 [00:01<00:00, 13.49it/s]


val Loss: 0.1309

Epoch 26:


100%|██████████| 58/58 [00:04<00:00, 12.03it/s]


Train Loss: 0.1248
Learning rate: 0.00031


100%|██████████| 15/15 [00:01<00:00, 13.93it/s]


val Loss: 0.1305

Epoch 27:


100%|██████████| 58/58 [00:04<00:00, 12.07it/s]


Train Loss: 0.1236
Learning rate: 0.00027


100%|██████████| 15/15 [00:01<00:00, 12.45it/s]


val Loss: 0.1276

Epoch 28:


100%|██████████| 58/58 [00:04<00:00, 11.97it/s]


Train Loss: 0.1228
Learning rate: 0.00024


100%|██████████| 15/15 [00:01<00:00, 13.50it/s]


val Loss: 0.1272

Epoch 29:


100%|██████████| 58/58 [00:04<00:00, 12.06it/s]


Train Loss: 0.1209
Learning rate: 0.00021


100%|██████████| 15/15 [00:01<00:00, 13.16it/s]


val Loss: 0.1260

Epoch 30:


100%|██████████| 58/58 [00:04<00:00, 12.09it/s]


Train Loss: 0.1197
Learning rate: 0.00018


100%|██████████| 15/15 [00:01<00:00, 13.59it/s]


val Loss: 0.1262

Epoch 31:


100%|██████████| 58/58 [00:04<00:00, 11.96it/s]


Train Loss: 0.1192
Learning rate: 0.00015


100%|██████████| 15/15 [00:01<00:00, 13.62it/s]


val Loss: 0.1249

Epoch 32:


100%|██████████| 58/58 [00:04<00:00, 12.07it/s]


Train Loss: 0.1179
Learning rate: 0.00012


100%|██████████| 15/15 [00:01<00:00, 13.23it/s]


val Loss: 0.1250

Epoch 33:


100%|██████████| 58/58 [00:04<00:00, 11.91it/s]


Train Loss: 0.1176
Learning rate: 0.00010


100%|██████████| 15/15 [00:01<00:00, 13.88it/s]


val Loss: 0.1245

Epoch 34:


100%|██████████| 58/58 [00:04<00:00, 11.96it/s]


Train Loss: 0.1164
Learning rate: 0.00007


100%|██████████| 15/15 [00:01<00:00, 14.30it/s]


val Loss: 0.1248

Epoch 35:


100%|██████████| 58/58 [00:04<00:00, 11.80it/s]


Train Loss: 0.1167
Learning rate: 0.00005


100%|██████████| 15/15 [00:01<00:00, 13.97it/s]


val Loss: 0.1237

Epoch 36:


100%|██████████| 58/58 [00:04<00:00, 11.99it/s]


Train Loss: 0.1163
Learning rate: 0.00004


100%|██████████| 15/15 [00:01<00:00, 12.68it/s]


val Loss: 0.1242

Epoch 37:


100%|██████████| 58/58 [00:04<00:00, 12.18it/s]


Train Loss: 0.1155
Learning rate: 0.00002


100%|██████████| 15/15 [00:01<00:00, 13.40it/s]


val Loss: 0.1232

Epoch 38:


100%|██████████| 58/58 [00:04<00:00, 12.20it/s]


Train Loss: 0.1153
Learning rate: 0.00001


100%|██████████| 15/15 [00:01<00:00, 13.15it/s]


val Loss: 0.1232

Epoch 39:


100%|██████████| 58/58 [00:04<00:00, 12.20it/s]


Train Loss: 0.1154
Learning rate: 0.00001


100%|██████████| 15/15 [00:01<00:00, 13.88it/s]


val Loss: 0.1233

Epoch 40:


100%|██████████| 58/58 [00:04<00:00, 14.27it/s]


Train Loss: 0.1152
Learning rate: 0.00000


100%|██████████| 15/15 [00:01<00:00, 13.34it/s]

val Loss: 0.1232






{'train_loss': [0.4574529215179641,
  0.365660075483651,
  0.3292431368910033,
  0.3118077439480814,
  0.29067368368650304,
  0.25536975290240915,
  0.23060589771846246,
  0.21182438294435368,
  0.1971040535075911,
  0.1847405544128911,
  0.17848936391287837,
  0.17079083878418494,
  0.16315401576716324,
  0.15779059328909578,
  0.1531525845157689,
  0.15018921165630736,
  0.14658085770648102,
  0.1449247350723579,
  0.13930254365349637,
  0.13583225281587963,
  0.13518705424563637,
  0.1309901838158739,
  0.12930574525019217,
  0.1290703297689043,
  0.12729345236358972,
  0.12484996264864659,
  0.1235561402964181,
  0.1227778584278863,
  0.12090829875448654,
  0.11970421350721655,
  0.11923232597523722,
  0.11793183496799962,
  0.11761961106596322,
  0.11642057255938135,
  0.11666524538706088,
  0.11628232665103057,
  0.11551977360043032,
  0.11533481165252883,
  0.1154256048130578,
  0.1151610424035582],
 'train_acc': [],
 'val_loss': [0.38543855349222816,
  0.34358327388763427,
  0.