In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from unet3d import UNet3D
import time
import os
import cv2

In [2]:
device = 'cuda'
model = UNet3D(in_channels=1, num_classes=2, block_channels=[8, 16, 32, 64]).to(device)
#model_2 = torch.load("model_kidney_1_dense_5epochs.pth").to(device)

in and out: (1, 8)
in and out: (8, 16)
in and out: (16, 32)
in and out: (32, 64)
in, res and out: (64, 32, 32)
in, res and out: (32, 16, 16)
in, res and out: (16, 8, 8)
output numbner of classes: 2


In [3]:
### data size

kidney_1_dense = (2279, 1303, 912) # dense, all  
kidney_1_voi_p = (1397, 1928, 1928) # high resulution subset  
kidney_2_sparse_65 = (2217, 1041, 1511) # sparsly segmented, 65%  
kidney_3_sparse_85 = (1035, 1706, 1510) # sparsly segmented, 85%  
kidney_3_dense_p_l = (501, 1706, 1510) # dense segmented subset, label ONLY

print(f'size of kidney_1_dense: {2279*1303*912*4/(1e9):.4} GB as float32')
print(f'size of kidney_1_voi: {1397*1928*1928*4/(1e9):.4} GB as float32')

size of kidney_1_dense: 10.83 GB as float32
size of kidney_1_voi: 20.77 GB as float32


