In [54]:
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, MaxPooling3D, Dropout, Flatten, concatenate, Reshape, UpSampling3D, Lambda, Conv3D, Conv3DTranspose
from tensorflow.keras.layers import BatchNormalization
import tensorflow as tf

In [55]:
voxel_grid_input = Input(shape=(600, 100, 100, 1), dtype='float32')  # (batch/None, depth, height, width, channels)
cond_vec_input = Input(shape=(4,1), dtype='float32')

In [56]:
enc_model = Sequential()

enc_model.add(Conv3D(16, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
enc_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
enc_model.add(BatchNormalization(center=True, scale=True))
enc_model.add(Dropout(0.5))

enc_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
enc_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
enc_model.add(BatchNormalization(center=True, scale=True))
enc_model.add(Dropout(0.5))

enc_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
enc_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
enc_model.add(BatchNormalization(center=True, scale=True))
enc_model.add(Dropout(0.5))

enc_model.add(Flatten())

enc_model.add(Dense(8, activation='relu'))

encoded_box = enc_model(voxel_grid_input)


In [63]:
deformed_box_repr = concatenate([encoded_box, tf.squeeze(cond_vec_input, axis=-1)], axis=-1)

<dtype: 'float32'> <dtype: 'float32'>
<dtype: 'float32'>


In [58]:
deformed_box_vec_input = Input(shape=(deformed_box_repr.shape[-1], ), dtype='float32')

dec_model = Sequential()

dec_model.add(Dense(2*2*3*1*1, activation='relu'))
dec_model.add(Reshape((2,2,3,1)))

dec_model.add(Conv3DTranspose(16, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
dec_model.add(UpSampling3D())
dec_model.add(Conv3DTranspose(8, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
dec_model.add(UpSampling3D())
dec_model.add(Conv3DTranspose(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
dec_model.add(UpSampling3D())
dec_model.add(Conv3DTranspose(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
dec_model.add(UpSampling3D(size=(2, 1, 1)))
dec_model.add(Reshape((600, 100, 100)))

decoded_deformed_box = dec_model(deformed_box_repr)

### Finally, combine everything into a single VAE model.

In [59]:
vae_model = Model(inputs=[voxel_grid_input, cond_vec_input], outputs=decoded_deformed_box)

VAE model is built. Now we need to define a way to evaluate its performace and enable it to learn. 

In [60]:
ground_truth_voxel_grid = Input(shape=(600, 100, 100, 1))

discr_model = Sequential()

discr_model.add(Conv3D(16, (3, 3, 3), strides=(2, 2, 2), activation='relu')) 
discr_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
discr_model.add(BatchNormalization(center=True, scale=True))
discr_model.add(Dropout(0.5))

discr_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
discr_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
discr_model.add(BatchNormalization(center=True, scale=True))
discr_model.add(Dropout(0.5))

discr_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
discr_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
discr_model.add(BatchNormalization(center=True, scale=True))
discr_model.add(Dropout(0.5))

discr_model.add(Flatten())
discr_model.add(Dense(1))

discr_output = discr_model(ground_truth_voxel_grid)

# Need to make our model of class tf...training.Model in order to use Checkpoint later (because the model class must inherit from a Trackable base).
discriminator_model = Model(inputs=[ground_truth_voxel_grid], outputs=discr_output)

In [None]:
DBG = True
if DBG:
    print(enc_model.summary(), '\n\n')

    print(discriminator_model.summary(), '\n\n')

    print(dec_model.summary(), '\n\n')

    print(vae_model.summary())

In [61]:
DBG = True
if DBG:
    print(enc_model.summary(), '\n\n')

    print(discriminator_model.summary(), '\n\n')

    print(dec_model.summary(), '\n\n')

    print(vae_model.summary())

Model: "sequential_16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv3d_36 (Conv3D)           (None, 299, 49, 49, 16)   448       
_________________________________________________________________
max_pooling3d_36 (MaxPooling (None, 149, 24, 24, 16)   0         
_________________________________________________________________
batch_normalization_36 (Batc (None, 149, 24, 24, 16)   64        
_________________________________________________________________
dropout_36 (Dropout)         (None, 149, 24, 24, 16)   0         
_________________________________________________________________
conv3d_37 (Conv3D)           (None, 74, 11, 11, 8)     3464      
_________________________________________________________________
max_pooling3d_37 (MaxPooling (None, 37, 5, 5, 8)       0         
_________________________________________________________________
batch_normalization_37 (Batc (None, 37, 5, 5, 8)     