In [1]:
import torch
from util.data import get_dl, get_train_ds, get_val_ds
from util.training_utils import train, check_accuracy, save, load, get_all_metrics, get_all_metrics_2
from models.Base3DUNet import Base3DUNet
from models.Attention3DUNet import Attention3UNet
from models.Residual3DUNet import Res3DUNet

BATCH_SIZE = 4
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_dl = get_dl(get_val_ds(full_masks=True), BATCH_SIZE, nw=2)

images: 148, masks: 148 


In [2]:
import gc

def clear():
    torch.cuda.empty_cache()
    gc.collect()

clear()

In [3]:
model = Base3DUNet(3, 3, features=[64, 128, 256, 512], up_sample=True).to(DEVICE)
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")


In [4]:
load(model,"weights/3D/3d_100e_adam_b4_dice_upsampler")
print("3DUNet Upsampler with Dice Loss")
get_all_metrics_2(val_dl,model,DEVICE)

3DUNet Upsampler with Dice Loss
 Accuracy (TC,ET,WT): 
 99.3601 , 99.6902, 99.0066
Dice Score (TC,ET,WT): 
 0.7821628451347351 , 0.68885737657547, 0.839487612247467
IoU Score (TC,ET,WT): 
 0.6852228045463562 , 0.5777105689048767, 0.7514512538909912
 Precision (TC,ET,WT): 
 --> 40.6819 , 80.0799, 16.7454
 Recall (TC,ET,WT): 
 --> 34.6550 , 78.0800, 14.8955
 F1-score (TC,ET,WT): 
 --> 37.4274 , 79.0673, 15.7664


In [5]:
load(model,"weights/3D/3d_100e_adam_b4_bce_dice_upsampler")
print("3DUNet Upsampler with BCE-Dice Loss")
get_all_metrics_2(val_dl,model,DEVICE)

3DUNet Upsampler with BCE-Dice Loss
 Accuracy (TC,ET,WT): 
 99.3121 , 99.6928, 99.0178
Dice Score (TC,ET,WT): 
 0.7719240784645081 , 0.6890219449996948, 0.8466907143592834
IoU Score (TC,ET,WT): 
 0.6696090698242188 , 0.5770220756530762, 0.7578732371330261
 Precision (TC,ET,WT): 
 --> 41.4878 , 84.0877, 16.1345
 Recall (TC,ET,WT): 
 --> 34.6383 , 78.0820, 14.8971
 F1-score (TC,ET,WT): 
 --> 37.7549 , 80.9737, 15.4912


In [6]:
clear()
model = Base3DUNet(3, 3, features=[64, 128, 256, 512], up_sample=False).to(DEVICE)
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")


In [7]:
load(model,"weights/3D/3d_100e_adam_b4_dice")
print("3DUNet with Dice Loss")
get_all_metrics_2(val_dl,model,DEVICE)

3DUNet with Dice Loss
 Accuracy (TC,ET,WT): 
 99.2633 , 99.6768, 98.9967
Dice Score (TC,ET,WT): 
 0.7423750162124634 , 0.6710379719734192, 0.8460584282875061
IoU Score (TC,ET,WT): 
 0.6411218643188477 , 0.5546413660049438, 0.7568052411079407
 Precision (TC,ET,WT): 
 --> 46.9100 , 97.4532, 14.9651
 Recall (TC,ET,WT): 
 --> 34.6212 , 78.0695, 14.8940
 F1-score (TC,ET,WT): 
 --> 39.8395 , 86.6910, 14.9294


In [8]:
load(model,"weights/3D/3d_100e_adam_b4_bce_dice")
print("3DUNet with BCE-Dice Loss")
get_all_metrics_2(val_dl, model, DEVICE)

3DUNet with BCE-Dice Loss
 Accuracy (TC,ET,WT): 
 99.3517 , 99.7012, 99.0492
Dice Score (TC,ET,WT): 
 0.7789777517318726 , 0.6982654929161072, 0.8524793386459351
IoU Score (TC,ET,WT): 
 0.6808264851570129 , 0.5862982869148254, 0.7661640048027039
 Precision (TC,ET,WT): 
 --> 38.6881 , 85.5733, 16.1231
 Recall (TC,ET,WT): 
 --> 34.6521 , 78.0886, 14.9019
 F1-score (TC,ET,WT): 
 --> 36.5590 , 81.6598, 15.4884


In [9]:
clear()
model = Attention3UNet(3, 3, features=[64, 128, 256, 512], up_sample=False).to(DEVICE)
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")


