In [1]:
import torch

In [2]:
from util.data import get_dl, get_train_ds, get_val_ds
from util.training_utils import train, check_accuracy, save, load
from models.Base3DUNet import Base3DUNet
from util.loss import DiceLoss, BCEDiceLoss

In [3]:
BATCH_SIZE = 4
EPOCHS = 50
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
LR = 0.0001
# DEVICE = 'cpu'

In [5]:
train_dl = get_dl(get_train_ds(), BATCH_SIZE,nw=2)
val_dl = get_dl(get_val_ds(), BATCH_SIZE, nw=2)

images: 221, masks: 221 
images: 148, masks: 148 


In [6]:
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [7]:
model = Base3DUNet(3, 3, features=[64, 128, 256, 512]).to(DEVICE)
print(model)

Base3DUNet(
  (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downs): ModuleList(
    (0): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),

In [9]:
load(model, "weights/3D/3d_50e_adam_b4_dice")

In [10]:
model.train()

Base3DUNet(
  (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downs): ModuleList(
    (0): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),

In [8]:

print(f"total parameters = {sum(p.numel() for p in model.parameters())}")
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total parameters = 90302147
total learnable parameters = 90302147


In [11]:
opt = torch.optim.Adam(model.parameters(), lr=LR)
# loss = torch.nn.BCEWithLogitsLoss # cannot be used because there's a lot of imbalance anyway, so it is better to combine it with dice
loss = DiceLoss()

In [12]:
print(DEVICE)


cuda


In [13]:
train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/50]: 100%|██████████| 56/56 [01:31<00:00,  1.63s/it, loss=0.209] 
At epoch [2/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.166] 
At epoch [3/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.523] 
At epoch [4/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.188] 
At epoch [5/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.117]
At epoch [6/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.121] 
At epoch [7/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.21]  
At epoch [8/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.0908]
At epoch [9/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.246] 
At epoch [10/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.133] 
At epoch [11/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.196] 
At epoch [12/50]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.108] 
At epoch [13/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/

In [14]:
check_accuracy(val_dl,model,DEVICE)
# 50 Results: 115540617/116391936 with accuracy 99.2686 Dice score: 0.7475918531417847
# 100 Results: Results: 115510440/116391936 with accuracy 99.2427 Dice score: 0.7560997009277344

Results: 115510440/116391936 with accuracy 99.2427
Dice score: 0.7560997009277344


In [15]:
# saving sample
save(model,"weights/3D/3d_100e_adam_b4_dice")

In [14]:
train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/50]: 100%|██████████| 56/56 [01:36<00:00,  1.73s/it, loss=0.201] 
At epoch [2/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.118]
At epoch [3/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.171] 
At epoch [4/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.188]
At epoch [5/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.135]
At epoch [6/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.0845]
At epoch [7/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.27]  
At epoch [8/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.218] 
At epoch [9/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.225] 
At epoch [10/50]:  12%|█▎        | 7/56 [00:14<01:28,  1.81s/it, loss=0.128]

In [15]:
check_accuracy(val_dl,model,DEVICE)

In [19]:
# loading sample
model1 = Base3DUNet(3,3)
load(model,"weights/3D/t1")