In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np

import os
import sys
import h5py
import time
import tempfile
from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import display as ipy_display, clear_output

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# setting path
sys.path.append('../')

In [3]:
unet_3D = __import__('3D-Unet-4')
import networks

In [4]:
import dataset

# parameters

In [5]:
num_classes = 5
batch_size = 1
suffix = 'run4'

dataset_name = 'abdomen'

# CUDA

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

# Path

In [7]:
source_mr_train_dir = "../../data2/preprocessed_data/"
source_mr_test_dir = "../../data2/preprocessed_data/"

In [8]:
target_ct_train_dir = "../../data2/preprocessed_data/"
target_ct_test_dir = "../../data2/preprocessed_data/"

# label_ids_abdomen

In [9]:
label_ids_abdomen = {"ignore": 0,
    "liver": 1,
    "right_kidney": 2,
    "left_kidney": 3,
    "spleen": 4,
}

label_ids = label_ids_abdomen

# Dataset

In [10]:
def sample_batch(dataset, batch_size=20, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    data_dir = dataset.data_dir
    num_samples = len(dataset)
    sample_indices = np.random.choice(num_samples, batch_size, replace=False) # replace=True allow repeat

    images = []
    labels = []

    for idx in sample_indices:
        
        data_vol, label_vol = dataset[idx]
        
        images.append(data_vol)
        labels.append(label_vol)

    images = torch.stack(images)
    labels = torch.stack(labels)

    return images, labels

#  sliding_window

In [11]:
def sliding_window(input_volume, window_size=(32, 256, 256), stride=(16, 128, 128)):
    
    z_max = input_volume.shape[0] - window_size[0] + 1 # z_max = 269
    y_range = range(0, input_volume.shape[1] - window_size[1] + 1, stride[1]) # 0, 128, 256
    x_range = range(0, input_volume.shape[2] - window_size[2] + 1, stride[2]) # 0, 128, 256

    windows = []

    for y in y_range:
        for x in x_range:
            # Loop through the z slices with stride
            # z_range: 0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256
            for z in range(0, z_max, stride[0]):
                window = input_volume[z:z+window_size[0], y:y+window_size[1], x:x+window_size[2]]
                windows.append(window)

            # Add an additional window for the remaining depth
            z_remaining = input_volume.shape[0] - window_size[0] # z_remaining = 78 - 32 = 46
            window = input_volume[z_remaining:, y:y+window_size[1], x:x+window_size[2]]
            windows.append(window)

    return windows

In [12]:
def combine_windows(window_outputs, input_volume_shape, window_size=(32, 256, 256), stride=(16, 128, 128)):
    num_classes = window_outputs[0].shape[1] # 5
    combined_prob = torch.zeros((num_classes,) + input_volume_shape).to(device)
    count_matrix = torch.zeros(input_volume_shape).to(device)

    z_max = input_volume_shape[0] - window_size[0] + 1
    y_range = range(0, input_volume_shape[1] - window_size[1] + 1, stride[1])
    x_range = range(0, input_volume_shape[2] - window_size[2] + 1, stride[2])

    idx = 0

    for y in y_range:
        for x in x_range:
            # Loop through the z slices with stride
            for z in range(0, z_max, stride[0]):
                output = window_outputs[idx].squeeze() # output.cpu().numpy().shape: (5, 32, 256, 256)
                combined_prob[:, z:z+window_size[0], y:y+window_size[1], x:x+window_size[2]] += output
                count_matrix[z:z+window_size[0], y:y+window_size[1], x:x+window_size[2]] += 1
                idx += 1

            # Add an additional window for the remaining depth
            z_remaining = input_volume_shape[0] - window_size[0]
            output = window_outputs[idx].squeeze()
            combined_prob[:, z_remaining:, y:y+window_size[1], x:x+window_size[2]] += output
            count_matrix[z_remaining:, y:y+window_size[1], x:x+window_size[2]] += 1
            idx += 1

    # Normalize the class probabilities
    combined_prob /= count_matrix

    # Take the argmax of the accumulated probabilities
    combined_output = torch.argmax(combined_prob, dim=0)

    return combined_output

# compute_miou

In [13]:
def compute_miou(images, labels, Unet, classifier, label_ids, id_to_ignore=0):
    N = len(images)

    intersection = dict()
    union = dict()
    for label in label_ids:
        intersection[label] = union[label] = 0
    
    Unet.eval()
    classifier.eval()
    
    for i in range(N):
        X = images[i].unsqueeze(0).to(device) # (1,3,256,256)
        y_true = labels[i].view(-1).cpu().numpy()

        with torch.no_grad():
            myans1 = Unet(X)
            myans2 = classifier(myans1) # output: [1, 5, 256, 256]

            # Apply softmax to the output logits
            myans2_softmax = F.softmax(myans2, dim=1) # out size is: torch.Size([1, 5, 256, 256])

            # Get class predictions by selecting the class with the highest probability
            myans2_pred = torch.argmax(myans2_softmax, dim=1) # out size is: torch.Size([1, 256, 256])
            
            y_hat = myans2_pred.view(-1).cpu().numpy()

        for label in label_ids:
            if label_ids[label] == id_to_ignore:
                continue

            curr_id = label_ids[label]

            idx_gt = y_true == curr_id
            idx_hat = y_hat == curr_id

            intersection[label] += np.sum(idx_gt & idx_hat)
            union[label] += np.sum(idx_gt | idx_hat)

    mIoU = []
    res = dict()
    for label in label_ids:
        if label_ids[label] == id_to_ignore:
            continue

        if union[label] != 0:
            res[label] = intersection[label] / union[label]
        else:
            res[label] = np.float64(0)

        mIoU.append(res[label])
    
    return res, np.mean(mIoU)

# Combined model

In [14]:
class Combined_Model(nn.Module):
    
    def __init__(self, Unet, classifier):
        super(Combined_Model, self).__init__()
        
        self.Unet = Unet
        self.classifier = classifier
        
    def forward(self, x):
        
        output1 = self.Unet(x)
        output2 = self.classifier(output1)
        
        return output2

# Initialize

In [15]:
dataloader = dataset.get_dataloader( target_ct_train_dir,  target_ct_test_dir, num_classes, batch_size,  domain = 'target' )

train_dataset = dataloader["train"].dataset
test_dataset = dataloader["test"].dataset

In [16]:
Unet = unet_3D.UNet(1, num_classes, 16)
classifier = networks.Classifier(num_classes)

Unet = Unet.to(device)
classifier = classifier.to(device)

# Create the combined model
#combined = Combined_Model(Unet, Classifier)
#combined = combined.to(device)

# parallel
#Unet = torch.nn.DataParallel(Unet, device_ids=[0, 1, 2, 3])
#Classifier = torch.nn.DataParallel(Classifier, device_ids=[0, 1, 2, 3])

Unet = torch.nn.DataParallel(Unet)
classifier = torch.nn.DataParallel(classifier)

In [17]:
Unet_checkpoint = torch.load('../record-data/' + 'Unet_weights_' + suffix  + '.pth')
classifier_checkpoint = torch.load('../record-data/' + 'classifier_weights_' + suffix + '.pth')

Unet.load_state_dict(Unet_checkpoint)
classifier.load_state_dict(classifier_checkpoint)
print("Loaded model weights")

Loaded model weights


# Show results

In [18]:
test_output = []

for img_idx in range(len(test_dataset)): # 0, 1, 2, 3
    
    data_vol, label_vol = test_dataset[img_idx] # data_vol: torch.Size([1, 60, 512, 512])
    data_vol = data_vol.to(device)
    label_vol = label_vol.to(device)
    
    data_vol = torch.squeeze(data_vol, 0) # data_vol:  torch.Size([60, 512, 512])
    windows = sliding_window(data_vol) # slice 3D image based on window size and stride
    
    
    
    window_outputs = []
    
    Unet.eval()
    classifier.eval() 
    with torch.no_grad():
        for window in windows:
            window = window.unsqueeze(0)  # Add a channel dimension: torch.Size([1, 32, 256, 256])
            window = torch.unsqueeze(window, 0)  # Add a batch dimension: torch.Size([1, 1, 32, 256, 256])
            
            # inference
            output = Unet(window)
            output = classifier(output) # torch.Size([1, 5, 32, 256, 256])
            
            # collect outputs
            window_outputs.append(output)  # len(window_outputs) = 27
            # window_outputs[0].cpu().numpy().shape： (1, 5, 32, 256, 256)

    combined_output = combine_windows(window_outputs, data_vol.size())
    test_output.append(combined_output)

# evaluate

In [19]:
numpy_arrays = [tensor.cpu().numpy() for tensor in test_output]

In [20]:
numpy_arrays[0].shape

(63, 512, 512)

In [21]:
numpy_arrays[1].shape

(75, 512, 512)

In [22]:
numpy_arrays[2].shape

(50, 512, 512)

In [23]:
numpy_arrays[3].shape

(51, 512, 512)

In [24]:
id_to_ignore = 0
intersection = dict()
total = dict()
for label in label_ids:
    intersection[label] = total[label] = 0


for img_idx in range(len(test_dataset)): # 0, 1, 2, 3
    
    _, y_true = test_dataset[img_idx] # data_vol: torch.Size([1, 60, 512, 512])
    
    y_hat = numpy_arrays[img_idx]
    y_true = y_true.cpu().numpy() 
    
    print(y_hat.shape)
    print(y_true.shape)
    
    for label in label_ids:
        if label_ids[label] == id_to_ignore:
            continue

        curr_id = label_ids[label]

        idx_gt = y_true == curr_id
        idx_hat = y_hat == curr_id

        intersection[label] += 2 * np.sum(idx_gt & idx_hat)
        total[label] += np.sum(idx_gt) + np.sum(idx_hat)
        
    dice = []
    res = dict()
    for label in label_ids:
        if label_ids[label] == id_to_ignore:
            continue
            
        if total[label] != 0:
            res[label] = intersection[label] / total[label]
        else:
            print('total is zero')
            res[label] = np.float64(0)

        dice.append(res[label])

(63, 512, 512)
(63, 512, 512)
(75, 512, 512)
(75, 512, 512)
(50, 512, 512)
(50, 512, 512)
(51, 512, 512)
(51, 512, 512)
(61, 512, 512)
(61, 512, 512)
(56, 512, 512)
(56, 512, 512)


In [25]:
np.mean(dice)

0.018396567250762587

In [26]:
for k in res:
    print(k, res[k])

liver 0.0
right_kidney 0.0
left_kidney 0.03390754831696667
spleen 0.03967872068608367