In [10]:
load(model,"weights/attention/att3d_100e_adam_b4_dice_c")
print("Attention3DUNet with Dice Loss - (ConvT version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with Dice Loss - (ConvT version)
 Accuracy (TC,ET,WT): 
 99.3544 , 99.6905, 99.0364
Dice Score (TC,ET,WT): 
 0.7816538214683533 , 0.6932134032249451, 0.8522657752037048
IoU Score (TC,ET,WT): 
 0.6859500408172607 , 0.5824635624885559, 0.765794575214386
 Precision (TC,ET,WT): 
 --> 42.0900 , 82.9900, 16.1025
 Recall (TC,ET,WT): 
 --> 34.6530 , 78.0802, 14.8999
 F1-score (TC,ET,WT): 
 --> 38.0111 , 80.4603, 15.4779


In [11]:
load(model,"weights/attention/att3d_100e_adam_b4_bce_dice_c")
print("Attention3DUNet with BCE-Dice Loss - (ConvT version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with BCE-Dice Loss - (ConvT version)
 Accuracy (TC,ET,WT): 
 99.2950 , 99.6922, 99.0070
Dice Score (TC,ET,WT): 
 0.760830819606781 , 0.6907649040222168, 0.8438195586204529
IoU Score (TC,ET,WT): 
 0.661676287651062 , 0.5782598853111267, 0.7573460936546326
 Precision (TC,ET,WT): 
 --> 42.0382 , 86.3042, 16.2116
 Recall (TC,ET,WT): 
 --> 34.6323 , 78.0815, 14.8955
 F1-score (TC,ET,WT): 
 --> 37.9776 , 81.9872, 15.5257


In [12]:
clear()
model = Attention3UNet(3, 3, features=[64, 128, 256, 512], up_sample=True).to(DEVICE)
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")


In [13]:
load(model,"weights/attention/att3d_100e_adam_b4_dice")
print("Attention3DUNet with Dice Loss - (Upsampler version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with Dice Loss - (Upsampler version)
 Accuracy (TC,ET,WT): 
 99.3225 , 99.6896, 99.0118
Dice Score (TC,ET,WT): 
 0.7768449783325195 , 0.6793325543403625, 0.844255805015564
IoU Score (TC,ET,WT): 
 0.6746198534965515 , 0.5670403242111206, 0.7549707293510437
 Precision (TC,ET,WT): 
 --> 42.1787 , 92.9405, 16.8040
 Recall (TC,ET,WT): 
 --> 34.6419 , 78.0795, 14.8962
 F1-score (TC,ET,WT): 
 --> 38.0406 , 84.8643, 15.7927


In [14]:
load(model,"weights/attention/att3d_100e_adam_b4_bce_dice")
print("Attention3DUNet with BCE-Dice Loss - (Upsampler version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with BCE-Dice Loss - (Upsampler version)
 Accuracy (TC,ET,WT): 
 99.3204 , 99.6839, 98.9921
Dice Score (TC,ET,WT): 
 0.7743036150932312 , 0.6768759489059448, 0.8426199555397034
IoU Score (TC,ET,WT): 
 0.6734686493873596 , 0.5659208297729492, 0.7533087730407715
 Precision (TC,ET,WT): 
 --> 43.0446 , 89.8057, 16.7121
 Recall (TC,ET,WT): 
 --> 34.6411 , 78.0750, 14.8933
 F1-score (TC,ET,WT): 
 --> 38.3884 , 83.5305, 15.7503


In [15]:
clear()
model = Res3DUNet(3, 3).to(DEVICE)
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")


In [16]:
load(model,"weights/residual/3d_res_4layer_100e_adam_dice")
print("Residual3DUNet with Dice Loss")
get_all_metrics_2(val_dl, model, DEVICE)

Residual3DUNet with Dice Loss
 Accuracy (TC,ET,WT): 
 98.0628 , 99.1373, 97.7190
Dice Score (TC,ET,WT): 
 0.0 , 0.0, 0.655088484287262
IoU Score (TC,ET,WT): 
 0.0 , 0.0, 0.5106545090675354
 Precision (TC,ET,WT): 
 --> inf , inf, 22.2087
 Recall (TC,ET,WT): 
 --> 34.2025 , 77.6469, 14.7017
 F1-score (TC,ET,WT): 
 --> nan , nan, 17.6918


In [17]:
load(model,"weights/residual/3d_res_4layer_100e_adam_bce_dice")
print("Residual3DUNet with BCE-Dice Loss")
get_all_metrics_2(val_dl, model, DEVICE)

Residual3DUNet with BCE-Dice Loss
 Accuracy (TC,ET,WT): 
 99.2630 , 99.6664, 98.9079
Dice Score (TC,ET,WT): 
 0.7453449368476868 , 0.6737334132194519, 0.8274587988853455
IoU Score (TC,ET,WT): 
 0.6367512345314026 , 0.5583897233009338, 0.7322196364402771
 Precision (TC,ET,WT): 
 --> 44.4874 , 89.2285, 16.5262
 Recall (TC,ET,WT): 
 --> 34.6211 , 78.0613, 14.8806
 F1-score (TC,ET,WT): 
 --> 38.9390 , 83.2722, 15.6603
