In [1]:
%matplotlib inline

In [2]:

import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
from scipy import signal, ndimage, misc

from keras.layers import Input, Dense, Lambda, Flatten, Reshape, BatchNormalization, Dropout, GaussianNoise
from keras.layers import Convolution2D, Deconvolution2D, MaxPooling2D, UpSampling2D
from keras.layers import Convolution3D, UpSampling3D, MaxPooling3D
from keras.models import Model
from keras import regularizers
from keras import backend as K_backend
from keras import objectives
from keras.datasets import mnist


Using TensorFlow backend.


In [3]:
from matplotlib.pyplot import imshow
from skimage import measure, morphology
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

In [4]:
def nimshow(img):
    plt.imshow(img, interpolation='none')
    
def conv(img, kern, renorm=False):
    ift = signal.fftconvolve(img, kern, 'same')
    if renorm:
        ift /= np.amax(ift)
    return ift

def ricker2d(x, y, f=np.pi, n=0.5):
    r = (x**2 + y**2)**n
    return (1.0 - 2.0*(np.pi**2)*(f**2)*(r**2)) * np.exp(-(np.pi**2)*(f**2)*(r**2))

def gauss2d(x, y, f=1, sig=1, n=0.5):
    r = (x**2 + y**2)**n
    return np.exp(-((f*r)**2)/(.25*sig**2))

In [5]:
def plot_3d(image, threshold=-300, azim=45, elev=45):
    
    # Position the scan upright, 
    # so the head of the patient would be at the top facing the camera
    p = image.transpose(2,1,0)
#     p = p[:,:,::-1]
    
    verts, faces = measure.marching_cubes(p, threshold)

    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Fancy indexing: `verts[faces]` to generate a collection of triangles
    mesh = Poly3DCollection(verts[faces], alpha=0.1)
    face_color = [0.5, 0.5, 1]
    mesh.set_facecolor(face_color)
    ax.add_collection3d(mesh)
    ax.azim = azim
    ax.elev = elev

    ax.set_xlim(0, p.shape[0])
    ax.set_ylim(0, p.shape[1])
    ax.set_zlim(0, p.shape[2])

    plt.show()
    
def splay(volume, rows=5, cols=5):
    fig = plt.figure(figsize=(10, 10))
    n = rows*cols
    for i in range(rows):
        for j in range(cols):
            fig.add_subplot(rows, cols, cols*i+j+1)
            plt.imshow(volume[cols*i+j])



In [6]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
# x_train = np.reshape(x_train, (len(x_train), 1, 28, 28))
# x_test = np.reshape(x_test, (len(x_test), 1, 28, 28))

In [7]:
x_train.shape

(60000, 28, 28)

In [9]:
x_train3d = np.load('x_train_mnist3d_1k.npy')
x_train3d.shape

(1000, 28, 28, 28)

In [12]:
class Convo3d(object):
    def __init__(self):
        """
        5D tensor with shape: (samples, channels, conv_dim1, conv_dim2, conv_dim3) if dim_ordering='th' or 
        5D tensor with shape: (samples, conv_dim1, conv_dim2, conv_dim3, channels) if dim_ordering='tf'.
        """
        input_img = Input(shape=(1, 28, 28, 28)) # (nChan, nFrames, xPix, yPix) or (nChan, z, x, y)

        x = Convolution3D(16, 3, 3, 3, activation='relu', border_mode='same')(input_img)
        x = MaxPooling3D((2, 2, 2), border_mode='same')(x)
        x = Convolution3D(8, 3, 3, 3, activation='relu', border_mode='same')(x)
        x = MaxPooling3D((2, 2, 2), border_mode='same')(x)
        x = Convolution3D(8, 3, 3, 3, activation='relu', border_mode='same')(x)
        self.encoded = MaxPooling3D((2, 2, 2), border_mode='same')(x)

        # at this point the representation is (8, 4, 4) i.e. 128-dimensional

        x = Convolution3D(8, 3, 3, 3, activation='relu', border_mode='same')(encoded)
        x = UpSampling3D((2, 2, 2))(x)
        x = Convolution3D(8, 3, 3, 3, activation='relu', border_mode='same')(x)
        x = UpSampling3D((2, 2, 2))(x)
        x = Convolution3D(16, 3, 3, 3, activation='relu')(x)
        x = UpSampling3D((2, 2, 2))(x)
        self.decoded = Convolution3D(1, 3, 3, 3, activation='sigmoid', border_mode='same')(x)
        
        self.autoencoder = Model(input_img, self.decoded)
        
        

In [None]:
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')