In [0]:
import os
import numpy as np
import SimpleITK as sitk
import torch
from torch.autograd import Variable
from unet3d import UNet
import progressbar
os.environ['CUDA_VISIBLE_DEVICES']='0'

In [0]:
from iterator import BasicVolumeIterator
from pred_utils import calc_pad_for_fit, calc_pad_for_pred_loss, crop_pad_width

In [0]:
patch_size = [116, 132, 132]
out_size = [28, 44, 44]

In [0]:
model = UNet(n_ch = 1, n_class = 2).cuda()
ckpt = torch.load('./models/model_epoch95.bin')
model.load_state_dict(ckpt);
model.eval();

In [0]:
import scipy.ndimage as snd
from utils import resample

img_path = '/mnt/data/LiverCT/Parenchyma/LITS/val/volume-100.nii'

img = sitk.ReadImage(img_path)
img = resample(img, (1.0, 1.0, 1.0), interpolator = sitk.sitkLinear)

img_arr = sitk.GetArrayFromImage(img)
img_arr = snd.zoom(img_arr, zoom = (0.5, 0.5, 0.5), order = 1)

print('img arr shape : ', img_arr.shape)
pad_size1 = calc_pad_for_fit(img_arr.shape, out_size)
print('pad size 1 : ', pad_size1)
pad_size2 = calc_pad_for_pred_loss(patch_size, out_size)
print('pad size 2 : ', pad_size2)

pad_size_to_crop = pad_size1
tot_pad_size = [(x[0]+y[0], x[1]+y[1]) for x,y in zip(pad_size1, pad_size2)]
img_arr = np.clip(img_arr, -100, 400)
img_arr = (np.float32(img_arr) +100)/(250) - 1

img_arr_padded = np.pad(img_arr, tot_pad_size, mode = 'constant')

zeros_arr = np.zeros(img_arr.shape)
dummy_output_arr_to_be_predicted = np.pad(zeros_arr, pad_size1, mode = 'constant')

In [0]:
inp_itr = BasicVolumeIterator(img_arr_padded, patch_size, out_size)
out_itr = BasicVolumeIterator(dummy_output_arr_to_be_predicted, out_size, out_size)

In [0]:
patch_count = out_itr.get_num_patches()
bar = progressbar.ProgressBar(patch_count).start()
count = 0
print('Patch count : ', patch_count)
while inp_itr.is_not_over():
    input_arr = inp_itr.get_patch().reshape(1,1,116,132,132)
#     print(input_arr.min(), input_arr.max())
    inp_itr.move_coords()
    input_arr = Variable(torch.from_numpy(input_arr).float().cuda(), volatile = True)
    pred_arr = model(input_arr)
    pred_arr = pred_arr.argmax(dim = 1)
    pred_arr = pred_arr.data.cpu().numpy()
    unq = np.unique(pred_arr)
    if count % 100 == 0:
        if unq.max() > 0:
            print(unq)
    out_itr.set_patch(pred_arr[0])
    out_itr.move_coords()
    bar.update(count+1)
    count = count + 1
output_arr = out_itr.vol_array
print(output_arr.shape)
output_arr = crop_pad_width(output_arr, pad_size_to_crop)
# np.save('/mnt/sdb1/intern_data/pix2pix_wbce_pet3d/test_outputs/0_pred.npy', output_arr)

In [0]:
print(img_arr.shape, output_arr.shape, pad_size1)

In [0]:
# extract (ground truth) label array for comparing with prediction

lbl_path = '/mnt/data/LiverCT/Parenchyma/LITS/val/segmentation-110.nii'

lbl = sitk.ReadImage(lbl_path)
lbl = resample(lbl, (1.0, 1.0, 1.0), interpolator = sitk.sitkLinear)

lbl_arr = sitk.GetArrayFromImage(lbl)
lbl_arr[lbl_arr == 2] = 1
lbl_arr = np.uint8(snd.zoom(lbl_arr, zoom = (0.5, 0.5, 0.5), order = 0))

In [0]:
print(lbl_arr.shape)

In [0]:
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

print(np.unique(output_arr))
# print(output_arr.sum())
for idx in range(output_arr.shape[0]):
    if idx %10 == 0:
        print(idx)
        plt.figure()
        slc = output_arr[idx]
        print(np.unique(slc))
        plt.subplot(1,2,1)
        plt.imshow(slc, cmap = 'gray')
        plt.subplot(1,2,2)
        plt.imshow(lbl_arr[idx], cmap = 'gray')
        plt.show()