In [1]:
# IMPORTS
try:
    import SimpleITK as sitk

    import k3d

    import matplotlib.pyplot as plt
    import matplotlib.animation as animation

    from IPython.display import HTML, display

    import numpy as np
    import pandas as pd

    import os

    from tqdm import tqdm
    
    import etl
    from etl import resample_directory
    from etl import animate_2d_plot
    from etl import validate_bucket_download
    from etl import k3d_plot

    
except ImportError:
    !python3 -m pip install numpy itk vtk torch SimpleITK gsutil k3d tqdm -q --no-warn-script-location
    
    ## MUST RUN THESE TO ENABLE K3D VISUALIZATION!
    !jupyter nbextension install --user --py k3d
    !jupyter nbextension enable k3d --user --py
    
    import IPython
    IPython.Application.instance().kernel.do_shutdown(True)

### Copy images and masks from GCS bucket to local directory structure

In [2]:
# LOCAL DIRECTORY TREE STRUCTURE

#########################
# \_ la-seg             #  <- GCS Bucket
#     \_ hi-res         #  <- Fetched from GCS
#         |_ masks      #
#           \ ...       #
#         \_ images     #
#           \ ...       #
#     \_ lo-res         #  -> Uploaded to GCS
#         |_ masks64    #
#           \ ...       #
#         \_ images64   #
#           \ ...       #
#########################

### DEFINE PATHS
data_path = './la-seg/'                                     # GLOBALVAR

hi_res_path = os.path.join(data_path, 'hi-res')             # GLOBALVAR
hi_res_images_path = os.path.join(hi_res_path, 'images')    # GLOBALVAR
hi_res_masks_path = os.path.join(hi_res_path, 'masks')      # GLOBALVAR

lo_res_path = os.path.join(data_path, 'lo-res')             # GLOBALVAR
lo_res_images_path = os.path.join(lo_res_path, 'images64')  # GLOBALVAR
lo_res_masks_path = os.path.join(lo_res_path, 'masks64')    # GLOBALVAR

def safe_mkdir(path):
    if isinstance(path, list):
        [safe_mkdir(p) for p in path]
    elif not os.path.exists(path):
        os.mkdir(path)

safe_mkdir([data_path,
            hi_res_path, hi_res_images_path, hi_res_masks_path,
            lo_res_path, lo_res_images_path, lo_res_masks_path])

In [3]:
# PREREQUISITE NOTES:
# 1. IF YOU DO NOT HAVE THE GCLOUD SDK INSTALLED, INSTALL IT NOW
# 2. IF GCLOUD IS NOT YET AUTHENTICATED TO YOUR GOOGLE ACCOUNT,
# RUN THIS COMMAND IN A TERMINAL ENVIRONMENT TO AUTHENTICATE:
# !gcloud auth login --activate --launch-browser

project_id = 'medicalimagerepresentation01'
!gcloud config set project {project_id}

# UNCOMMENT THESE LINES TO FETCH DATA FROM GCS; LEAVE COMMENTED IF
# ALREADY FETCHED ONCE. IF FETCHING, BE SURE TO RUN
# remove_undersized_images() BELOW
# !gsutil -m cp -r gs://la-seg/hi-res/masks/*.gz {hi_res_masks_path}
# !gsutil -m cp -r gs://la-seg/hi-res/images/*.gz {hi_res_images_path}

# REMOVE IMAGES WITH DEPTH < 125 mm:
# ./la-seg/hi-res/images/MAIN-751.nii.gz and ./la-seg/hi-res/masks/SEG-751.nii.gz
# ./la-seg/hi-res/images/MAIN-752.nii.gz and ./la-seg/hi-res/masks/SEG-752.nii.gz

# def remove_undersized_images(min_millimeters=(0, 0, 125)):
#     img_files = sorted(os.listdir(hi_res_images_path))
#     msk_files = sorted(os.listdir(hi_res_masks_path))
#     for img_name, msk_name in zip(img_files, msk_files):
#         img_path = os.path.join(hi_res_images_path, img_name)
#         msk_path = os.path.join(hi_res_masks_path, msk_name)

#         img = sitk.ReadImage(img_path)
#         msk = sitk.ReadImage(msk_path)

#         img_size_vox = np.array(img.GetSize())
#         msk_size_vox = np.array(msk.GetSize())

#         assert(np.array_equal(img_size_vox, msk_size_vox))

#         img_spacing = np.array(img.GetSpacing())
#         msk_spacing = np.array(img.GetSpacing())

