# PET Signal for Labelling
A script to show how to use the PET Signal for labeling tumors.
Based on https://www.kaggle.com/kmader/d/4quant/soft-tissue-sarcoma/superpixels-on-petct-for-labeling

In [None]:
import os
import h5py
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
%matplotlib inline

from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries

# utility functions
make_proj = lambda x: np.sum(x,1)[::-1]
make_mip = lambda x: np.max(x,1)[::-1]

# Loading and Displaying PET and CT
Here we load the PET, CT and Label data from a _single_ patient and show the projection image for CT, the Maximum Intensity Projection (MIP) view for the PET data and the label data.

In [None]:
patient_index = 2
with h5py.File(os.path.join('..', 'input', 'lab_petct_vox_5.00mm.h5'), 'r') as p_data:
    print('Available keys', list(p_data.keys()))
    id_list = list(p_data['ct_data'].keys())
    ct_image = p_data['ct_data'][id_list[patient_index]].value
    pet_image = p_data['pet_data'][id_list[patient_index]].value
    label_image = p_data['label_data'][id_list[patient_index]].value
    # ask kevin why label_image as below
    #label_image = (p_data['label_data'][id_list[0]].value>0).astype(np.uint8)

### Projection
Show the projection of the images. This means show a 2D view of the 3D image data.

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize = (12, 4))
ct_proj = make_proj(ct_image)
suv_max = make_mip(pet_image)
lab_proj = make_proj(label_image)

ax1.set_title('CT Image projection')
ax1.imshow(ct_proj, cmap = 'bone')

ax2.set_title('SUV Image projection')
ax2.imshow(np.sqrt(suv_max), cmap = 'magma')

ax3.set_title('Tumor Labels projection')
ax3.imshow(lab_proj, cmap = 'magma')

## Full 3D Superpixels
Here we make full 3D superpixels for PETCT and show a simple rendering of them

In [None]:
pet_weight = 1.0 # how strongly to weight the pet_signal (1.0 is the same as CT)
# based on experience how a radiologist is using the PET window
pet_window = 5
petct_vol = np.stack([np.stack([(ct_slice+1024).clip(0,2048)/2048, 
                                 pet_weight*(suv_slice).clip(0, pet_window)/pet_window], -1) 
                      for ct_slice, suv_slice in zip(ct_image, pet_image)], 0)

In [None]:
%%time
petct_segs = slic(petct_vol, 
                  n_segments = 5050,
                  compactness = 1,
                  multichannel = True)

In [None]:
petct_max_segs = make_mip(petct_segs)
ct_proj = make_proj(petct_vol[:,:,:,0])
suv_mip = make_mip(petct_vol[:,:,:,1])

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 6))
ax1.imshow(suv_mip, cmap='magma')
ax1.set_title('SUV Image')
ax2.imshow(petct_max_segs, cmap='gray')
ax2.set_title('Segmented Image')
ax3.imshow(mark_boundaries(suv_mip, petct_max_segs))

## Compare Segments to Labels

We look at each superpixel and see how many different labels are inside it. We want each superpixel to be an 'atomic' unit of the image.

In [None]:
# Count label pixels
unique_labels = np.unique(label_image)
for i in unique_labels:
    u,c = np.unique(label_image[label_image == i], return_counts=True)
    print('label', i, 'pixel count:\t', c[0])

In [None]:
print('Distinct values in CT:\t', '{:,}'.format(len(np.unique(ct_image))))
print('Distinct values in PET:\t','{:,}'.format(len(np.unique(pet_image))))
print('Distinct values in Label:\t', len(np.unique(label_image)))

In [None]:
for idx in np.unique(petct_segs):
    cur_region_mask = (petct_segs == idx)
    labels_in_region = label_image[cur_region_mask]
    labeled_region_inside = np.unique(labels_in_region)
    if len(labeled_region_inside) > 1:
        print(labeled_region_inside)
        print('\nSuperpixel id', idx, 'regions', len(labeled_region_inside))
        print(pd.value_counts(labels_in_region))
        print('Missclassified Pixels:', np.sum(pd.value_counts(labels_in_region)[1:].values))

In [None]:
nz_labels = [i for i in np.unique(label_image) if i>=0]
fig, m_axs = plt.subplots(len(nz_labels), 2, figsize = (5, 15))
for (ax1, ax2), i_label in zip(m_axs, nz_labels):
    out_sp = np.zeros_like(petct_segs)
    cur_label_mask = label_image == i_label
    labels_in_region = petct_segs[cur_label_mask]
    
    superpixels_in_region = np.unique(labels_in_region)
    for i, sp_idx in enumerate(superpixels_in_region):
        out_sp[petct_segs == sp_idx] = i+1
    
    ax1.imshow(make_proj(cur_label_mask), cmap = 'bone')
    ax1.set_title('Label Map {}'.format(i_label) if i_label>0 else 'Background Label')
    ax1.axis('off')
    
    ax2.imshow(make_proj(out_sp), cmap = 'gist_earth')
    ax2.set_title('Superpixels ({})'.format(len(superpixels_in_region)))
    ax2.axis('off')

## Show the superpixels for each label
Here we can show which superpixels are inside each label.

In [None]:
for idx in np.unique(label_image):
    cur_region_mask = label_image == idx
    labels_in_region = petct_segs[cur_region_mask]
    labeled_regions_inside = np.unique(labels_in_region)
    print('Label id:', idx, 'superpixels inside', len(labeled_regions_inside))

# Optimize Superpixel Size

In [None]:
def label_score(gt_labels, sp_segs):
    # type: (np.ndarray, np.ndarray) -> float
    """
    Score how well the superpixels match to the ground truth labels. 
    Here we use a simple penalty of number of pixels misclassified
    :param gt_labels: the ground truth labels (from an annotation tool)
    :param sp_segs: the superpixel segmentation
    :return: the score (lower is better)
    """
    out_score = 0
    for idx in np.unique(sp_segs):
        cur_region_mask = sp_segs == idx
        labels_in_region = gt_labels[cur_region_mask]
        if np.sum(labels_in_region) > 0:
            out_score += np.sum(pd.value_counts(labels_in_region)[1:].values)
    return out_score

In [None]:
# Make new superpixels
def make_superpixel(pet_weight = 1.0, # how strongly to weight the pet_signal (1.0 is the same as CT)
                    n_segments = 1000, # number of segments
                    compactness = 1): # how compact the segments are
    
    t_petct_vol = np.stack([np.stack([(ct_slice+1024).clip(0,2048)/2048, 
                            pet_weight*(suv_slice).clip(0,5)/5.0
                           ],-1) for ct_slice, suv_slice in zip(ct_image, pet_image)],0)
    petct_segs = slic(t_petct_vol, 
                      n_segments=n_segments, 
                      compactness=compactness,
                      multichannel=True)
    return petct_segs

def make_and_score(*args, **kwargs):
    n_segs = make_superpixel(*args, **kwargs)
    return label_score(label_image, n_segs)

def f_make(n):
    print('calling f_make with:', n)
    return make_and_score(n_segments=int(n*1000))

In [None]:
# test different values for n_segments to see how the performance changes
n_segments = range(500, 2000, 100)
n_score = [make_and_score(n_segments = c_seg) for c_seg in n_segments]
plt.plot(n_segments, n_score, 'r*')
plt.show()

In [None]:
# Optimize the values
from scipy.optimize import fmin

fmin(f_make, x0=[1], full_output=True, disp=True)