In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import math
from keras.optimizers import Adam
from datagen import SequenceDataGenerator
from embedding_model import SequenceEmbeddingModel, sequence_loss_with_params
from params import Params
from IPython.display import clear_output
from utils import *
from evaluate import sequence_eval

In [None]:
params = Params()

params.EMBEDDING_DIM            = 12
params.BATCH_SIZE               = 1
params.NUM_SHAPE                = 3
params.NUM_CLASSES              = params.NUM_SHAPE + 1
params.NUM_FILTER               = [256, 256, 128]
params.ETH_MEAN_SHIFT_THRESHOLD = 1.5
params.DELTA_VAR                = 0.5
params.DELTA_D                  = 1.5
params.IMG_SIZE                 = 256
params.OUTPUT_SIZE              = 64
params.SEQUENCE_LEN             = 20
params.BACKBONE                 = 'xception'
params.TASK                     = 'sequence'
params.COLORS                   = np.random.random((params.NUM_SHAPE+1, 3))

In [None]:
model = SequenceEmbeddingModel(params)
optim = Adam(lr = LEARNING_RATE)
loss_function = sequence_loss_with_params(params)
model.compile(optim, loss = loss_function)
clear_output()
dg = SequenceDataGenerator(
    num_shape    = params.NUM_SHAPE, 
    image_size   = params.IMG_SIZE,
    sequence_len = params.SEQUENCE_LEN)

In [None]:
EPOCHS = 100
LEARNING_RATE = 1e-4
nan = float('nan')
step = 0
loss_history = []

for epoch in range(EPOCHS):
    for _ in range(100):
        sequence = dg.get_sequence()
        for i in range(params.SEQUENCE_LEN - 1):
            image_info = sequence[i]
            prev_image_info = sequence[i+1]
            x, y = prep_double_frame(image_info, prev_image_info, params)
            history = model.fit(x, y, batch_size = 1, verbose = False)
            loss_history.append(history.history['loss'][-1])
            if step % 100 == 99:
                clear_output()
                visualize_history(loss_history, 'loss')
            sequence = dg.get_sequence()
            sequence_eval(model, sequence, params)
    model.save('sequence.h5')