In [1]:
import tensorflow as tf
from model.ram import RecurrentAttentionModel
import numpy as np
import time

from data.augmented_mnist import minibatcher
from data.augmented_mnist import get_mnist

In [2]:
(X_train, y_train),(X_test, y_test) = get_mnist(True, True, False)
print(X_train.shape, y_train.shape, np.max(X_train), np.min(X_train))
print(X_test.shape, y_test.shape, np.max(X_test), np.min(X_test))

(60000, 28, 28, 1) (60000, 10) 1.0 0.0
(10000, 28, 28, 1) (10000, 10) 1.0 0.0


In [3]:
# trainings step
loss = tf.keras.metrics.Mean(name='hybrid_loss')
reward = tf.keras.metrics.Mean(name='reward')
baseline_mse = tf.keras.metrics.Mean(name='baseline_mse')
classification_loss = tf.keras.metrics.Mean(name='classification_loss')
test_accuracy = tf.keras.metrics.Mean(name='test_accuracy')               

In [4]:
def train(learning_rate, std, batch_size=200, epochs=10):
    ram = RecurrentAttentionModel(time_steps=7,
                              n_glimpses=1, 
                              glimpse_size=8,
                              num_classes=10,
                              max_gradient_norm=5.0,
                              std=std)
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    for e in range(epochs):
        # trainings step
        batcher = minibatcher(X_train, y_train, batch_size, True)
        for X, y in batcher:
            with tf.GradientTape() as tape:
                logits = ram(X)
                hybrid_loss, c_loss, r, b_mse = ram.hybrid_loss(logits, y)

                gradients = tape.gradient(hybrid_loss, ram.trainable_variables)
                optimizer.apply_gradients(zip(gradients, ram.trainable_variables))

            loss(hybrid_loss)
            classification_loss(c_loss)
            baseline_mse(b_mse)
            reward(r)

        # testing step
        batcher = minibatcher(X_test, y_test, batch_size, True)
        for X, y in batcher:
            logits = ram(X)
            accuracy, _, _ = ram.predict(logits, y)
            test_accuracy(accuracy)

        # Get the metric results
        current_loss = loss.result().numpy()
        current_reward = reward.result().numpy()
        current_baseline_mse = baseline_mse.result().numpy()
        current_classification_loss = classification_loss.result().numpy()
        current_test_accuracy = test_accuracy.result().numpy()
        print("Epoch:", e, "loss:", current_loss, "reward:", current_reward, "baseline mse:", current_baseline_mse, "classification loss:", current_classification_loss, "accuracy:", current_test_accuracy)

In [5]:
train(1e-3, 0.22, batch_size=20)

Epoch: 0 loss: 2.061417 reward: 0.36596745 baseline mse: 0.3574243 classification loss: 1.7768238 accuracy: 0.44439995
Epoch: 1 loss: 1.9319688 reward: 0.440009 baseline mse: 0.42970714 classification loss: 1.5906212 accuracy: 0.49264964
Epoch: 2 loss: 1.8503592 reward: 0.48399484 baseline mse: 0.47263905 classification loss: 1.4758449 accuracy: 0.53376675
Epoch: 3 loss: 1.7894381 reward: 0.5165669 baseline mse: 0.50444 classification loss: 1.3894031 accuracy: 0.559875
Epoch: 4 loss: 1.7383448 reward: 0.5432736 baseline mse: 0.5305085 classification loss: 1.3180213 accuracy: 0.58172005
Epoch: 5 loss: 1.6962256 reward: 0.56528866 baseline mse: 0.5519965 classification loss: 1.258963 accuracy: 0.6000836
Epoch: 6 loss: 1.6597849 reward: 0.58328986 baseline mse: 0.5695638 classification loss: 1.2086772 accuracy: 0.61245733
Epoch: 7 loss: 1.6282768 reward: 0.5991296 baseline mse: 0.58502704 classification loss: 1.165122 accuracy: 0.6275124
Epoch: 8 loss: 1.599892 reward: 0.613419 baseline m