In [None]:
import os
import numpy as np
import skimage.measure

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
EPOCH = 100  

DATA_DIR = './data/preprocessed_masks/lung_region/'
PROB_DIR = './data/preprocessed_masks/nodule_prob_epoch_%08d/' % EPOCH

MAX_SIZE = 4/3*np.pi*(20**3)
PROB_THRESHOLD = 0.8
THRESHOLD_STEP = 0.05

TOP_NODULE_NUM = 20
CROP_SIZE = 64

In [None]:
UID = '87cdb87db24528fdb8479220a1854b83'

In [None]:
# Load the image and the lung mask
f = np.load(os.path.join(DATA_DIR, UID + '_lung_region.npz'))
lung_img, lung_mask = f['lung_img'], f['lung_mask']
lung_img[~lung_mask] = -1024
f.close()

# Load the 3D probability map
f = np.load(os.path.join(PROB_DIR, UID + '_nodule_prob.npz'))
prob3d = f['prob3d']
f.close()

Z, Y, X = lung_img.shape
print('CT scan size in ZYX:', lung_img.shape)

In [None]:
threshold = PROB_THRESHOLD
nodule_mask = (prob3d >= threshold)

# Find connected regions
# Get the cluster id (integer) of each voxel, and number of clusters
cluster_ids, cluster_num = skimage.measure.label(nodule_mask, return_num=True, background=0)
# the cluster num above is equal to the maximum in cluster_ids
# so add 1 get the "actual" cluster number (including the bg)
cluster_num += 1
print('number of clusters found:', cluster_num)

In [None]:
# i = 1
# rm_inds = []
# split_ind = [0]*cluster_num
# while i < cluster_num:
#     # print(i)
#     volumn = np.sum(cluster_ids == i)
#     if(volumn > MAX_SIZE):
#         new_threshold = PROB_THRESHOLD + THRESHOLD_STEP * (split_ind[i]+1)
#         # print('splitting cluster %d with volumn %d, threshold %f' % (i, volumn, new_threshold))
        
#         new_mask = np.logical_and(cluster_ids == i, prob3d >= new_threshold)
#         new_cluster_ids, new_cluster_num = skimage.measure.label(new_mask, return_num=True, background=0)
#         # print('number of new clusters:', new_cluster_num)
        
#         # update the cluster indices: new clusters are appended at the end
#         split_ind += [split_ind[i] + 1] * new_cluster_num
#         cluster_ids[new_cluster_ids >= 1] = cluster_num + new_cluster_ids[new_cluster_ids >= 1] - 1

#         # remove the old cluster
#         cluster_ids[cluster_ids == i] = 0  
#         rm_inds.append(i)
#         cluster_num += new_cluster_num
    
#     i += 1

In [None]:
# take the top N cluster, sorted by probability weights

# compute the sum of nodule probability within the each cluster
prob_sums = np.array([np.sum(prob3d[cluster_ids == i]) for i in range(cluster_num)])
prob_sums[0] = 0  # skip 0, the background
topN_cluster_inds = np.argsort(prob_sums)[::-1][:TOP_NODULE_NUM]

# # See how much mess the top-N encodes
# prob_sum_topN = np.sum(prob_sums[topN_cluster_inds])
# prob_sum_thresh = np.sum(prob3d[prob3d >= PROB_THRESHOLD])
# print('fraction of probs in the top %d: %f' % (TOP_NODULE_NUM, prob_sum_topN/prob_sum_thresh))

