In [2]:
import torch
from util.data import get_dl, get_train_ds, get_val_ds
from util.training_utils import train, check_accuracy, save, load, check_accuracy_v2,  check_accuracy_v3, get_all_metrics, get_all_metrics_2
from models.Base3DUNet import Base3DUNet
from models.Attention3DUNet import Attention3UNet
from models.Residual3DUNet import Res3DUNet

BATCH_SIZE = 4
EPOCHS = 100
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

val_dl = get_dl(get_val_ds(full_masks=True), BATCH_SIZE, nw=2)

images: 148, masks: 148 


In [8]:
import gc

def clear():
    torch.cuda.empty_cache()
    gc.collect()

clear()

In [4]:
model = Base3DUNet(3, 3, features=[64, 128, 256, 512], up_sample=True).to(DEVICE)

In [6]:
load(model,"weights/3D/3d_100e_adam_b4_dice_upsampler")
print("3DUNet Upsampler with Dice Loss")
get_all_metrics_2(val_dl,model,DEVICE)

3DUNet Upsampler with Dice Loss
 Accuracy (TC,ET,WT): 
 99.3601 , 99.6902, 99.0066
Dice Score (TC,ET,WT): 
 0.7821627855300903 , 0.6888573169708252, 0.8394877314567566
IoU Score (TC,ET,WT): 
 0.6852228045463562 , 0.5777104496955872, 0.7514510154724121
 Precision (TC,ET,WT): 
 --> 40.6819 , 80.0799, 16.7454
 Recall (TC,ET,WT): 
 --> 34.6550 , 78.0800, 14.8955
 F1-score (TC,ET,WT): 
 --> 37.4274 , 79.0673, 15.7664


In [7]:
load(model,"weights/3D/3d_100e_adam_b4_bce_dice_upsampler")
print("3DUNet Upsampler with BCE-Dice Loss")
get_all_metrics_2(val_dl,model,DEVICE)

3DUNet Upsampler with BCE-Dice Loss
 Accuracy (TC,ET,WT): 
 99.3121 , 99.6928, 99.0178
Dice Score (TC,ET,WT): 
 0.7719242572784424 , 0.6890219449996948, 0.8466905951499939
IoU Score (TC,ET,WT): 
 0.6696090698242188 , 0.5770222544670105, 0.7578733563423157
 Precision (TC,ET,WT): 
 --> 41.4878 , 84.0877, 16.1345
 Recall (TC,ET,WT): 
 --> 34.6383 , 78.0820, 14.8971
 F1-score (TC,ET,WT): 
 --> 37.7549 , 80.9737, 15.4912


In [9]:
clear()
model = Base3DUNet(3, 3, features=[64, 128, 256, 512], up_sample=False).to(DEVICE)

In [10]:
load(model,"weights/3D/3d_100e_adam_b4_dice")
print("3DUNet with Dice Loss")
get_all_metrics_2(val_dl,model,DEVICE)

3DUNet with Dice Loss
 Accuracy (TC,ET,WT): 
 99.2633 , 99.6768, 98.9967
Dice Score (TC,ET,WT): 
 0.7423750162124634 , 0.6710378527641296, 0.8460586071014404
IoU Score (TC,ET,WT): 
 0.6411216259002686 , 0.5546414852142334, 0.7568053007125854
 Precision (TC,ET,WT): 
 --> 46.9100 , 97.4532, 14.9651
 Recall (TC,ET,WT): 
 --> 34.6212 , 78.0695, 14.8940
 F1-score (TC,ET,WT): 
 --> 39.8395 , 86.6910, 14.9294


In [11]:
load(model,"weights/3D/3d_100e_adam_b4_bce_dice")
print("3DUNet with BCE-Dice Loss")
get_all_metrics_2(val_dl, model, DEVICE)

3DUNet with BCE-Dice Loss
 Accuracy (TC,ET,WT): 
 99.3517 , 99.7012, 99.0492
Dice Score (TC,ET,WT): 
 0.7789778709411621 , 0.698265552520752, 0.8524792194366455
