## Imports

In [19]:
import time
from time import sleep
import torch
from torch.utils.data import DataLoader, random_split
from itkwidgets import view

import model as md
import dataset as dtst

In [2]:
torch.cuda.empty_cache()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_tensor_type('torch.cuda.DoubleTensor')
print(f"{device = }")

device = device(type='cuda')


## Model and Optimizer

In [3]:
model = md.UNet64()
model.train()
model.cuda()

# params 464849, # conv layers 30


UNet64(
  (conv1_8): Conv3d(1, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv8_8): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv8_16): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv16_16): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv16_32): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv32_32): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv32_64): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv64_64): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv64_32): Conv3d(64, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv32_16): Conv3d(32, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv16_8): Conv3d(16, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
  (conv8_1): C

In [4]:
net = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
net.load_state_dict(torch.load("UNet64_cube_sz_64_p_05.pt"))
criterion = torch.nn.BCEWithLogitsLoss()
learning_rate = 0.005
optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)

## Datasets and Dataloaders

In [5]:
dataset = dtst.UnetDataset(root_dir="/home/sci/kyle.anderson/lymph_nodes/Dataset", patch_size=64, min_probability=0.5)
train_set, test_set = random_split(dataset,
                                   [int(len(dataset)*0.8), len(dataset) - int(len(dataset)*0.8)],
                                   generator=torch.Generator('cuda'))

batch_size = 4
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
val_loader = DataLoader(test_set, batch_size=1, shuffle=False, num_workers=0)

## Training

In [6]:
def train_loop(train_loader, model, loss_fn, optimizer, log_file=None):
    epoch_loss = 0.0
    size = len(train_loader.dataset)
    
    for idx, batch in enumerate(train_loader):
        optimizer.zero_grad()
        
        with torch.set_grad_enabled(True):
            input_data = batch["img"].cuda()
            pred = model(input_data)
            truth = batch["mask"].cuda()
            
            loss = loss_fn(pred, truth)
            loss.backward()
            
            epoch_loss += loss.item()
            
        optimizer.step()
        
        if idx % (size // 20) == 0:
            print(f"Loss: {loss.item():06.5f}\t[{idx*len(batch['name']):3d}/{size:3d}]")
        
    if log_file is not None:
        time_stamp = time.ctime(time.time())
        with open(log_file, "a") as f:
            f.write(f"Time: {time_stamp}\tLoss = {epoch_loss:.5f}\n")
            
    return epoch_loss

In [7]:
def test_loop(test_loader, model, loss_fn, log_file=None):
    size = len(test_loader.dataset)
    num_batches = len(test_loader)
    test_loss = 0.0
    
    with torch.no_grad():
        for batch in test_loader:
            pred = model(batch["img"].cuda())
            test_loss += loss_fn(pred, batch["mask"].cuda()).item()
            
    avg_loss = test_loss / num_batches
    print(f"Average test loss: {avg_loss:.5f}")
    
    if log_file is not None:
        time_stamp = time.ctime(time.time())
        with open(log_file, "a") as f:
            f.write(f"Time: {time_stamp}\tAverage loss = {avg_loss:.5f}\n")

# Only run the following section of you are trying to train the model!

In [10]:
epoch_losses = []
start_epoch, end_epoch = 40, 50

for epoch in range(start_epoch, end_epoch):
    print(f"Epoch {epoch+1} of {end_epoch}\n{'-'*40}")
    start_time = time.time()
    epoch_loss = train_loop(
        train_loader, net, criterion, optimizer, "UNet64_cube_sz_64_p_05_train_loss.txt")
    epoch_losses.append(epoch_loss)
    test_loop(test_loader, net, criterion, "UNet64_cube_sz_64_p_05_test_loss.txt")
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Epoch elapsed time: {int(elapsed_time//3600):2d}h:{int((elapsed_time%3600)//60):2d}m:{elapsed_time%60:6.3f}s")
    
    torch.save(net.state_dict(), "UNet64_cube_sz_64_p_05.pt")

Epoch 41 of 50
----------------------------------------
Loss: 0.10450	[  0/332]
Loss: 0.11286	[ 64/332]
Loss: 0.10182	[128/332]
Loss: 0.10775	[192/332]
Loss: 0.10018	[256/332]
Loss: 0.09956	[320/332]
Average test loss: 0.11296
Epoch elapsed time:  0h:38m: 7.673s
Epoch 42 of 50
----------------------------------------
Loss: 0.10283	[  0/332]
Loss: 0.11106	[ 64/332]
Loss: 0.10017	[128/332]
Loss: 0.10617	[192/332]
Loss: 0.09837	[256/332]
Loss: 0.09793	[320/332]
Average test loss: 0.11135
Epoch elapsed time:  0h:37m:43.960s
Epoch 43 of 50
----------------------------------------
Loss: 0.10118	[  0/332]
Loss: 0.10930	[ 64/332]
Loss: 0.09855	[128/332]
Loss: 0.10463	[192/332]
Loss: 0.09659	[256/332]
Loss: 0.09633	[320/332]
Average test loss: 0.10976
Epoch elapsed time:  0h:37m:25.144s
Epoch 44 of 50
----------------------------------------
Loss: 0.09956	[  0/332]
Loss: 0.10758	[ 64/332]
Loss: 0.09696	[128/332]
Loss: 0.10312	[192/332]
Loss: 0.09486	[256/332]
Loss: 0.09478	[320/332]
Average tes

In [37]:
def test_inference(dataset, model, num_samples):
    predictions = {}
    model.eval()
    with torch.no_grad():
        for sample in range(num_samples):
            pred = model(dataset[sample]["img"].cuda().unsqueeze(0))
            predictions[dataset[sample]["name"]] = pred.cpu().detach().numpy().squeeze()
    return predictions

In [38]:
predictions = test_inference(dataset, net, 4)

In [46]:
preds_list = list(predictions.values())
mask_0 = dataset[0]["mask"]
view(preds_list[0], mask_0.squeeze())

Viewer(geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], rendered_image=<itk.itkImageP…