In [None]:
import modelutils
import utils
import pickle
import numpy as np
import matplotlib.pyplot as plt
from keras import callbacks
import pprint
import glob
import pickle

# Finding the latest dataset-paths.

In [None]:
dataset_name_voxelgrid = utils.get_latest_preprocessed_dataset("voxelgrid-dataset")
print(dataset_name_voxelgrid)

dataset_name_pointcloud = utils.get_latest_preprocessed_dataset("pointcloud-dataset")
print(dataset_name_pointcloud)

# Miscellaneous stuff.

In [None]:
pp = pprint.PrettyPrinter(indent=4)

In [None]:
tensorboard_callback = callbacks.TensorBoard()
histories = []

In [None]:
def save_model_and_history(model, history, name):
    
    print("Saving model and history...")
    
    datetime_string = utils.get_datetime_string()
    
    model_name = datetime_string + "-" + name + "-model.h5"
    model_voxnet.save(model_name)
    print("Saved model to" + model_name)

    
    history_name = datetime_string + "-" + name + "-history.p"
    pickle.dump(history.history, open(history_name, "wb"))
    print("Saved history to" + history_name)

# Training VoxNet.

In [None]:
dataset_name = dataset_name_voxelgrid

print("Loading dataset...")
(x_input_train, y_output_train, _), (x_input_test, y_output_test, _), dataset_parameters = pickle.load(open(dataset_name, "rb"))
pp.pprint(dataset_parameters)

In [None]:
input_shape = (32, 32, 32)
output_size = 2 
model_voxnet = modelutils.create_voxnet_model_homepage(input_shape, output_size)
model_voxnet.summary()

 # Compile the model.
model_voxnet.compile(
        optimizer="rmsprop",
        loss="mse",
        metrics=["mae"]
    )

# Train the model.
history = model_voxnet.fit(
    x_input_train, y_output_train,
    epochs=50,
    validation_data=(x_input_test, y_output_test),
    callbacks=[tensorboard_callback]
    )

histories.append(history)

save_model_and_history(model_voxnet, history, "voxnet")

# Training PointNet.

In [None]:
dataset_name = dataset_name_pointcloud

print("Loading dataset...")
(x_input_train, y_output_train, _), (x_input_test, y_output_test, _), dataset_parameters = pickle.load(open(dataset_name, "rb"))
pp.pprint(dataset_parameters)

In [None]:
def transform(x_input, y_output):

    x_input_transformed = []
    y_output_transformed = []
    for input_sample, output_sample in zip(x_input_train, y_output_train):
        if input_sample.shape[0] == 30000:
            x_input_transformed.append(input_sample[:,0:3])
            y_output_transformed.append(output_sample)
        else:
            # TODO maybe do some padding here?
            print("Ignoring shape:", input_sample.shape)
            
    x_input_transformed = np.array(x_input_transformed)
    y_output_transformed = np.array(y_output_transformed)
    return x_input_transformed, y_output_transformed
    
x_input_train, y_output_train = transform(x_input_train, y_output_train)
x_input_test, y_output_test = transform(x_input_test, y_output_test)

print("Training data input shape:", x_input_train.shape)
print("Training data output shape:", y_output_train.shape)
print("Testing data input shape:", x_input_test.shape)
print("Testing data output shape:", y_output_test.shape)
print("")

In [None]:
input_shape = (30000, 3)
output_size = 2 
model_pointnet = modelutils.create_point_net(input_shape, output_size)
model_pointnet.summary()

 # Compile the model.
model_pointnet.compile(
        optimizer="rmsprop",
        loss="mse",
        metrics=["mae"]
    )

# Train the model.
history = model_pointnet.fit(
    x_input_train, y_output_train,
    epochs=50,
    validation_data=(x_input_test, y_output_test),
    callbacks=[tensorboard_callback],
    batch_size=4
    )

histories.append(history)

save_model_and_history(model_pointnet, history, "pointnet")

# Plot the histories.

In [None]:
def plot_histories(histories, names):
    for index, (history, name) in enumerate(zip(histories, names)):
        for key, data in history.history.items():
            plt.plot(data, label=name + "-" + key)
    
    # TODO consider: plt.savefig()
    plt.show()
    plt.close()
    
plot_histories(histories, ["voxnet", "pointnet"])