# compute the centroid of each cluster in the top-N and take crops
half_size = CROP_SIZE // 2
topN_centroids = np.zeros((TOP_NODULE_NUM, 3))
topN_lung_img_crop = np.zeros((TOP_NODULE_NUM, CROP_SIZE, CROP_SIZE, CROP_SIZE))
topN_prob3d_crop = np.zeros((TOP_NODULE_NUM, CROP_SIZE, CROP_SIZE, CROP_SIZE))
Z_mesh, Y_mesh, X_mesh = np.meshgrid(np.arange(Z), np.arange(Y), np.arange(X), indexing='ij')
# # check whether the meshgrid is currenctly computed
# z, y, x = 20, 30, 40
# assert(np.all(Z_mesh[z, :, :] == z))
# assert(np.all(Y_mesh[:, y, :] == y))
# assert(np.all(X_mesh[:, :, x] == x))
for n_cluster, i in enumerate(topN_cluster_inds):
    cluster_prob3d = prob3d * (cluster_ids == i)
    # normalize to have sum equal to 1
    cluster_prob3d = cluster_prob3d / np.sum(cluster_prob3d)
    z = int(np.sum(Z_mesh*cluster_prob3d))
    y = int(np.sum(Y_mesh*cluster_prob3d))
    x = int(np.sum(X_mesh*cluster_prob3d))
    z = np.minimum(np.maximum(z, half_size), Z-half_size)
    y = np.minimum(np.maximum(y, half_size), Y-half_size)
    x = np.minimum(np.maximum(x, half_size), X-half_size)
    topN_centroids[n_cluster] = [z, y, x]
    
    # Take crop from the centroid
    z_begin, z_end = z - half_size, z + half_size
    y_begin, y_end = y - half_size, y + half_size
    x_begin, x_end = x - half_size, x + half_size
    topN_lung_img_crop[n_cluster] = lung_img[z_begin:z_end, y_begin:y_end, x_begin:x_end]
    topN_prob3d_crop[n_cluster] = prob3d[z_begin:z_end, y_begin:y_end, x_begin:x_end]

# Visualizing the results

In [None]:
zs = [z for z, y, x in topN_centroids]
ys = [y for z, y, x in topN_centroids]
xs = [x for z, y, x in topN_centroids]

In [None]:
plt.close('all')

z = 125
plt.figure()
plt.imshow(lung_img[z, :, :], vmin=-1000, vmax=400, cmap=plt.cm.bone)
plt.colorbar()

plt.figure()
plt.imshow(prob3d[z, :, :], vmin=0, vmax=1)
plt.colorbar()
plt.plot(xs, ys, 'ro')

plt.figure()
plt.imshow(np.sum(prob3d, axis=0))
plt.colorbar()
plt.plot(xs, ys, 'ro')

In [None]:
plt.close('all')

for z in range(100, Z-50, 25):
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(lung_img[z, :, :], vmin=-1000, vmax=400, cmap=plt.cm.bone)
    plt.title('lung image (slice along z=%d)' % z)
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(prob3d[z, :, :])
    plt.title('prob3d (slice along z=%d)' % z)
    plt.colorbar()
    plt.plot(xs, ys, 'ro')

plt.figure(figsize=(10, 4))
plt.imshow(np.sum(prob3d, axis=0))
plt.title('prob3d (sum along axis z)')
plt.colorbar()
plt.plot(xs, ys, 'ro')

In [None]:
plt.close('all')

for y in range(50, Y-50, 25):
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(lung_img[:, y, :], vmin=-1000, vmax=400, cmap=plt.cm.bone)
    plt.title('lung image (slice along y=%d)' % y)
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(prob3d[:, y, :])
    plt.title('prob3d (slice along y=%d)' % y)
    plt.colorbar()
    plt.plot(xs, zs, 'ro')

plt.figure(figsize=(10, 4))
plt.imshow(np.sum(prob3d, axis=1))
plt.title('prob3d (sum along axis y)')
plt.colorbar()
plt.plot(xs, zs, 'ro')

In [None]:
plt.close('all')

for x in range(50, X-50, 25):
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(lung_img[:, :, x], vmin=-1000, vmax=400, cmap=plt.cm.bone)
    plt.title('lung image (slice along x=%d)' % x)
    plt.colorbar()

    plt.subplot(1, 2, 2)
    plt.imshow(prob3d[:, :, x])
    plt.title('prob3d (slice along x=%d)' % x)
    plt.colorbar()
    plt.plot(ys, zs, 'ro')

plt.figure(figsize=(10, 4))
plt.imshow(np.sum(prob3d, axis=2))
plt.title('prob3d (sum along axis x)')
plt.colorbar()
plt.plot(ys, zs, 'ro')