In [88]:
# dataset class
import numpy as np
class Kidney3D(Dataset):
    """Dataset for one kidney"""
    def __init__(self, kidney_path, label_path, size=128, stride=-1):
        """Initialize data set based on volume size and stride"""
        self.kidney = np.load(kidney_path)
        self.label = np.load(label_path)
        if stride == -1: stride = size
        self.stride = stride
        self.size = size
        # pad the data to cover all raw data
        # not necessery for training but good for inferencing
        pad_dim = []
        for i in range(3):
            comp = stride - (self.kidney.shape[i] - size) % stride
            pad_dim.append((comp // 2, comp // 2 + comp % 2))
        self.kidney = np.pad(self.kidney, pad_dim, 'edge')
        self.label = np.pad(self.label, pad_dim, 'constant', constant_values=0)
        self.n_h = (self.kidney.shape[0] - size) // stride + 1
        self.n_w = (self.kidney.shape[1] - size) // stride + 1
        self.n_l = (self.kidney.shape[2] - size) // stride + 1
        self.n = self.n_h * self.n_w * self.n_l
        
    def __len__(self):
        return self.n
    
    def __getitem__(self, idx):
        h = idx // (self.n_w * self.n_l)
        w = idx % (self.n_w * self.n_l) // self.n_l
        l = idx % (self.n_w * self.n_l) % self.n_l

        data = torch.from_numpy(self.kidney[h*self.stride:h*self.stride+self.size, 
                                            w*self.stride:w*self.stride+self.size,
                                            l*self.stride:l*self.stride+self.size])
        data = torch.unsqueeze(data, 0).to(torch.float)
        
        label = torch.from_numpy(self.label[h*self.stride:h*self.stride+self.size, 
                                            w*self.stride:w*self.stride+self.size,
                                            l*self.stride:l*self.stride+self.size])
        label = label.to(torch.long)
        # print(idx, (h, w, l)) # (0, 0, 13)
        # print(f'data shape: {data.shape}')
        # print(f'label shape: {label.shape}')
        return data, label



In [5]:
# explore defined dataset
kidney_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_images.npy'
label_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_labels.npy'
train_ds = Kidney3D(kidney_path, label_path, size=128, stride=64)
train_ds.__len__()

8398

In [91]:
kidney_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_images.npy'
label_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_labels.npy'
train_ds = Kidney3D(kidney_path, label_path, size=128, stride=-1)
train_ds.__len__()

1584

In [92]:
# unbalanced issue of the data

# label_list = []
# for i in range(ds.n):
#     label_list.append(ds.__getitem__(i)[1].sum().item())

# print(len(label_list))
# print(sum([1 for i in label_list if i > 0]))

pos = train_ds.label.sum()
total = train_ds.label.shape[0] * train_ds.label.shape[1] * train_ds.label.shape[2]
print(f'{pos/total:.5f}')

0.00503


when size = 128, 509 of 1190 data non-empty  
when size = 128, stride = 20, 120291 out of 254880 non-empty
when size = 512, 71 of 120 data non-empty  
pixel wise, 0.00503 of data are positive

In [8]:
train_dataloader = DataLoader(train_ds, batch_size=4, shuffle=True)

In [9]:
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([0.2, 0.8]).to(device))
optimizer = torch.optim.Adam(model.parameters())

In [10]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/unet_3d')
# X, y = next(iter(train_dataloader))
# writer.add_graph(model, X)
# writer.close()

2023-12-28 05:06:00.118488: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
def dice_accuracy_batch(pred, y):
    """return overlap, label_positive, pred_positive for each batch
    
    Since one batch may not contains any positive label, leave the accuracy calculation to the end 
    """
    overlap = (pred.argmax(1) & y).type(torch.int).sum().item()
    pred_positive = (pred.argmax(1) == 1).type(torch.int).sum().item()
    label_positive = y.type(torch.int).sum().item()
    
    return overlap, pred_positive, label_positive

In [12]:
def train(dataloader, model, loss_fn, optimizer, print_gap=20, track_accuracy=True, epoch=0):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, overlap, pred_positive, label_positive = 0,0,0,0
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss = loss.item()
        writer.add_scalar('loss', loss, batch + epoch * num_batches)
        test_loss += loss

        if track_accuracy:
            overlap_batch, pred_positive_batch, label_positive_batch = dice_accuracy_batch(pred, y)
            overlap += overlap_batch
            pred_positive += pred_positive_batch
            label_positive += label_positive_batch
            if (pred_positive_batch + label_positive_batch) == 0:
                batch_dice_accuracy = -1
            else:
                batch_dice_accuracy = 2*overlap_batch / (pred_positive_batch + label_positive_batch)
            writer.add_scalar('Batch_Dice_Accuracy', batch_dice_accuracy, batch + epoch * num_batches)

        if batch % print_gap == 0:
            current = (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
            if track_accuracy: print(f"batch Dice Accuracy: {batch_dice_accuracy:.2%}")
            

    avg_loss = test_loss / num_batches
    print(f"Epoch Avg loss: {avg_loss:>8f}")
    if track_accuracy:
        dice_accuracy = 2*overlap / (pred_positive + label_positive)
        print(f"Epoch Dice Accuracy: {dice_accuracy:.2%}")


In [13]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, overlap, pred_positive, label_positive = 0,0,0,0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            overlap_batch, pred_positive_batch, label_positive_batch = dice_accuracy_batch(pred, y)
            overlap += overlap_batch 
            pred_positive += pred_positive_batch
            label_positive += label_positive_batch
    test_loss /= num_batches
    dice_accuracy = 2*overlap / (pred_positive + label_positive)
    print(f"Dice Accuracy: {dice_accuracy:.2%}, Avg loss: {test_loss:>8f} \n")

In [17]:
epochs = 3
for t in range(epochs):
    start_time = time.time()
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, print_gap=100, track_accuracy=True, epoch=t)
    print(f"Epoch {t+1} training done, time: {time.time() - start_time:.2f}")
print("Done!")

Epoch 1
-------------------------------
loss: 0.006644  [    4/ 8398]
batch Dice Accuracy: 78.41%
loss: 0.031050  [  404/ 8398]
batch Dice Accuracy: 62.07%
loss: 0.005226  [  804/ 8398]
batch Dice Accuracy: 86.98%
loss: 0.004521  [ 1204/ 8398]
batch Dice Accuracy: 0.09%
loss: 0.007121  [ 1604/ 8398]
batch Dice Accuracy: 81.50%
loss: 0.009353  [ 2004/ 8398]
batch Dice Accuracy: 90.06%
loss: 0.000826  [ 2404/ 8398]
batch Dice Accuracy: 33.20%
loss: 0.006756  [ 2804/ 8398]
batch Dice Accuracy: 85.45%
loss: 0.014820  [ 3204/ 8398]
batch Dice Accuracy: 62.88%
loss: 0.003228  [ 3604/ 8398]
batch Dice Accuracy: 68.23%
loss: 0.013954  [ 4004/ 8398]
batch Dice Accuracy: 70.32%
loss: 0.016927  [ 4404/ 8398]
batch Dice Accuracy: 77.39%
loss: 0.010809  [ 4804/ 8398]
batch Dice Accuracy: 89.19%
loss: 0.001397  [ 5204/ 8398]
batch Dice Accuracy: 1.56%
loss: 0.013443  [ 5604/ 8398]
batch Dice Accuracy: 83.12%
loss: 0.002059  [ 6004/ 8398]
batch Dice Accuracy: 82.57%
loss: 0.006753  [ 6404/ 8398]
batc

In [20]:
writer.close()

In [16]:
test(train_dataloader, model, loss_fn)

Dice Accuracy: 88.61%, Avg loss: 0.011191 



# Test on another kidney

In [None]:
kidney_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_images.npy'
label_path = '/home/ziyu/Projects/blood-vessel-segmentation/data/processed/conv/kidney_1_dense_labels.npy'
train_ds = Kidney3D(kidney_path, label_path, size=128, stride=64)
train_ds.__len__()

# Save & Load model

In [19]:
#torch.save(model, "model_kidney_1_dense_weighted_8epochs.pth")
#model = torch.load("model_kidney_1_dense_weighted_8epochs.pth")

In [17]:
model_2 = torch.load("model_kidney_1_dense_5epochs.pth").to(device)

In [18]:
test(train_dataloader, model_2, loss_fn)

Dice Accuracy: 87.90%, Avg loss: 0.016385 

