In [1]:
import torch

In [2]:
from util.data import get_dl, get_train_ds, get_val_ds
from util.training_utils import train, check_accuracy, save, load
from models.Base3DUNet import Base3DUNet
from util.loss import DiceLoss

In [3]:
BATCH_SIZE = 4
EPOCHS = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
LR = 0.0001
# DEVICE = 'cpu'

In [5]:
train_dl = get_dl(get_train_ds(), BATCH_SIZE,nw=2)
val_dl = get_dl(get_val_ds(), BATCH_SIZE, nw=2)

images: 221, masks: 221 
images: 148, masks: 148 


In [6]:
import gc
torch.cuda.empty_cache()
gc.collect()

89

In [7]:
model = Base3DUNet(3, 3, features=[64, 128, 256, 512]).to(DEVICE)
print(model)

Base3DUNet(
  (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (downs): ModuleList(
    (0): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(3, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (1): DoubleConv3D(
      (conv): Sequential(
        (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1),

In [8]:

print(f"total parameters = {sum(p.numel() for p in model.parameters())}")
print(f"total learnable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

total parameters = 90302147
total learnable parameters = 90302147


In [9]:
opt = torch.optim.Adam(model.parameters(), lr=LR)
# loss = torch.nn.BCEWithLogitsLoss # cannot be used because there's a lot of imbalance anyway, so it is better to combine it with dice
loss = DiceLoss()

In [10]:
print(DEVICE)


cuda


In [11]:
train(model, epochs=EPOCHS, training_loader=train_dl, loss_fn=loss, device=DEVICE, optimizer=opt)

At epoch [1/10]: 100%|██████████| 56/56 [01:38<00:00,  1.76s/it, loss=0.958]
At epoch [2/10]: 100%|██████████| 56/56 [01:38<00:00,  1.75s/it, loss=0.944]
At epoch [3/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.949]
At epoch [4/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.827]
At epoch [5/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.978]
At epoch [6/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.931]
At epoch [7/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.961]
At epoch [8/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.887]
At epoch [9/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.919]
At epoch [10/10]: 100%|██████████| 56/56 [01:37<00:00,  1.75s/it, loss=0.746]


In [12]:
check_accuracy(val_dl,model,DEVICE)

Results: 112503139/116391936 with accuracy 96.6589
Dice score: 0.47057220339775085


In [18]:
# saving sample
save(model,"weights/3D/t2")

In [19]:
# loading sample
model1 = Base3DUNet(3,3)
load(model,"weights/3D/t1")