IoU Score (TC,ET,WT): 
 0.6808264255523682 , 0.5862983465194702, 0.7661637663841248
 Precision (TC,ET,WT): 
 --> 38.6881 , 85.5733, 16.1231
 Recall (TC,ET,WT): 
 --> 34.6521 , 78.0886, 14.9019
 F1-score (TC,ET,WT): 
 --> 36.5590 , 81.6598, 15.4884


In [12]:
clear()
model = Attention3UNet(3, 3, features=[64, 128, 256, 512], up_sample=False).to(DEVICE)

In [15]:
load(model,"weights/attention/att3d_100e_adam_b4_dice_c")
print("Attention3DUNet with Dice Loss - (ConvT version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with Dice Loss - (ConvT version)
 Accuracy (TC,ET,WT): 
 99.3544 , 99.6905, 99.0364
Dice Score (TC,ET,WT): 
 0.7816533446311951 , 0.6932132840156555, 0.8522657155990601
IoU Score (TC,ET,WT): 
 0.6859501004219055 , 0.5824636220932007, 0.7657946348190308
 Precision (TC,ET,WT): 
 --> 42.0900 , 82.9900, 16.1025
 Recall (TC,ET,WT): 
 --> 34.6530 , 78.0802, 14.8999
 F1-score (TC,ET,WT): 
 --> 38.0111 , 80.4603, 15.4779


In [17]:
load(model,"weights/attention/att3d_100e_adam_b4_bce_dice_c")
print("Attention3DUNet with BCE-Dice Loss - (ConvT version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with BCE-Dice Loss - (ConvT version)
 Accuracy (TC,ET,WT): 
 99.2950 , 99.6922, 99.0070
Dice Score (TC,ET,WT): 
 0.7608307003974915 , 0.6907647252082825, 0.8438194394111633
IoU Score (TC,ET,WT): 
 0.6616765856742859 , 0.5782598853111267, 0.7573457956314087
 Precision (TC,ET,WT): 
 --> 42.0382 , 86.3042, 16.2116
 Recall (TC,ET,WT): 
 --> 34.6323 , 78.0815, 14.8955
 F1-score (TC,ET,WT): 
 --> 37.9776 , 81.9872, 15.5257


In [18]:
clear()
model = Attention3UNet(3, 3, features=[64, 128, 256, 512], up_sample=True).to(DEVICE)

In [19]:
load(model,"weights/attention/att3d_100e_adam_b4_dice")
print("Attention3DUNet with Dice Loss - (Upsampler version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with Dice Loss - (Upsampler version)
 Accuracy (TC,ET,WT): 
 99.3225 , 99.6896, 99.0118
Dice Score (TC,ET,WT): 
 0.7768449783325195 , 0.6793321967124939, 0.844255805015564
IoU Score (TC,ET,WT): 
 0.674619734287262 , 0.5670403242111206, 0.7549707889556885
 Precision (TC,ET,WT): 
 --> 42.1787 , 92.9405, 16.8040
 Recall (TC,ET,WT): 
 --> 34.6419 , 78.0795, 14.8962
 F1-score (TC,ET,WT): 
 --> 38.0406 , 84.8643, 15.7927


In [20]:
load(model,"weights/attention/att3d_100e_adam_b4_bce_dice")
print("Attention3DUNet with Dice Loss - (Upsampler version)")
get_all_metrics_2(val_dl, model, DEVICE)

Attention3DUNet with Dice Loss - (Upsampler version)
 Accuracy (TC,ET,WT): 
 99.3204 , 99.6839, 98.9921
Dice Score (TC,ET,WT): 
 0.7743038535118103 , 0.6768757700920105, 0.8426200747489929
IoU Score (TC,ET,WT): 
 0.6734684109687805 , 0.5659209489822388, 0.7533087730407715
 Precision (TC,ET,WT): 
 --> 43.0446 , 89.8057, 16.7121
 Recall (TC,ET,WT): 
 --> 34.6411 , 78.0750, 14.8933
 F1-score (TC,ET,WT): 
 --> 38.3884 , 83.5305, 15.7503


In [None]:
clear()
model = Res3DUNet(3, 3).to(DEVICE)