In [1]:
import torch

In [2]:
from util.data import get_dl, get_train_ds, get_val_ds
from util.training_utils import train, check_accuracy, save, load, check_accuracy_v2,  check_accuracy_v3, get_all_metrics, get_all_metrics_2
from models.Base3DUNet import Base3DUNet
from util.loss import DiceLoss, BCEDiceLoss

In [3]:
BATCH_SIZE = 4
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
LR = 0.0001
# DEVICE = 'cpu'

In [5]:
train_dl = get_dl(get_train_ds(), BATCH_SIZE,nw=2)
val_dl = get_dl(get_val_ds(), BATCH_SIZE, nw=2)
val_dl_o = get_dl(get_val_ds(full_masks=True), BATCH_SIZE, nw=2)

images: 221, masks: 221 
images: 148, masks: 148 
images: 148, masks: 148 


In [6]:
import gc
torch.cuda.empty_cache()
gc.collect()

89

In [7]:
model = Base3DUNet(3, 3, features=[64, 128, 256, 512], up_sample=True).to(DEVICE)
print(model)

Base3DUNet(
  (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downs): ModuleList(
    (0): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),

In [8]:
load(model, "weights/3D/3d_100e_adam_b4_dice_upsampler")

In [9]:
# model.train()

In [10]:

print(f"total parameters = {sum(p.numel() for p in model.parameters())}")
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total parameters = 94130947
total learnable parameters = 94130947


In [11]:
opt = torch.optim.Adam(model.parameters(), lr=LR)
# loss = torch.nn.BCEWithLogitsLoss # cannot be used because there's a lot of imbalance anyway, so it is better to combine it with dice
loss = BCEDiceLoss()
# loss = DiceLoss()

In [12]:
print(DEVICE)


cuda


In [13]:
# train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

In [10]:
check_accuracy_v2(val_dl,model,DEVICE)
# convT
# 100 DICE Results: Results: 115510440/116391936 with accuracy 99.3155 Dice score: 0.7647337317466736
# 100 BCE-DICE Results: 115626234/116391936 with accuracy 99.3421 Dice score: 0.7765376567840576
# upsample
# 100 DICE Results: Results: 115510440/116391936 with accuracy 99.3394 Dice score: 0.7728903293609619
# 100 BCE-DICE Results: 115626234/116391936 with accuracy 99.3208 Dice score: 0.7685501575469971

 Accuracy (TC,ET,WT): 
 --> 99.3720 , 99.7061, 99.0616
Dice Score (TC,ET,WT): 
 0.820264995098114 , 0.8163532018661499, 0.8860259652137756


In [14]:
check_accuracy_v3(val_dl_o,model,DEVICE)

 Accuracy (TC,ET,WT): 
 --> 99.3601 , 99.6902, 99.0066
Dice Score (TC,ET,WT): 
 0.8251846432685852 , 0.8126899003982544, 0.8778696656227112


In [11]:
get_all_metrics(val_dl_o,model,DEVICE)

 Accuracy (TC,ET,WT): 
 99.3601 , 99.6902, 99.0066
Dice Score (TC,ET,WT): 
 0.8231386542320251 , 0.8071435689926147, 0.8804966807365417
IoU Score (TC,ET,WT): 
 0.7136725187301636 , 0.6832652688026428, 0.7896050810813904
 Precision (TC,ET,WT): 
 --> 162.7276 , 320.3194, 66.9818
 Recall (TC,ET,WT): 
 --> 138.6200 , 312.3200, 59.5818
 F1-score (TC,ET,WT): 
 --> 149.7095 , 316.2691, 63.0655


In [12]:
val_dl_oo = get_dl(get_val_ds(full_masks=True), 1, nw=2)
get_all_metrics(val_dl_oo,model,DEVICE)

images: 148, masks: 148 
 Accuracy (TC,ET,WT): 
 99.3601 , 99.6902, 99.0066
Dice Score (TC,ET,WT): 
 0.782164990901947 , 0.6888605356216431, 0.8394848108291626
IoU Score (TC,ET,WT): 
 0.6852257251739502 , 0.577714741230011, 0.7514474987983704
 Precision (TC,ET,WT): 
 --> 40.6821 , 80.0816, 16.7456
 Recall (TC,ET,WT): 
 --> 34.6550 , 78.0800, 14.8954
 F1-score (TC,ET,WT): 
 --> 37.4275 , 79.0681, 15.7664


In [13]:
val_dl_oo = get_dl(get_val_ds(full_masks=True),2, nw=2)
get_all_metrics(val_dl_oo,model,DEVICE)

images: 148, masks: 148 
 Accuracy (TC,ET,WT): 
 99.3601 , 99.6902, 99.0066
Dice Score (TC,ET,WT): 
 0.8139834403991699 , 0.7855698466300964, 0.8720433115959167
IoU Score (TC,ET,WT): 
 0.7086735963821411 , 0.6589915156364441, 0.7794755101203918
 Precision (TC,ET,WT): 
 --> 81.3637 , 160.1585, 33.4909
 Recall (TC,ET,WT): 
 --> 69.3100 , 156.1600, 29.7909
 F1-score (TC,ET,WT): 
 --> 74.8547 , 158.1340, 31.5327


In [14]:
get_all_metrics_2(val_dl_o,model,DEVICE)

 Accuracy (TC,ET,WT): 
 99.3601 , 99.6902, 99.0066
Dice Score (TC,ET,WT): 
 0.782162606716156 , 0.6888573169708252, 0.8394874334335327
IoU Score (TC,ET,WT): 
 0.6852228045463562 , 0.5777103900909424, 0.751451313495636
 Precision (TC,ET,WT): 
 --> 40.6819 , 80.0799, 16.7454
 Recall (TC,ET,WT): 
 --> 34.6550 , 78.0800, 14.8955
 F1-score (TC,ET,WT): 
 --> 37.4274 , 79.0673, 15.7664


In [15]:
# saving sample
# save(model,"weights/3D/3d_100e_adam_b4_bce_dice_upsampler")
print("saved the model...")

saved the model...


In [None]:
# save(model,"weights/3D/3d_100e_adam_b4_dice_upsampler")

In [14]:
# train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/50]: 100%|██████████| 56/56 [01:36<00:00,  1.73s/it, loss=0.201] 
At epoch [2/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.118]
At epoch [3/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.171] 
At epoch [4/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.188]
At epoch [5/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.135]
At epoch [6/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.0845]
At epoch [7/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.27]  
At epoch [8/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.218] 
At epoch [9/50]: 100%|██████████| 56/56 [01:37<00:00,  1.74s/it, loss=0.225] 
At epoch [10/50]:  12%|█▎        | 7/56 [00:14<01:28,  1.81s/it, loss=0.128]

In [15]:
check_accuracy(val_dl,model,DEVICE)

In [19]:
# loading sample
model1 = Base3DUNet(3,3)
load(model,"weights/3D/t1")