In [1]:
import torch

In [2]:
from util.data import get_dl, get_train_ds, get_val_ds
from util.training_utils import train, check_accuracy, save, load, check_accuracy_v2
from models.Base3DUNet import Base3DUNet
from util.loss import DiceLoss, BCEDiceLoss

In [3]:
BATCH_SIZE = 4
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
LR = 0.0001
# DEVICE = 'cpu'

In [5]:
train_dl = get_dl(get_train_ds(), BATCH_SIZE,nw=2)
val_dl = get_dl(get_val_ds(), BATCH_SIZE, nw=2)

images: 221, masks: 221 
images: 148, masks: 148 


In [6]:
import gc
torch.cuda.empty_cache()
gc.collect()

445

In [7]:
model = Base3DUNet(3, 3, features=[64, 128, 256, 512], up_sample=True).to(DEVICE)
print(model)

Base3DUNet(
  (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downs): ModuleList(
    (0): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),

In [9]:
load(model, "weights/3D/3d_100e_adam_b4_dice_upsampler")

In [9]:
# model.train()

In [10]:

print(f"total parameters = {sum(p.numel() for p in model.parameters())}")
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total parameters = 94130947
total learnable parameters = 94130947


In [11]:
opt = torch.optim.Adam(model.parameters(), lr=LR)
# loss = torch.nn.BCEWithLogitsLoss # cannot be used because there's a lot of imbalance anyway, so it is better to combine it with dice
loss = BCEDiceLoss()
# loss = DiceLoss()

In [12]:
print(DEVICE)


cuda


In [13]:
# train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/100]: 100%|██████████| 56/56 [02:29<00:00,  2.67s/it, loss=1.5] 
At epoch [2/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1.48]
At epoch [3/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1.48]
At epoch [4/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1.4] 
At epoch [5/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1.31]
At epoch [6/100]: 100%|██████████| 56/56 [02:23<00:00,  2.57s/it, loss=1.29]
At epoch [7/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1.25]
At epoch [8/100]: 100%|██████████| 56/56 [02:23<00:00,  2.57s/it, loss=1.21]
At epoch [9/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1.24]
At epoch [10/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1.2] 
At epoch [11/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=1]   
At epoch [12/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=0.948]
At epoch [13/100]: 100%|██████████| 56/56 [02:23<00:00,  2.56s/it, loss=

In [11]:
check_accuracy_v2(val_dl,model,DEVICE)
# convT
# 100 DICE Results: Results: 115510440/116391936 with accuracy 99.3155 Dice score: 0.7647337317466736
# 100 BCE-DICE Results: 115626234/116391936 with accuracy 99.3421 Dice score: 0.7765376567840576
# upsample
# 100 DICE Results: Results: 115510440/116391936 with accuracy 99.3394 Dice score: 0.7728903293609619
# 100 BCE-DICE Results: 115626234/116391936 with accuracy 99.3208 Dice score: 0.7685501575469971

 Accuracy (TC,ET,WT): 
 --> 99.3720 , 99.7061, 97.4900
Dice Score (TC,ET,WT): 
 2.0 , 2.0, 2.0


In [15]:
# saving sample
# save(model,"weights/3D/3d_100e_adam_b4_bce_dice_upsampler")
print("saved the model...")

saved the model...


In [None]:
# save(model,"weights/3D/3d_100e_adam_b4_dice_upsampler")

In [14]:
# train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/50]: 100%|██████████| 56/56 [01:36<00:00,  1.73s/it, loss=0.201] 
At epoch [2/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.118]
At epoch [3/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.171] 
At epoch [4/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.188]
At epoch [5/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.135]
At epoch [6/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.0845]
At epoch [7/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.27]  
At epoch [8/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.218] 
At epoch [9/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.225] 
At epoch [10/50]:  12%|█▎        | 7/56 [00:14<01:28,  1.81s/it, loss=0.128]

In [15]:
check_accuracy(val_dl,model,DEVICE)

In [19]:
# loading sample
model1 = Base3DUNet(3,3)
load(model,"weights/3D/t1")