#         assert(np.array_equal(img_spacing, msk_spacing))

#         img_size_mm = img_size_vox * img_spacing
#         msk_size_mm = msk_size_vox * msk_spacing

#         is_undersized = (np.less(img_size_mm, min_millimeters)).any()

#         if is_undersized:
#             print(f'Removing {img_path} and {msk_path}')
#             os.remove(img_path)
#             os.remove(msk_path)

# remove_undersized_images()

Updated property [core/project].


Updates are available for some Cloud SDK components.  To install them,
please run:
  $ gcloud components update



In [4]:
if validate_bucket_download(hi_res_images_path, hi_res_masks_path):
    print('Bucket download passed all tests.')
else:
    raise ValueError('Bucket download did not pass all tests.')

Bucket download passed all tests.


# Resampling

In [5]:
def resample_masks_and_images(validate=True):
    from etl import resample_directory
    print('Resampling masks...')
    resample_directory(hi_res_masks_path, lo_res_masks_path,
                   image = 'NiftiImageIO', is_label=True,
                   n_jobs=-1)

    print('Resampling images...')
    resample_directory(hi_res_images_path, lo_res_images_path,
                   image = 'NiftiImageIO', is_label=False,
                   n_jobs=-1)

    if not validate:
        return
    
    ### LOOK FOR MASKS AND IMAGE FILES
    masks_list = os.listdir(lo_res_masks_path)
    images_list = os.listdir(lo_res_images_path)

    ############## VERIFY DATA INTEGRITY ###############
    num_masks = len(masks_list)
    num_images = len(images_list) 

    print('Testing data integrity...', end='')

    ### 1. VERIFY MATCHING NUMBER OF MASKS AND IMAGES
    if num_masks != num_images:
        raise FileNotFoundError(f"Unequal number of masks and images in {lo_res_path}")

    masks_list.sort()
    images_list.sort()

    image_mask_pairs = list(zip(images_list, masks_list))

    ### 2. VERIFY 1-1 CORRESPONDENCE BETWEEN MASKS AND IMAGES
    for pair in image_mask_pairs:
        name_1 = ''.join(pair[1].split('.')[0].split('-')[1:])
        name_2 = ''.join(pair[0].split('.')[0].split('-')[1:])
        if (name_1 != name_2):
            msg = f'Incomplete correspondence between masks and images in {lo_res_path}: '
            msg += f'found non-matching (mask, image) pair {name_1, name_2} from files {pair}.'
            raise FileNotFoundError(msg)
    
    ### 3. VERIFY 1-1 CORRESPONDENCE BETWEEN HI-RES AND LO-RES
    hi_res_pairs = list(zip(sorted(os.listdir(hi_res_masks_path)), sorted(os.listdir(hi_res_images_path))))
    
    for hi_res_pair, lo_res_pair in zip(hi_res_pairs, image_mask_pairs):
        lo_res_name_1 = ''.join(lo_res_pair[0].split('.')[0].split('-')[1:])
        lo_res_name_2 = ''.join(lo_res_pair[1].split('.')[0].split('-')[1:])
        hi_res_name_1 = ''.join(hi_res_pair[0].split('.')[0].split('-')[1:])
        hi_res_name_2 = ''.join(hi_res_pair[1].split('.')[0].split('-')[1:])
        
        if lo_res_name_1 != hi_res_name_1 or lo_res_name_2 != hi_res_name_2:
            msg = f'Incomplete correspondence between hi-res and lo-res in {lo_res_path} and {hi_res_path}. '
            msg += f'Hi-res pair {hi_res_pair} does not match {lo_res_pair}. Did you finish resampling?'
            raise FileNotFoundError(msg)

    print('Passed!')
    
    ################# VERIFY LABEL QUANTIZATION ##################
    
    print('Testing label quantization...', end='')
    
    for mask in os.listdir(lo_res_masks_path):
        lo_res_mask_path = os.path.join(lo_res_masks_path, mask)
        sample_mask_lo_res = sitk.ReadImage(lo_res_mask_path)
        unique_values = list(pd.Series(sitk.GetArrayFromImage(sample_mask_lo_res).ravel()).unique())
        if unique_values != [0,1]:
            raise ValueError('Masks arrays not quantized to the set {0, 1}. Did you pass is_label=True to resample_directory()?')

    print(f'Passed! Quantized to {unique_values}')

# if resample_masks_and_images(validate=True):
#     print('Resampling validation passed all tests.')

