In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from keras.optimizers import Adam
from datagen import SequenceDataGenerator
from embedding_model import SequenceEmbeddingModel, sequence_loss_with_params
from params import Params
from IPython.display import clear_output
from utils import *
from evaluate import sequence_eval

Using TensorFlow backend.


In [2]:
EPOCHS = 100
LEARNING_RATE = 1e-4

In [3]:
params = Params()

params.EMBEDDING_DIM            = 6
params.BATCH_SIZE               = 1
params.NUM_CLASSES              = 4
params.NUM_SHAPE                = 3
params.NUM_FILTER               = [256, 256, 128]
params.ETH_MEAN_SHIFT_THRESHOLD = 1.5
params.DELTA_VAR                = 0.5
params.DELTA_D                  = 1.5
params.IMG_SIZE                 = 128
params.OUTPUT_SIZE              = 32
params.SEQUENCE_LEN             = 20
params.BACKBONE                 = 'xception'
params.TASK                     = 'sequence'
params.COLORS                   = np.random.random((params.NUM_SHAPE+1, 3))

In [None]:
model = SequenceEmbeddingModel(params)
optim = Adam(lr = LEARNING_RATE)
loss_function = sequence_loss_with_params(params)
model.compile(optim, loss = loss_function)
clear_output()

W0730 17:07:42.237943 15732 deprecation_wrapper.py:119] From c:\users\38909\appdata\local\conda\conda\envs\ml\lib\site-packages\keras\backend\tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0730 17:07:42.249912 15732 deprecation_wrapper.py:119] From c:\users\38909\appdata\local\conda\conda\envs\ml\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0730 17:07:42.250937 15732 deprecation_wrapper.py:119] From c:\users\38909\appdata\local\conda\conda\envs\ml\lib\site-packages\keras\backend\tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0730 17:07:45.720660 15732 deprecation_wrapper.py:119] From c:\users\38909\appdata\local\conda\conda\envs\ml\lib\site-packages\keras\backend\tensorflow_backend.py:3980: The name tf.nn.avg_pool is deprecated. Please use tf.nn.avg_

In [None]:
dg = SequenceDataGenerator(
    num_shape = params.NUM_SHAPE, 
    image_size = params.IMG_SIZE,
    sequence_len = params.SEQUENCE_LEN)
step = 0
loss_history = []

In [None]:
def fit_xy(model, x, y, loss_history, params, dg, step):
    history = model.fit(x, y, batch_size = 1, verbose = False)
    loss_history.append(history.history['loss'][-1])
    if step % 100 == 99:
        clear_output()
        visualize_history(loss_history, 'loss')
        sequence = dg.get_sequence()
        sequence_eval(model, sequence, params)
    step += 1
    update_progress( (step%100) / 100)
    return loss_history, step

In [None]:
class_num     = params.NUM_CLASSES
embedding_dim = params.EMBEDDING_DIM

for epoch in range(EPOCHS):
    for _ in range(100):
        sequence = dg.get_sequence()
        image_info = sequence[0]
        x, y = prep_half_pair_for_model(image_info, params)
        loss_history, step = fit_xy(model, x, y, loss_history, params, dg, step)
        for i in range(1, len(sequence)):
            prev_image_info = sequence[i-1]
            image_info = sequence[i]
            x, y = prep_half_pair_for_model(prev_image_info, params)
            outputs = np.squeeze(model.predict(x))
            emb = outputs[:, :, (class_num):(class_num + embedding_dim)]
            x, y = prep_pair_for_model(image_info, params, prev_image_info, emb)
            loss_history, step = fit_xy(model, x, y, loss_history, params, dg, step)
            
    model.save('sequence.h5')