Trains a Neural Network using the Data-Generator.

In [1]:
from keras import models, layers
import numpy as np
import matplotlib.pyplot as plt
import os
import random
from datagenerator import DataGenerator
import datetime

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


# Hyper-parameters.

In [2]:
if os.path.exists("datasetpath.txt"):
    dataset_path = open("datasetpath.txt", "r").read().replace("\n", "")
else:
    dataset_path = "../data"

input_type = "image"
#input_type = "voxelgrid"
#input_type = "pointcloud"
#train_size = 500
#validate_size = 100

# Instantiate the data-generator.

In [3]:
data_generator = DataGenerator(dataset_path=dataset_path, input_type=input_type, output_targets=["height", "weight"])

print("jpg_paths", len(data_generator.jpg_paths))
print("pcd_paths", len(data_generator.pcd_paths))
print("json_paths_personal", len(data_generator.json_paths_personal))
print("json_paths_measures", len(data_generator.json_paths_measures))
print("QR-Codes:\n" + "\n".join(data_generator.qrcodes))

print("Done.")

jpg_paths 4511
pcd_paths 1360
json_paths_personal 40
json_paths_measures 74
QR-Codes:
SAM-02-003-01
SAM-GOV-001
SAM-GOV-002
SAM-GOV-003
SAM-GOV-004
SAM-GOV-005
SAM-GOV-008
SAM-GOV-011
SAM-GOV-012
SAM-GOV-013
SAM-GOV-014
SAM-GOV-023
SAM-GOV-025
SAM-GOV-026
SAM-GOV-033
SAM-GOV-034
SAM-GOV-035
SAM-GOV-036
SAM-GOV-037
SAM-GOV-038
SAM-GOV-041
SAM-GOV-042
SAM-GOV-043
SAM-GOV-044
SAM-GOV-099
SAM-SNG-011
SAM-SNG-012
SAM-SNG-013
SAM-SNG-014
SAM-SNG-015
SAM-SNG-016
SAM-SNG-021
SAM-SNG-036
SAM-SNG-066
SAM-SNG-067
SAM-SNG-072
SAM-SNG-091
SAM-SNG-096
Done.


#  Do the training-validation-split on QR-codes.

In [4]:
qrcodes_shuffle = list(data_generator.qrcodes)
qrcodes_shuffle = [qrcode for qrcode in qrcodes_shuffle if qrcode.startswith("SAM-GOV")]
random.shuffle(qrcodes_shuffle)
split_index = int(0.8 * len(qrcodes_shuffle))
qrcodes_train = qrcodes_shuffle[:split_index]
qrcodes_validate = qrcodes_shuffle[split_index:]

print("QR-Codes train:")
print(" ".join(qrcodes_train))
print("")

print("QR-Codes validate:")
print(" ".join(qrcodes_validate))
print("")

print("Done.")

QR-Codes train:
SAM-GOV-014 SAM-GOV-011 SAM-GOV-008 SAM-GOV-037 SAM-GOV-025 SAM-GOV-099 SAM-GOV-035 SAM-GOV-001 SAM-GOV-005 SAM-GOV-034 SAM-GOV-036 SAM-GOV-041 SAM-GOV-043 SAM-GOV-004 SAM-GOV-042 SAM-GOV-044 SAM-GOV-026 SAM-GOV-012 SAM-GOV-033

QR-Codes validate:
SAM-GOV-002 SAM-GOV-003 SAM-GOV-023 SAM-GOV-013 SAM-GOV-038

Done.


# Do a plausibility check on training-data.

In [5]:
# TODO

# Create a net.

In [None]:
input_shape = data_generator.input_shape
output_size = data_generator.output_size

model = models.Sequential()

model.add(layers.Flatten(input_shape=input_shape))
model.add(layers.Dense(128, activation="relu"))
model.add(layers.Dense(64, activation="relu"))
model.add(layers.Dense(output_size))
model.summary()

model.compile(
    optimizer="rmsprop",
    loss="mse",
    metrics=["mae"]
)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 43200)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 128)               5529728   
_________________________________________________________________
dense_2 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_3 (Dense)              (None, 2)                 130       
Total params: 5,538,114
Trainable params: 5,538,114
Non-trainable params: 0
_________________________________________________________________


# Train the net.

In [None]:
history = model.fit_generator(
    data_generator.generate(size=32, qrcodes_to_use=qrcodes_train),
    steps_per_epoch=32,
    epochs=100,
    validation_data=data_generator.generate(size=32, qrcodes_to_use=qrcodes_validate),
    validation_steps=8
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100

# Visualize results.

In [None]:
plt.plot(history.history["loss"], label="loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.legend()
plt.show()
plt.close()

plt.plot(history.history["mean_absolute_error"], label="mean_absolute_error")
plt.plot(history.history["val_mean_absolute_error"], label="val_mean_absolute_error")
plt.legend()
plt.show()
plt.close()

# Save model.

In [None]:
datetime_string = datetime.datetime.now().strftime("%Y%m%d-%H%M")
model_name = datetime_string + ".h5"
model_path = os.path.join(model_name)
model.save(model_path)

# Test the model.

In [None]:
# Generate some data.
x_input, y_output = next(data_generator.generate(size=32, qrcodes_to_use=qrcodes_validate))

# Evaluate.
loss, metric = model.evaluate(x_input, y_output)
print("Loss:", loss)
print("Metric:", metric)
print("")

# Do prediction and compare.
y_output_pred = model.predict(x_input)
#for y_true, y_pred in zip(y_output, y_output_pred):
    #print(y_true, y_pred, y_true - y_pred)
    
plt.title("Height " + input_type)
plt.plot(y_output[:,0], label="Truth")
plt.plot(y_output_pred[:,0], label="Predicted")
plt.plot(np.abs(y_output_pred - y_output)[:,0], label="Error")
plt.legend()
plt.savefig(datetime_string + "-" + input_type + "-height.jpg")
plt.show()
plt.close()

plt.title("Weight " + input_type)
plt.plot(y_output[:,1], label="Truth")
plt.plot(y_output_pred[:,1], label="Predicted")
plt.plot(np.abs(y_output_pred - y_output)[:,1], label="Error")
plt.legend()
plt.savefig(datetime_string + "-" + input_type + "-weight.jpg")
plt.show()
plt.close()
 
plt.title("Weight for height " + input_type)
plt.plot(y_output[:,1] / y_output[:,0], label="Truth")
plt.plot(y_output_pred[:,1] / y_output_pred[:,0], label="Predicted")
plt.legend()
plt.savefig(datetime_string + "-" + input_type + "-weight_for_height.jpg")
plt.show()
plt.close()