# Implementation of a 3D Unet Architecture


In [2]:
import numpy as np 
import matplotlib.pyplot as plt 
import os, time
from importlib import reload

# 3D visualization tools
from mayavi import mlab
mlab.init_notebook(backend='ipy')

import tensorflow as tf
import model, utilities

Notebook initialized with ipy backend.


In [3]:
# Import modules providing tools for image manipulation
import sys
sys.path.append('../tools/')
import mosaic, deformation

In [67]:
reload(utilities)
reload(model)
reload(deformation)

<module 'deformation' from '../tools\\deformation.py'>

In [20]:
# Fix for tensorflow-gpu issues that I found online... (don't ask me what it does)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

1 Physical GPUs, 1 Logical GPUs


In [4]:
# Get permissible network input sized (cube lengths)
n_blocks = 2
valid_inputs = [n for n in range(512) if utilities.check_size(n, n_blocks=n_blocks)[0]]
print(valid_inputs)
print('Output shape at {} is {}'.format(252,utilities.check_size(252, n_blocks)[1]))

[92, 100, 108, 116, 124, 132, 140, 148, 156, 164, 172, 180, 188, 196, 204, 212, 220, 228, 236, 244, 252, 260, 268, 276, 284, 292, 300, 308, 316, 324, 332, 340, 348, 356, 364, 372, 380, 388, 396, 404, 412, 420, 428, 436, 444, 452, 460, 468, 476, 484, 492, 500, 508]
Output shape at 252 is 164.0


In [21]:
unet = model.Unet(n_blocks=2, initial_filters=8)

In [23]:
unet.summary()

Model: "Unet"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_block (InputBlock)     multiple                  3696      
_________________________________________________________________
downsample_block (Downsample multiple                  20784     
_________________________________________________________________
downsample_block_1 (Downsamp multiple                  83040     
_________________________________________________________________
bottleneck_block (Bottleneck multiple                  463168    
_________________________________________________________________
upsample_block (UpsampleBloc multiple                  475328    
_________________________________________________________________
upsample_block_1 (UpsampleBl multiple                  118880    
_________________________________________________________________
output_block (OutputBlock)   multiple                  27714  

In [29]:
# Load some slices from the dataset

# Locate the sample directory on the computer
base_dir = 'C:\\Users\\Linus Meienberg\\Documents\\ML Datasets\\FruSingleNeuron_20190707\\SampleCrops'
samples = os.listdir(base_dir)

#Load the first sample
sample = utilities.load_volume(os.path.join(base_dir, samples[0]))

In [47]:
utilities.show3DImage(sample['image'])

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xf4\x00\x00\x01\xf4\x08\x02\x00\x00\x00D\xb4H\xd…

In [54]:
utilities.show3DImage(sample['mask'], mode='mask')

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xf4\x00\x00\x01\xf4\x08\x02\x00\x00\x00D\xb4H\xd…

In [33]:
np.histogram(sample['mask'], bins=3)
#NOTE The mask seems to contain integer values up to 3 how does that come? what does it signify? 

(array([28756806,  3184260,     2934], dtype=int64),
 array([0., 1., 2., 3.], dtype=float32))

In [71]:
displacement = deformation.displacementGridField3D((220,220,220), n_lines=5)
sample_img_deformed = deformation.applyDisplacementField3D(sample['image'], *displacement)
sample_mask_deformed = deformation.applyDisplacementField3D(sample['mask'], *displacement)

In [72]:
utilities.show3DImage(sample_img_deformed)

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xf4\x00\x00\x01\xf4\x08\x02\x00\x00\x00D\xb4H\xd…

In [70]:
displacement = deformation.displacementGridField3D((220,220,220), n_lines=5)
mlab.figure(size=(500,500))
#plot = mlab.contour3d(sample_img_deformed[:,:,:,0],contours = 4, transparent=True)
plot = mlab.pipeline.vector_field(*displacement)
mlab.pipeline.vectors(plot, mask_points=2000, scale_factor=50.)
plot

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01\xf4\x00\x00\x01\xf4\x08\x02\x00\x00\x00D\xb4H\xd…

In [36]:
pred = unet(np.expand_dims(sample_img,axis=0))

In [37]:
pred.shape

TensorShape([1, 132, 132, 132, 2])