In [1]:
import tensorflow as tf
from st_resnet import ST_ResNet
from params import Params as param
from utils import batch_generator

In [2]:
# generate concatenated input for closeness, period and trend
x_closeness = tf.random.normal(shape=(1000, param.map_height, param.map_width, param.closeness_sequence_length * param.nb_flow))
x_period = tf.random.normal(shape=(1000, param.map_height, param.map_width, param.period_sequence_length * param.nb_flow))
x_trend = tf.random.normal(shape=(1000, param.map_height, param.map_width, param.trend_sequence_length * param.nb_flow))
y = tf.random.normal(shape=(1000, param.map_height, param.map_width, param.num_of_output))

# concatenate the three inputs along the depth dimension
X = tf.concat([x_closeness, x_period, x_trend], axis=-1)

# create train-test split of data
train_index = int(round((0.8*len(X)),0))
xtrain = X[:train_index]
ytrain = y[:train_index]
xtest = X[train_index:]
ytest = y[train_index:]

In [3]:
# obtain an interator for the next batch
train_batch_generator = batch_generator(xtrain, ytrain, param.batch_size)
test_batch_generator = batch_generator(xtest, ytest, param.batch_size)

In [6]:
g = ST_ResNet()

closeness_shape = (1, 32, 32, param.closeness_sequence_length * param.nb_flow)
period_shape = (1, 32, 32, param.period_sequence_length * param.nb_flow)
trend_shape = (1, 32, 32, param.trend_sequence_length * param.nb_flow)

dummy_closeness = tf.random.normal(closeness_shape)
dummy_period = tf.random.normal(period_shape)
dummy_trend = tf.random.normal(trend_shape)

dummy_output = g(dummy_closeness, dummy_period, dummy_trend)


In [7]:
for epoch in range(param.num_epochs):
    loss_train = 0
    loss_val = 0
    print("Epoch: {}\t".format(epoch), )
    
    # Training
    num_batches = xtrain.shape[0] // param.batch_size
    for b in range(num_batches):
        x_batch, y_batch = next(train_batch_generator)
        x_closeness = x_batch[:, :, :, :param.closeness_sequence_length * param.nb_flow]
        x_period = x_batch[:, :, :, param.closeness_sequence_length * param.nb_flow:param.closeness_sequence_length * param.nb_flow + param.period_sequence_length * param.nb_flow]
        x_trend = x_batch[:, :, :, param.closeness_sequence_length * param.nb_flow + param.period_sequence_length * param.nb_flow:]
        result = g.train_step(((x_closeness, x_period, x_trend), y_batch))
        loss_tr = result['loss_train']
        loss_train = loss_tr * param.delta + loss_train * (1 - param.delta) # exponential moving average (EMA) update rule for the training loss
    
        with tf.summary.create_file_writer(param.log_dir + '/train').as_default():
            tf.summary.scalar('loss', loss_tr, step=epoch * num_batches + b)
    
    # Testing
    num_batches = xtest.shape[0] // param.batch_size
    for b in range(num_batches):
        x_batch, y_batch = next(test_batch_generator)
        x_closeness = x_batch[:, :, :, :param.closeness_sequence_length * param.nb_flow]
        x_period = x_batch[:, :, :, param.closeness_sequence_length * param.nb_flow:param.closeness_sequence_length * param.nb_flow + param.period_sequence_length * param.nb_flow]
        x_trend = x_batch[:, :, :, param.closeness_sequence_length * param.nb_flow + param.period_sequence_length * param.nb_flow:]
        result = g.test_step(((x_closeness, x_period, x_trend), y_batch))
        loss_v = result['loss_test']
        loss_val += loss_v
        
        with tf.summary.create_file_writer(param.log_dir + '/val').as_default():
            tf.summary.scalar('loss', loss_v, step=epoch * num_batches + b)
    
    if(num_batches != 0):
        loss_val /= num_batches
    
    print("loss: {:.3f}, val_loss: {:.3f}".format(loss_train, loss_val))
    
    # Save the model after every epoch
    g.save(param.model_path)


Epoch: 0	
loss: 2.309, val_loss: 1.003




INFO:tensorflow:Assets written to: model_logs/20230326-221051/assets


INFO:tensorflow:Assets written to: model_logs/20230326-221051/assets


Epoch: 1	
loss: 1.006, val_loss: 0.999




INFO:tensorflow:Assets written to: model_logs/20230326-221051/assets


INFO:tensorflow:Assets written to: model_logs/20230326-221051/assets


Epoch: 2	
loss: 1.004, val_loss: 0.999




INFO:tensorflow:Assets written to: model_logs/20230326-221051/assets


INFO:tensorflow:Assets written to: model_logs/20230326-221051/assets
