In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from unet3d import UNet3D
import time
import numpy as np

In [2]:
device = 'cuda'
model = UNet3D(in_channels=1, num_classes=2, block_channels=[4, 8, 16, 32]).to(device)
#model_2 = torch.load("model_kidney_1_dense_5epochs.pth").to(device)
#model = torch.load("model_kidney_1_dense_weighted_8epochs.pth").to(device)

in and out: (1, 4)
in and out: (4, 8)
in and out: (8, 16)
in and out: (16, 32)
in, res and out: (32, 16, 16)
in, res and out: (16, 8, 8)
in, res and out: (8, 4, 4)
output numbner of classes: 2


In [3]:
### data size

kidney_1_dense = (2279, 1303, 912) # dense, all  
kidney_1_voi_p = (1397, 1928, 1928) # high resulution subset  
kidney_2_sparse_65 = (2217, 1041, 1511) # sparsly segmented, 65%  
kidney_3_sparse_85 = (1035, 1706, 1510) # sparsly segmented, 85%  
kidney_3_dense_p_l = (501, 1706, 1510) # dense segmented subset, label ONLY

print(f'size of kidney_1_dense: {2279*1303*912*4/(1e9):.4} GB as float32')
print(f'size of kidney_1_voi: {1397*1928*1928*4/(1e9):.4} GB as float32')

size of kidney_1_dense: 10.83 GB as float32
size of kidney_1_voi: 20.77 GB as float32


# Dataset

In [4]:
# dataset class

class Kidney3D(Dataset):
    """Dataset for one kidney"""
    def __init__(self, kidney_path, label_path, size=128, stride=-1):
        """Initialize data set based on volume size and stride"""
        self.kidney = np.load(kidney_path)
        self.label = np.load(label_path)
        if stride == -1: stride = size
        self.stride = stride
        self.size = size
        # pad the data to cover all raw data
        # not necessery for training but good for inferencing
        pad_dim = []
        for i in range(3):
            comp = stride - (self.kidney.shape[i] - size) % stride
            pad_dim.append((comp // 2, comp // 2 + comp % 2))
        self.kidney = np.pad(self.kidney, pad_dim, 'edge')
        self.label = np.pad(self.label, pad_dim, 'constant', constant_values=0)
        self.n_h = (self.kidney.shape[0] - size) // stride + 1
        self.n_w = (self.kidney.shape[1] - size) // stride + 1
        self.n_l = (self.kidney.shape[2] - size) // stride + 1
        self.n = self.n_h * self.n_w * self.n_l
        
    def __len__(self):
        return self.n
    
    def __getitem__(self, idx):
        h = idx // (self.n_w * self.n_l)
        w = idx % (self.n_w * self.n_l) // self.n_l
        l = idx % (self.n_w * self.n_l) % self.n_l

        data = torch.from_numpy(self.kidney[h*self.stride:h*self.stride+self.size, 
                                            w*self.stride:w*self.stride+self.size,
                                            l*self.stride:l*self.stride+self.size])
        data = torch.unsqueeze(data, 0).to(torch.float)
        
        label = torch.from_numpy(self.label[h*self.stride:h*self.stride+self.size, 
                                            w*self.stride:w*self.stride+self.size,
                                            l*self.stride:l*self.stride+self.size])
        label = label.to(torch.long)
        # print(idx, (h, w, l)) # (0, 0, 13)
        # print(f'data shape: {data.shape}')
        # print(f'label shape: {label.shape}')
        return data, label

In [5]:
# unbalanced issue of the data

# label_list = []
# for i in range(ds.n):
#     label_list.append(ds.__getitem__(i)[1].sum().item())

# print(len(label_list))
# print(sum([1 for i in label_list if i > 0]))

# pos = train_ds.label.sum()
# total = train_ds.label.shape[0] * train_ds.label.shape[1] * train_ds.label.shape[2]
# print(f'{pos/total:.5f}')

when size = 128, 509 of 1190 data non-empty  
when size = 128, stride = 20, 120291 out of 254880 non-empty
when size = 512, 71 of 120 data non-empty  
pixel wise, 0.00503 of data are positive

# Loss, optimizer, accuracy, train & test functions

In [6]:
from torch.utils.tensorboard import SummaryWriter
#writer.close()
writer = SummaryWriter('runs/unet_3d')
#tensorboard --logdir='/home/ziyu/Projects/blood-vessel-segmentation/runs'

2024-01-05 10:02:35.008887: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 20.0]).to(device))
optimizer = torch.optim.Adam(model.parameters())

