In [1]:
import torch

In [2]:
from util.data import get_dl, get_train_ds, get_val_ds
from util.training_utils import train, check_accuracy, save, load
from models.Attention3DUNet import Attention3UNet
from util.loss import DiceLoss, BCEDiceLoss

In [3]:
BATCH_SIZE = 3
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
LR = 0.0001
# DEVICE = 'cpu'

In [5]:
train_dl = get_dl(get_train_ds(), BATCH_SIZE,nw=2)
val_dl = get_dl(get_val_ds(), BATCH_SIZE, nw=2)

images: 221, masks: 221 
images: 148, masks: 148 


In [6]:
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [7]:
model = Attention3UNet(3, 3, features=[64, 128, 256, 512], up_sample=True).to(DEVICE)
print(model)

Attention3UNet(
  (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downs): ModuleList(
    (0): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1,

In [8]:
# load(model, "weights/3D/3d_50e_adam_b4_dice")

In [9]:
# model.train()

In [10]:

print(f"total parameters = {sum(p.numel() for p in model.parameters())}") # 5m more params than base3DUnet
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total parameters = 95182159
total learnable parameters = 95182159


In [11]:
opt = torch.optim.Adam(model.parameters(), lr=LR)
# loss = torch.nn.BCEWithLogitsLoss # cannot be used because there's a lot of imbalance anyway, so it is better to combine it with dice
loss = BCEDiceLoss()
# loss = DiceLoss()

In [12]:
print(DEVICE)


cuda


In [None]:
train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/100]: 100%|██████████| 74/74 [02:35<00:00,  2.10s/it, loss=1.39]
At epoch [2/100]:  55%|█████▌    | 41/74 [01:32<01:12,  2.20s/it, loss=1.35]

In [None]:
check_accuracy(val_dl, model, DEVICE)
# 100 DICE Results: Results: 115510440/116391936 with accuracy 99.3187 Dice score: 0.7661410570144653
# 100 BCE-DICE Results: 115626234/116391936 with accuracy 99.3421 Dice score: 0.7765376567840576

In [15]:
# saving sample
save(model,"weights/attention/att3d_100e_adam_b4_bce-dice")
print("saved the model...")

saved the model...


In [14]:
train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/50]: 100%|██████████| 56/56 [01:36<00:00,  1.73s/it, loss=0.201] 
At epoch [2/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.118]
At epoch [3/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.171] 
At epoch [4/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.188]
At epoch [5/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.135]
At epoch [6/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.0845]
At epoch [7/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.27]  
At epoch [8/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.218] 
At epoch [9/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.225] 
At epoch [10/50]:  12%|█▎        | 7/56 [00:14<01:28,  1.81s/it, loss=0.128]

In [15]:
check_accuracy(val_dl,model,DEVICE)

In [19]:
# loading sample
model1 = Attention3UNet(3,3)
load(model,"weights/attention/t1")