In [1]:
import math
import glob
import time
import json
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Dense, MaxPooling3D, Dropout, Flatten, concatenate, Reshape, UpSampling3D, Lambda, Conv3D, Conv3DTranspose
from tensorflow.keras.layers import BatchNormalization

# from tensorflow_probability.python.layers import MixtureNormal
from tensorflow.keras.losses import BinaryCrossentropy

In [2]:
from sample_data_structure import get_surface_nodes, get_voxel_shapes_from_nodes_dicts, get_conditions, get_surface_nodes_dispositions, get_disposed_nodes, reshape_voxel_grid_into_np

Variational AutoEncoder is a combination of an Encoder and a Decoder learning to encode the input instances into a (less dimentional) regularized latent space and to decode them back with minimal reconstruction error. 

Having a trained VAE we can sample new instances from distribution in the latent space, decompress them and get a quite realistic data instance.

For our problem, we use encoder to get the representation of an input shape in the latent space, then concatenate it with the condition vector (e.g. [% coords of the force application point, relative size of applicator, box thickness]) so that it represents a deformed shape in the latent space - and then decode deformed shape representation into a voxel shape. 

## Build the model


voxel_grid_input = Input(shape=(600, 100, 100, 1))  # (batch/None, depth, height, width, channels)
cond_vec_input = Input(shape=(4,))

In [4]:
enc_model = Sequential()