In [8]:
def dice_accuracy_batch(pred, y):
    """return overlap, label_positive, pred_positive for each batch
    
    Since one batch may not contains any positive label, leave the accuracy calculation to the end 
    """
    overlap = (pred.argmax(1) & y).type(torch.int).sum().item()
    pred_positive = (pred.argmax(1) == 1).type(torch.int).sum().item()
    label_positive = y.type(torch.int).sum().item()
    
    return overlap, pred_positive, label_positive

In [9]:
def train(dataloader, model, loss_fn, optimizer, print_gap=20, track_accuracy=True, epoch=0):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, overlap, pred_positive, label_positive = 0,0,0,0
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss = loss.item()
        writer.add_scalar('loss', loss, batch + epoch * num_batches)
        test_loss += loss

        if track_accuracy:
            overlap_batch, pred_positive_batch, label_positive_batch = dice_accuracy_batch(pred, y)
            overlap += overlap_batch
            pred_positive += pred_positive_batch
            label_positive += label_positive_batch
            if (pred_positive_batch + label_positive_batch) == 0:
                batch_dice_accuracy = -1
            else:
                batch_dice_accuracy = 2*overlap_batch / (pred_positive_batch + label_positive_batch)
            writer.add_scalar('Batch_Dice_Accuracy', batch_dice_accuracy, batch + epoch * num_batches)

        if batch % print_gap == 0:
            current = (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            if track_accuracy: print(f"batch Dice Accuracy: {batch_dice_accuracy:.2%}")
            

    avg_loss = test_loss / num_batches
    print(f"Epoch Avg loss: {avg_loss:>8f}")
    if track_accuracy:
        dice_accuracy = 2*overlap / (pred_positive + label_positive)
        print(f"Epoch Dice Accuracy: {dice_accuracy:.2%}")


In [10]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, overlap, pred_positive, label_positive = 0,0,0,0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            overlap_batch, pred_positive_batch, label_positive_batch = dice_accuracy_batch(pred, y)
            overlap += overlap_batch 
            pred_positive += pred_positive_batch
            label_positive += label_positive_batch
    test_loss /= num_batches
    dice_accuracy = 2*overlap / (pred_positive + label_positive)
    print(f"Dice Accuracy: {dice_accuracy:.2%}, Avg loss: {test_loss:>8f} \n")

# Training

In [11]:
# create training datasets
image_size = 128
batch_size = 8

kidney_1_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_images.npy'
label_1_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_labels.npy'
train_ds_1 = Kidney3D(kidney_1_path, label_1_path, size=image_size, stride=64)
print(train_ds_1.__len__())
train_dataloader_1 = DataLoader(train_ds_1, batch_size=batch_size, shuffle=True)

kidney_2_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_2_images.npy'
label_2_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_2_labels.npy'
train_ds_2 = Kidney3D(kidney_2_path, label_2_path, size=image_size, stride=64)
print(train_ds_2.__len__())
train_dataloader_2 = DataLoader(train_ds_2, batch_size=batch_size, shuffle=True)

kidney_3_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_3_sparse_images.npy'
label_3_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_3_sparse_labels.npy'
train_ds_3 = Kidney3D(kidney_3_path, label_3_path, size=image_size, stride=64)
print(train_ds_3.__len__())
train_dataloader_3 = DataLoader(train_ds_3, batch_size=batch_size, shuffle=True)

9800
12512
9568


In [12]:
#torch.cuda.empty_cache()

In [13]:
# from sparse to dense, kidney 2, kidney 3, kidney 1, each 10 epochs
epochs = 10
for train_dataloader in (train_dataloader_2, train_dataloader_3, train_dataloader_1):
    for t in range(epochs):
        start_time = time.time()
        print(f"Epoch {t+1}\n-------------------------------")
        train(train_dataloader, model, loss_fn, optimizer, print_gap=1000, track_accuracy=True, epoch=t)
        print(f"Epoch {t+1} training done, time: {time.time() - start_time:.2f}")
    print("Done!")

Epoch 1
-------------------------------
loss: 0.621270  [    8/12512]
batch Dice Accuracy: 1.89%
loss: 0.035701  [ 8008/12512]
batch Dice Accuracy: 38.47%
Epoch Avg loss: 0.117618
Epoch Dice Accuracy: 37.08%
Epoch 1 training done, time: 946.52
Epoch 2
-------------------------------
loss: 0.025255  [    8/12512]
batch Dice Accuracy: 51.03%
loss: 0.019502  [ 8008/12512]
batch Dice Accuracy: 31.50%
Epoch Avg loss: 0.032898
Epoch Dice Accuracy: 56.48%
Epoch 2 training done, time: 961.29
Epoch 3
-------------------------------
loss: 0.024734  [    8/12512]
batch Dice Accuracy: 37.47%
loss: 0.011250  [ 8008/12512]
batch Dice Accuracy: 86.11%
Epoch Avg loss: 0.022147
Epoch Dice Accuracy: 64.39%
Epoch 3 training done, time: 960.94
Epoch 4
-------------------------------
loss: 0.012362  [    8/12512]
batch Dice Accuracy: 23.87%
loss: 0.031201  [ 8008/12512]
batch Dice Accuracy: 97.68%
Epoch Avg loss: 0.018255
Epoch Dice Accuracy: 69.60%
Epoch 4 training done, time: 960.91
Epoch 5
-------------

In [14]:
test_ds_1 = Kidney3D(kidney_1_path, label_1_path, size=image_size, stride=-1)
test_dataloader_1 = DataLoader(test_ds_1, batch_size=batch_size, shuffle=True)

test_ds_2 = Kidney3D(kidney_2_path, label_2_path, size=image_size, stride=-1)
test_dataloader_2 = DataLoader(test_ds_2, batch_size=batch_size, shuffle=True)

test_ds_3 = Kidney3D(kidney_3_path, label_3_path, size=image_size, stride=-1)
test_dataloader_3 = DataLoader(test_ds_3, batch_size=batch_size, shuffle=True)

In [15]:
# image size 128, net work 4...
test(test_dataloader_2, model, loss_fn)
test(test_dataloader_3, model, loss_fn)
test(test_dataloader_1, model, loss_fn)

Dice Accuracy: 38.43%, Avg loss: 0.171806 

Dice Accuracy: 0.01%, Avg loss: 0.481434 

Dice Accuracy: 82.27%, Avg loss: 0.007162 



In [15]:
# image size 128, net work 8..
test(test_dataloader_2, model, loss_fn)
test(test_dataloader_3, model, loss_fn)
test(test_dataloader_1, model, loss_fn)

Dice Accuracy: 65.59%, Avg loss: 0.138016 

None
Dice Accuracy: 0.01%, Avg loss: 0.382423 

None
Dice Accuracy: 88.28%, Avg loss: 0.005435 

None


In [16]:
test(test_dataloader_2, model, loss_fn)

Dice Accuracy: 65.59%, Avg loss: 0.138883 



In [17]:
test(test_dataloader_3, model, loss_fn)

Dice Accuracy: 0.01%, Avg loss: 0.380592 

None


In [None]:
# Model performe so bad on kidney_3, try train 1 more epoch on kidney_3
for t in range(1):
    start_time = time.time()
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader_3, model, loss_fn, optimizer, print_gap=1000, track_accuracy=True, epoch=t)
    print(f"Epoch {t+1} training done, time: {time.time() - start_time:.2f}")
