In [None]:
from keras import models
from pprint import pprint
import json
import os
from gtsrb_loader.load_data import load_bounding_boxes_generator
from gtsrb_loader.get_folderpath import get_folderpath
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches

# This is a bit of magic to make matplotlib figures appear inline in the notebook
# rather than in a new window.
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# Some more magic so that the notebook will reload external python modules;
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

In [None]:
path = get_folderpath(subset='train', original_images=False)
path

In [None]:
model = models.load_model(os.path.join('model_files', '12-07-17-localization.h5'))

In [None]:
pics = np.array(next(load_bounding_boxes_generator(path, batch_size=30))[0])
pics.shape

In [None]:
predictions = model.predict(pics, verbose=1)

these are the predictions of the bounding boxes

In [None]:
predictions

In [None]:
width =  predictions[:, 2] - predictions[:, 0]
height = predictions[:, 3] - predictions[:, 1]
for i in range(30):
    ax = plt.subplot(5, 6, i+1)
    plt.imshow(pics[i])
    plt.axis('off')
    rect = patches.Rectangle((predictions[i, 0], predictions[i, 1]),
                             width[i], height[i], linewidth=1, edgecolor='y', facecolor='none')
    # Add the patch to the Axes
    ax.add_patch(rect)
plt.show()

## Train the model

In [None]:
def pipe_generator(gen):
    for X, y in gen:
        yield np.array(X), np.array(y).reshape((-1,4))

In [None]:
batch_size=10
gen = pipe_generator(load_bounding_boxes_generator(path, batch_size=batch_size))

In [None]:
for _ in range(3):
    gen = pipe_generator(load_bounding_boxes_generator(path, batch_size=batch_size))
    model.fit_generator(gen, steps_per_epoch=39000/batch_size, verbose=2, epochs=1)

In [None]:
model.save(os.path.join('model_files','12-07-17-localization.h5'))