In [None]:
import torch
from torch.utils.data import Subset
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present as consume_prefix
from itkwidgets import view

import scipy
import numpy as np
import random

import model as mdl
import Kyle_model as kmdl
import dataset as dtst

In [None]:
dataset_directory = "/home/sci/kyle.anderson/lymph_nodes/Dataset"
checkpoint_directory = "/home/sci/kyle.anderson/lymph_nodes/Lymph-Node-Segmentation"

In [None]:
# model = mdl.UNet64()
model = kmdl.UNet()
model.cuda()

In [None]:
num_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {num_train_params}.")

In [None]:
ddp_state_dict = torch.load(f"{checkpoint_directory}/KyleUNetWeightedBCE.tar")["model_state_dict"]
consume_prefix(ddp_state_dict, "module.")
model.load_state_dict(ddp_state_dict)

In [None]:
dataset = dtst.UnetDataset(dataset_directory, patch_size=64, min_probability=0.05)
print(f"Number of samples in dataset: {len(dataset)}.")

all_indices = list(range(len(dataset)))
random.Random(13).shuffle(all_indices)
train_set = Subset(dataset, all_indices[len(all_indices)-int(0.8*len(all_indices)):])
test_set = Subset(dataset, all_indices[:len(all_indices)-int(0.8*len(all_indices))])

In [None]:
whole_sample = data.get_whole_volume_and_mask(2)
view(whole_sample['img'], whole_sample['mask'])

In [None]:
with open(f"{dataset_directory}/prob_map.npy", "rb") as f:
    p_map = np.load(f)

In [None]:
p_map = p_map / np.max(p_map)

In [None]:
view(np.asarray(p_map))

# Test the model on a few images

In [None]:
def test(model, sample):
    model.eval()
    with torch.no_grad():
        pred = model(sample["img"].unsqueeze(0).cuda(non_blocking=True).type(torch.cuda.FloatTensor))
        # img = torch.from_numpy(scipy.ndimage.zoom(sample["img"], [1.0, 0.5, 0.5, 0.5])).unsqueeze(0)
        # pred = model(img.cuda(non_blocking=True).type(torch.cuda.FloatTensor))

    return pred.detach().cpu().numpy()

In [None]:
def dice_coeff(prediction, truth):

    if prediction.shape != truth.shape:
        print("Incompatible shapes.")
        print(f"{prediction.shape = }, {truth.shape = }")
        return
    
    intersection = torch.count_nonzero(prediction[prediction * truth >= 0.5])
    tot_voxels = (torch.count_nonzero(prediction[prediction >= 0.5])
                 + torch.count_nonzero(truth))
    return 2*intersection / tot_voxels

In [None]:
def find_avg_dice(model, dataset):
    model.eval()
    running_avg = 0.0
    with torch.no_grad():
        for sample in dataset:
            # img = torch.from_numpy(scipy.ndimage.zoom(sample["img"], [1.0, 0.5, 0.5, 0.5])).unsqueeze(0)
            # mask = torch.from_numpy(scipy.ndimage.zoom(sample["mask"], [1.0, 0.5, 0.5, 0.5])).unsqueeze(0)
            img = sample['img'].unsqueeze(0)
            mask = sample['mask'].unsqueeze(0)
            preds = model(img.cuda(non_blocking=True).type(torch.cuda.FloatTensor))
            loss = dice_coeff(preds.squeeze(), mask.cuda(non_blocking=True).type(torch.cuda.FloatTensor).squeeze())
            running_avg += loss

    running_avg /= len(dataset)
    return running_avg

In [None]:
avg_dice = find_avg_dice(model, dataset)

In [None]:
# average dice score with cube size 64 and with min probability 0.05 is 0.12046
print(f"Average dice score: {avg_dice:.5f}")

In [17]:
idx = 10
pred = test(model, dataset[idx])
binary_pred = np.zeros_like(pred)
binary_pred[pred >= 0.5] = 1
view(binary_pred.squeeze(), dataset[idx]["mask"].cpu().numpy().squeeze())

Viewer(geometries=[], gradient_opacity=0.22, interpolation=False, point_sets=[], rendered_image=<itk.itkImageP…

In [None]:
idx = 100
pred = test(model, test_set[idx])
view(pred.squeeze(), test_set[idx]["mask"].cpu().numpy().squeeze())

In [None]:
pred = test(model, data[200])
view(pred.squeeze(), data[200]["mask"].cpu().numpy().squeeze())

In [None]:
pred = test(model, data[100])
view(pred.squeeze(), data[100]["mask"].cpu().numpy().squeeze())