print("Done!")

In [20]:
# trained on kidney_3, check performance on other kidneys
test(test_dataloader_2, model, loss_fn)
test(test_dataloader_3, model, loss_fn)
test(test_dataloader_1, model, loss_fn)

Dice Accuracy: 2.70%, Avg loss: 1.218644 

None
Dice Accuracy: 89.70%, Avg loss: 0.006915 

None
Dice Accuracy: 36.56%, Avg loss: 0.137333 

None


In [20]:
writer.close()

# Test on another kidney

In [14]:
# create test dataset on kidney_3_dense
kidney_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_3_sparse_images.npy'
label_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_3_sparse_labels.npy'
test_ds = Kidney3D(kidney_path, label_path, size=128, stride=-1)
print(test_ds.__len__())
test_dataloader = DataLoader(test_ds, batch_size=4, shuffle=True)


1512


In [15]:
test(test_dataloader, model, loss_fn)

Dice Accuracy: 0.01%, Avg loss: 0.091088 



In [17]:
# create test dataset on kidney_2_dense
kidney_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_2_images.npy'
label_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_2_labels.npy'
test_ds = Kidney3D(kidney_path, label_path, size=128, stride=-1)
print(test_ds.__len__())
test_dataloader = DataLoader(test_ds, batch_size=4, shuffle=True)

1944


In [18]:
test(test_dataloader, model, loss_fn)

Dice Accuracy: 1.24%, Avg loss: 0.130242 



In [23]:
2000 / 3.5

571.4285714285714

# Save & Load model

In [17]:
torch.save(model, "model_kidney_231_64_10epochs_3_1epoch.pth")
#model = torch.load("model_kidney_1_dense_weighted_8epochs.pth")

In [17]:
model_2 = torch.load("model_kidney_1_dense_5epochs.pth").to(device)

In [18]:
test(train_dataloader, model_2, loss_fn)

Dice Accuracy: 87.90%, Avg loss: 0.016385 