In [6]:
# sample_mask_lo_res = sitk.ReadImage(os.path.join(lo_res_masks_path, sorted(os.listdir(lo_res_masks_path))[1]))
# sample_mask_hi_res = sitk.ReadImage(os.path.join(hi_res_masks_path, sorted(os.listdir(hi_res_masks_path))[1]))
# sample_image_lo_res = sitk.ReadImage(os.path.join(lo_res_images_path, sorted(os.listdir(lo_res_images_path))[1]))
# sample_image_hi_res = sitk.ReadImage(os.path.join(hi_res_images_path, sorted(os.listdir(hi_res_images_path))[1]))

In [7]:
import etl
from etl import get_images

num_imgs = 45

sample_images_hi_res = []#get_images(hi_res_images_path, n=num_imgs)
sample_masks_hi_res = get_images(hi_res_masks_path, n=num_imgs)

In [8]:
from etl import resampling_pipeline
aug_masks = resampling_pipeline(sample_images_hi_res, sample_masks_hi_res, masks_only=True, n_jobs=1)

 16%|█▋        | 7/43 [00:00<00:00, 66.85it/s]

Resampling mask spacing


100%|██████████| 43/43 [00:00<00:00, 101.21it/s]
3it [00:00, 20.57it/s]

Cropping masks
0: 1, 0, 1
180 180 134
1: 1, 0, 0
180 180 134
2: 1, 0, 0
180 180 134
3: 1, 0, 1
180 180 134
4: 

9it [00:00, 23.41it/s]

0, 0, 0
180 180 134
5: 1, 0, 1
180 180 134
6: 1, 0, 0
180 180 134
7: 1, 1, 1
180 180 134
8: 1, 0, 0
180 180 134
9: 1, 0, 0
180 180 134
10: 

15it [00:00, 24.89it/s]

1, 0, 1
180 180 134
11: 1, 0, 0
180 180 134
12: 1, 0, 1
180 180 134
13: 1, 0, 1
180 180 134
14: 1, 0, 0
180 180 134
15: 1, 0, 0
180 180 134


21it [00:00, 24.79it/s]

16: 1, 0, 0
180 180 134
17: 1, 0, 0
180 180 134
18: 1, 0, 1
180 180 134
19: 1, 0, 0
180 180 134
20: 1, 0, 0
180 180 134
21: 

24it [00:00, 24.47it/s]

1, 0, 0
180 180 134
22: 1, 0, 0
180 180 134
23: 1, 0, 0
180 180 134
24: 1, 0, 0
180 180 134
25: 1, 0, 0
180 180 134
26: 

30it [00:01, 22.60it/s]

1, 0, 0
180 180 134
27: Bounding box on mask 27 does not meet minimum size requirement for cropping; cropping will trespass minimum bounding box.
1, 0, 0
180 180 134
28: 1, 0, 0
180 180 134
29: Bounding box on mask 29 does not meet minimum size requirement for cropping; cropping will trespass minimum bounding box.
1, 0, 0
180 180 134
30: 1, 0, 1
180 180 134
31: 

36it [00:01, 22.42it/s]

1, 0, 0
180 180 134
32: 1, 0, 0
180 180 134
33: Bounding box on mask 33 does not meet minimum size requirement for cropping; cropping will trespass minimum bounding box.
1, 0, 0
180 180 134
34: 0, 1, 0
180 180 134
35: 1, 0, 0
180 180 134
36: 

39it [00:01, 22.50it/s]

0, 0, 0
180 180 134
37: 1, 0, 0
180 180 134
38: 1, 0, 0
180 180 134
39: 1, 0, 0
180 180 134
40: 1, 0, 0
180 180 134
41: 

43it [00:01, 23.86it/s]
100%|██████████| 43/43 [00:00<00:00, 1283.35it/s]

1, 0, 0
180 180 134
42: 1, 0, 0
180 180 134
Resampling mask size





In [9]:
# for i in range(num_imgs)[0:5]:
#     print(f'Augmented mask #{i}')
#     display(HTML(animate_2d_plot(aug_masks[i]).to_jshtml()))

In [10]:
n = 6

In [11]:
k3d_plot(aug_masks[n])#, color_range=[100, 300])

Output()

In [12]:
k3d_plot(etl.resample_image_standardize(aug_masks[n], out_size=(180, 180, 134), is_label=True))

Output()

In [13]:
img = etl.resample_image(sample_masks_hi_res[0], is_label=True)
shape = img.GetSize()
shape

(241, 241, 228)

In [14]:
import IPython
from ipywidgets import interact, interactive, fixed, interact_manual

