In [1]:
import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)

from model.ram import RecurrentAttentionModel
import numpy as np
import time

from data.augmented_mnist import minibatcher
from data.augmented_mnist import get_mnist

In [2]:
(X_train, y_train),(X_test, y_test) = get_mnist(True, True, False)
print(X_train.shape, y_train.shape, np.max(X_train), np.min(X_train))
print(X_test.shape, y_test.shape, np.max(X_test), np.min(X_test))

(60000, 28, 28, 1) (60000, 10) 1.0 0.0
(10000, 28, 28, 1) (10000, 10) 1.0 0.0


In [3]:
# trainings step
loss = tf.keras.metrics.Mean(name='hybrid_loss')
reward = tf.keras.metrics.Mean(name='reward')
baseline_mse = tf.keras.metrics.Mean(name='baseline_mse')
classification_loss = tf.keras.metrics.Mean(name='classification_loss')
test_accuracy = tf.keras.metrics.Mean(name='test_accuracy')               

In [4]:
def train(learning_rate, std, batch_size=20, epochs=10):
    ram = RecurrentAttentionModel(time_steps=7,
                              n_glimpses=1, 
                              glimpse_size=8,
                              num_classes=10,
                              max_gradient_norm=5.0,
                              std=std)
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    for e in range(epochs):
        # trainings step
        batcher = minibatcher(X_train, y_train, batch_size, True)
        for X, y in batcher:
            with tf.GradientTape() as tape:
                logits = ram.call(X)
                hybrid_loss, c_loss, r, b_mse = ram.hybrid_loss(logits, y)

                gradients = tape.gradient(hybrid_loss, ram.trainable_variables)
                optimizer.apply_gradients(zip(gradients, ram.trainable_variables))

            loss(hybrid_loss)
            classification_loss(c_loss)
            baseline_mse(b_mse)
            reward(r)

        # testing step
        batcher = minibatcher(X_test, y_test, batch_size, True)
        for X, y in batcher:
            logits = ram(X)
            accuracy, _, _ = ram.predict(logits, y)
            test_accuracy(accuracy)

        # Get the metric results
        current_loss = loss.result().numpy()
        current_reward = reward.result().numpy()
        current_baseline_mse = baseline_mse.result().numpy()
        current_classification_loss = classification_loss.result().numpy()
        current_test_accuracy = test_accuracy.result().numpy()
        print("Epoch:", e, "loss:", current_loss, "reward:", current_reward, "baseline mse:", current_baseline_mse, "classification loss:", current_classification_loss, "accuracy:", current_test_accuracy)

In [5]:
train(1e-3, 0.22, batch_size=500)

Epoch: 0 loss: 2.136445 reward: 0.31918344 baseline mse: 0.3101171 classification loss: 1.8851024 accuracy: 0.44660002
Epoch: 1 loss: 1.9870595 reward: 0.40228355 baseline mse: 0.39081162 classification loss: 1.6713425 accuracy: 0.45994997
Epoch: 2 loss: 1.8962046 reward: 0.44974485 baseline mse: 0.4368791 classification loss: 1.5440432 accuracy: 0.49766666
Epoch: 3 loss: 1.8231125 reward: 0.48652118 baseline mse: 0.47257695 classification loss: 1.4423995 accuracy: 0.5281001
Epoch: 4 loss: 1.7551184 reward: 0.52142024 baseline mse: 0.50644195 classification loss: 1.3475928 accuracy: 0.5621401
Epoch: 5 loss: 1.7083069 reward: 0.54399765 baseline mse: 0.52835685 classification loss: 1.2832794 accuracy: 0.59061664
Epoch: 6 loss: 1.6594535 reward: 0.56884295 baseline mse: 0.5524692 classification loss: 1.215414 accuracy: 0.60555714
Epoch: 7 loss: 1.622904 reward: 0.586942 baseline mse: 0.57004255 classification loss: 1.1648072 accuracy: 0.62403744
Epoch: 8 loss: 1.5880995 reward: 0.6042686