In [1]:
import sys
sys.path.append('..')
import glob

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.animation as animation
import numpy as np
import torch
from torchvision import transforms
from scipy.ndimage import binary_erosion
from skimage.color import label2rgb
from skimage import io

from mri_segmentation.model import UNet
from mri_segmentation.utils import dice_coefficient

mpl.rcParams.update(mpl.rcParamsDefault)

Running on cpu


In [2]:
validation_folders = [
    '../data/TCGA_DU_8168_19970503/',
    '../data/TCGA_DU_8167_19970402/',
    '../data/TCGA_DU_A5TP_19970614/',
    '../data/TCGA_HT_7856_19950831/',
    '../data/TCGA_FG_5962_20000626/',
]

In [3]:
validation_num = 4

files = glob.glob(validation_folders[validation_num] + '*.tif')

images_names = [name for name in files if not 'mask' in name]
sorted_image_names = sorted(images_names, key=lambda x: (x.split('_')[2], int(x.split('_')[-1].split('.')[0])))


masks_names = [name for name in files if 'mask' in name]
sorted_masks_names = sorted(masks_names, key=lambda x: (x.split('_')[2], int(x.split('_')[-2])))

images = [np.array(io.imread(image)) for image in sorted_image_names]
masks = [np.array(io.imread(mask)) for mask in sorted_masks_names]

print(len(images))

51


In [4]:
model = UNet()
model.load_state_dict(torch.load('../mri_segmentation/weights/model_UNet_num_epochs_85_seed_1234.pt', map_location=torch.device('cpu')))

inputs = torch.zeros((len(images), 3, 256, 256))
targets = torch.zeros((len(images), 256, 256))

for i, (image, mask) in enumerate(zip(images, masks)):
    tensor_image = transforms.ToTensor()(image)
    tensor_mask = transforms.ToTensor()(mask)
    tensor_mask = tensor_mask.long()
    inputs[i] = tensor_image
    targets[i] = tensor_mask

In [5]:
# Predict the segmentation
outputs = model(inputs)
dc = dice_coefficient(outputs, targets)
outputs = np.round(outputs.detach().squeeze(1).numpy())  # Round numbers for the binay mask

In [6]:
def get_image_with_masks(n):
    image = images[n]
    mask = masks[n] - masks[n] * binary_erosion(masks[n])
    pred_mask = outputs[n] - outputs[n] * binary_erosion(outputs[n])
    mask[pred_mask == 1] = 2
    return label2rgb(label=mask, image=image, bg_label=0, kind='overlay', colors=[(2, 0, 0), (0, 2, 0)])  # If colors at 1, low intensity

In [7]:
# Animation
interval = 5000 / len(images) # Make animations same duration (5s)
animation_file = f'../images/prediction_{validation_folders[validation_num].split("/")[2]}.gif'

fig = plt.figure(frameon=False)
mri = get_image_with_masks(0)
h = plt.imshow(mri)
plt.axis('off')
plt.title(f'DC = {dc:.4f}', fontsize=15)
fig.set_size_inches(4, 4, forward=True)

def update(i):
    mri = get_image_with_masks(i)
    h.set_data(mri)
    return h

anim = animation.FuncAnimation(fig, update, frames=range(len(images)), interval=interval)
anim.save(animation_file)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