enc_model.add(Conv3D(16, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
enc_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
enc_model.add(BatchNormalization(center=True, scale=True))
enc_model.add(Dropout(0.5))

enc_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
enc_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
enc_model.add(BatchNormalization(center=True, scale=True))
enc_model.add(Dropout(0.5))

enc_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
enc_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
enc_model.add(BatchNormalization(center=True, scale=True))
enc_model.add(Dropout(0.5))

enc_model.add(Flatten())

enc_model.add(Dense(8, activation='relu'))
# enc_model.add(Dense(8, activation='relu', activity_regularizer=tfp.python.layers.MixtureNormal))


encoded_box = enc_model(voxel_grid_input)

print(enc_model.summary())

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv3d (Conv3D)              (None, 299, 49, 49, 16)   448       
_________________________________________________________________
max_pooling3d (MaxPooling3D) (None, 149, 24, 24, 16)   0         
_________________________________________________________________
batch_normalization (BatchNo (None, 149, 24, 24, 16)   64        
_________________________________________________________________
dropout (Dropout)            (None, 149, 24, 24, 16)   0         
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 74, 11, 11, 8)     3464      
_________________________________________________________________
max_pooling3d_1 (MaxPooling3 (None, 37, 5, 5, 8)       0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 37, 5, 5, 8)       3

Now that we have a way to encode sample shapes, we concatenate this representation vector with a condition vector to form a latent representation of a deformed shape.

In [5]:
deformed_box_repr = concatenate([encoded_box, cond_vec_input], axis=-1)

#### Next, Decoder (future Generator):

In [6]:
deformed_box_vec_input = Input(shape=(deformed_box_repr.shape[-1], ))

dec_model = Sequential()

dec_model.add(Dense(2*2*3*1*1, activation='relu'))
dec_model.add(Reshape((2,2,3,1)))

dec_model.add(Conv3DTranspose(16, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
dec_model.add(UpSampling3D())
dec_model.add(Conv3DTranspose(8, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
dec_model.add(UpSampling3D())
dec_model.add(Conv3DTranspose(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
dec_model.add(UpSampling3D())
dec_model.add(Conv3DTranspose(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
dec_model.add(Reshape((600, 100, 100)))

decoded_deformed_box = dec_model(deformed_box_repr)

dec_model.summary()

assert(len(dec_model.input_shape) == len(deformed_box_repr.shape) == 2)
assert(dec_model.input_shape[-1] == deformed_box_repr.shape[-1])

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_1 (Dense)              (None, 12)                156       
_________________________________________________________________
reshape (Reshape)            (None, 2, 2, 3, 1)        0         
_________________________________________________________________
conv3d_transpose (Conv3DTran (None, 5, 5, 7, 16)       448       
_________________________________________________________________
up_sampling3d (UpSampling3D) (None, 10, 10, 14, 16)    0         
_________________________________________________________________
conv3d_transpose_1 (Conv3DTr (None, 21, 21, 29, 8)     3464      
_________________________________________________________________
up_sampling3d_1 (UpSampling3 (None, 42, 42, 58, 8)     0         
_________________________________________________________________
conv3d_transpose_2 (Conv3DTr (None, 85, 85, 117, 8)   

### Finally, combine everything into a single VAE model.

In [7]:
vae_model = Model(inputs=[voxel_grid_input, cond_vec_input], outputs=decoded_deformed_box)

VAE model is built. Now we need to define a way to evaluate its performace and enable it to learn. 

In [8]:
ground_truth_voxel_grid = Input(shape=(600, 100, 100, 1))

discr_model = Sequential()

discr_model.add(Conv3D(16, (3, 3, 3), strides=(2, 2, 2), activation='relu', input_shape=(600, 100, 100, 1))) 
discr_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
discr_model.add(BatchNormalization(center=True, scale=True))
discr_model.add(Dropout(0.5))

discr_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='relu'))
discr_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
discr_model.add(BatchNormalization(center=True, scale=True))
discr_model.add(Dropout(0.5))

discr_model.add(Conv3D(8, (3, 3, 3), strides=(2, 2, 2), activation='sigmoid'))
discr_model.add(MaxPooling3D(pool_size=(2, 2, 2)))
discr_model.add(BatchNormalization(center=True, scale=True))
discr_model.add(Dropout(0.5))

discr_model.add(Flatten())
discr_model.add(Dense(1))

discr_output = discr_model(ground_truth_voxel_grid)

# Need to make our model of class tf...training.Model in order to use Checkpoint later (because the model class must inherit from a Trackable base).
discriminator_model = Model(inputs=[ground_truth_voxel_grid], outputs=discr_output)

# print(discr_model.summary())

## Train our model
using https://www.tensorflow.org/tutorials/generative/dcgan

In [9]:
CONFIG = {
    'num_epochs': 50,
    'test_percent': 0.2,
    'num_samples_total': 383,
    'num_samples_using': 5,
    'batch_size': 3
}

\* cross-entropy loss (aka log loss) is measuring difference between probabilities.

In [10]:
def generator_loss(discriminator_decision):
    """
    Compute cross-entropy loss between the discriminator's prediction on generated shapes only, and an array of ones.
    """
    return BinaryCrossentropy(tf.ones_like(discriminator_decision), discriminator_decision)


def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss


def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)


generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=vae_model,
                                 discriminator=discriminator_model)

In [11]:
@tf.function
def train_step(x_shape, x_cond, ground_truth):

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        print("vae_model.input_shape: ", vae_model.input_shape)
        print("x [0] = ", len(x_shape[0]))
        print("X: {} x {} x {} --- Cond: {}".format(len(x_shape), len(x_shape[0]), len(x_shape[0][0]), len(x_cond)))
        print(type(x_shape), type(x_cond))
        
        generated_deformation = vae_model([[x_shape], [x_cond]], training=True)

        real_output = disc_model(ground_truth, training=True)
        fake_output = disc_model(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

        print(wasserstein_loss(real_output, fake_output))

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

In [30]:
def train(X, X_cond, Y):
    for epoch in range(CONFIG['num_epochs']):
        start = time.time()

        print(X)
        

        X, X_cond, Y = X.batch(CONFIG['batch_size']), X_cond.batch(CONFIG['batch_size']), Y.batch(CONFIG['batch_size'])
            
        for sample in X
        for i in range(len(X_cond)):

            
            X_cond.as_numpy_iterator()
            
#             print(list(X_cond[i].values()))
            print(X[i].shape, len(list(X_cond[i].values())), Y[i].shape)
            
            train_step(list(X.as_numpy_iterator())[i], list(X_cond.as_numpy_iterator())[i], list(Y.as_numpy_iterator()[i]))

        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    print('Yay!')

### Training

In [12]:
test_ids = np.random.default_rng().choice(range(1, CONFIG['num_samples_total']), math.floor(CONFIG['test_percent'] * CONFIG['num_samples_total']), replace=False)

test_ids = set(test_ids)
train_ids = set(range(1, CONFIG['num_samples_total'])) - test_ids

if CONFIG['num_samples_using']:
    train_ids = [train_ids.pop() for i in range(CONFIG['num_samples_using'])]
    test_ids = [test_ids.pop() for i in range(CONFIG['num_samples_using'])]

In [21]:
X_nodes_dicts = get_surface_nodes(train_ids)

X = get_voxel_shapes_from_nodes_dicts(X_nodes_dicts)
print("X done")

X_cond = get_conditions(train_ids)
X_cond = [list(cond.values()) for cond in X_cond]
print("cond done")

Y_nodes_disp_dicts = get_surface_nodes_dispositions(train_ids)
Y_nodes_dicts = get_disposed_nodes(X_nodes_dicts, Y_nodes_disp_dicts)
Y = get_voxel_shapes_from_nodes_dicts(Y_nodes_dicts)
print("Y done")

X done
cond done
Y done


In [22]:
   
# Reshape to fit model input.
X = reshape_voxel_grid_into_np(X)
Y = reshape_voxel_grid_into_np(Y)

# Convert into Dataset class in order to use tf.batch function later.
train_data = tf.data.Dataset.from_tensor_slices(zip(X, X_cond, Y))

# X = tf.data.Dataset.from_tensor_slices(X)
# X_cond = tf.data.Dataset.from_tensor_slices(X_cond)
# Y = tf.data.Dataset.from_tensor_slices(Y)

In [31]:
train(X, X_cond, Y)

<TensorSliceDataset shapes: (600, 100, 100, 1), types: tf.int32>
[array([[4.00000000e+00, 1.96078431e-03, 3.74107021e-01, 7.46797547e-01],
       [4.00000000e+00, 3.92156863e-03, 5.65283492e-01, 7.46797547e-01],
       [1.00000000e+00, 1.96078431e-03, 5.65283492e-01, 7.46797547e-01]]), array([[2.00000000e+00, 1.96078431e-03, 3.44695256e-01, 7.46797547e-01],
       [1.00000000e+00, 1.58730159e-03, 7.64519508e-01, 7.46797547e-01]])]


TypeError: object of type 'BatchDataset' has no len()

In [None]:
Xt_nodes_dicts = get_surface_nodes(test_ids)

Xt = get_voxel_shapes_from_nodes_dicts(X_nodes_dicts)
print("Xt done")

Xt_cond = get_conditions(test_ids)
print("t-cond done")

vae_model([Xt[0], Xt_cond[0]], training=False)