def imshow_sitk(img):
    return plt.imshow(sitk.GetArrayFromImage(img))

def this_image(img, x0, x1, y0, y1, x):
    shape = img.GetSize()
    
    x1 = min(shape[0], x1)
    y1 = min(shape[1], y1)

    imshow_sitk(img[x0:x1,y0:y1,x])

#   imshow_sitk(etl.resample_image(sample_masks_hi_res[0], is_label=True)[13:193,x,46:180])
img0 = sample_masks_hi_res[n]
shape = img0.GetSize()

interact(this_image, x0 = (0, shape[0]//2), x1 = (shape[0]//2, shape[0]), y0 = (0, shape[1]//2), y1 = (shape[1]//2, shape[1]), img = fixed(img0), x = (0, shape[2]-1))


interactive(children=(IntSlider(value=128, description='x0', max=256), IntSlider(value=384, description='x1', …

<function __main__.this_image(img, x0, x1, y0, y1, x)>

In [15]:
img = etl.resample_image(sample_masks_hi_res[n], is_label=True)
shape = img.GetSize()

interact(this_image, x0 = (0, shape[0]//2), x1 = (shape[0]//2, shape[0]), y0 = (0, shape[1]//2), y1 = (shape[1]//2, shape[1]), img = fixed(img), x = (0, shape[2]-1))

interactive(children=(IntSlider(value=55, description='x0', max=110), IntSlider(value=165, description='x1', m…

<function __main__.this_image(img, x0, x1, y0, y1, x)>

In [16]:
shape = aug_masks[n].GetSize()
interact(this_image, x0 = (0, shape[0]//2), x1 = (shape[0]//2, shape[0]), y0 = (0, shape[1]//2), y1 = (shape[1]//2, shape[1]), img = fixed(aug_masks[n]), x = (0, shape[2]-1))

interactive(children=(IntSlider(value=16, description='x0', max=32), IntSlider(value=48, description='x1', max…

<function __main__.this_image(img, x0, x1, y0, y1, x)>

In [17]:
img3 = etl.resample_image_standardize(img, is_label=True)
shape = img3.GetSize()
interact(this_image, x0 = (0, shape[0]//2), x1 = (shape[0]//2, shape[0]), y0 = (0, shape[1]//2), y1 = (shape[1]//2, shape[1]), img = fixed(img3), x = (0, shape[2]-1))

interactive(children=(IntSlider(value=16, description='x0', max=32), IntSlider(value=48, description='x1', max…

<function __main__.this_image(img, x0, x1, y0, y1, x)>

In [18]:
img2 = etl.resample_image_standardize(etl.resample_image(sample_masks_hi_res[n], is_label=True), is_label=True)
shape = img2.GetSize()
interact(this_image, x0 = (0, shape[0]//2), x1 = (shape[0]//2, shape[0]), y0 = (0, shape[1]//2), y1 = (shape[1]//2, shape[1]), img = fixed(img2), x = (0, shape[2]-1))

interactive(children=(IntSlider(value=16, description='x0', max=32), IntSlider(value=48, description='x1', max…

<function __main__.this_image(img, x0, x1, y0, y1, x)>

In [19]:
img.GetSize()

(220, 220, 175)

## Testing with dummy image

In [20]:
grid = sitk.GridSource(outputPixelType=sitk.sitkUInt16, size=(220,220,175), 
                sigma=(8, 8, 8), gridSpacing=(32,32,32))
k3d_plot(grid)

grid_standard = etl.resample_image_standardize(grid, is_label=True)
shape = grid_standard.GetSize()
interact(this_image, x0 = (0, shape[0]//2), x1 = (shape[0]//2, shape[0]), y0 = (0, shape[1]//2), y1 = (shape[1]//2, shape[1]), img = fixed(grid_standard), x = (0, shape[2]-1))

Output()

interactive(children=(IntSlider(value=16, description='x0', max=32), IntSlider(value=48, description='x1', max…

<function __main__.this_image(img, x0, x1, y0, y1, x)>

# Visualizations


In [21]:
from etl import k3d_plot

In [22]:
heart_id = 1

## Segmentations
### Lo-res

In [23]:
#k3d_plot(sample_masks_lo_res[heart_id])

### Hi-res

In [24]:
#k3d_plot(sample_masks_hi_res[heart_id])

## Images
### Lo-res

In [25]:
#k3d_plot(sample_images_lo_res[heart_id], color_range=[100, 300])

### Hi-res

In [26]:
#k3d_plot(sample_images_hi_res[heart_id], color_range=[100, 300]) 