# Training RAM on BACH

## Requirements

### Imports

In [None]:
from sklearn.model_selection import train_test_split

from tqdm import tqdm

# 2019041500 - use this tf nightly version
import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)

import numpy as np
import matplotlib.pyplot as plt

from model.ram import RecurrentAttentionModel

from data.bach_loader import minibatcher
from data.bach_loader import load_bach_images

### Data

In [None]:
(X_train, y_train), X_test, _ = load_bach_images("/tf/BACH")

X_train = X_train.reshape(-1, 1536, 2048, 3)
X_test = X_test.reshape(-1, 1536, 2048, 3)
    
X_train = (X_train/255).astype(np.float32)
X_test = (X_test/255).astype(np.float32)

y_train = tf.keras.utils.to_categorical(y_train)

print(X_train.shape, y_train.shape, np.max(X_train), np.min(X_train), X_train.dtype)
print(X_test.shape, np.max(X_test), np.min(X_test), X_test.dtype)

## Trainings
### Hyperparameter

In [None]:
batch_size = 10

learning_rate = 0.0001
std = 0.30

ram = RecurrentAttentionModel(time_steps=8,
                              n_glimpses=3, 
                              glimpse_size=64,
                              num_classes=4,
                              max_gradient_norm=1.0,
                              input_channels=3,
                              std=std)
adam_opt = tf.keras.optimizers.Adam(learning_rate)

### Trainingsloop

In [None]:
history = []
for timestep in tqdm(range(500)):
    losses = []
    rewards = []
    classification_losses = []
    
    # training steps
    batcher = minibatcher(X_train, y_train, batch_size, True)
    for X, y in batcher:
        with tf.GradientTape() as tape:
            # calculate losses
            logits = ram(X)
            loss, classification_loss, reward, _ = ram.hybrid_loss(logits, y)
            
            # append to list for output
            losses.append(loss.numpy())
            classification_losses.append(classification_loss.numpy())
            rewards.append(reward.numpy())
            
            # calculate gradient and do gradient descent
            gradients = tape.gradient(loss, ram.trainable_variables)
            adam_opt.apply_gradients(zip(gradients, ram.trainable_variables))
        
    print("step:", timestep, "loss:", np.mean(losses), "classification_loss:", np.mean(classification_losses), 
          "reward:", np.mean(rewards))
    history.append([(np.mean(losses), np.mean(classification_losses), np.mean(rewards))])

## Visualization

In [None]:
history = np.array(history).reshape(-1, 3)

In [None]:
plt.title("hybrid loss")
plt.plot(history.T[0])

In [None]:
plt.title("classification loss")
plt.plot(history.T[1])

In [None]:
plt.title("reward")
plt.plot(history.T[2])

In [None]:
def plot_path_of(number, batch):
    from visualization.model import plot_prediction_path_3d
    imgs = X_train[batch*batch_size:batch*batch_size + batch_size]
    labels = y_train[batch*batch_size:batch*batch_size + batch_size]
    logits = ram(imgs)
    _, prediction, location = ram.predict(logits, labels)
    labels = np.argmax(labels, 1)
    for i, (y, y_hat) in enumerate(zip(list(prediction.numpy()), list(labels))):
        if y == y_hat & y == number:
            loc = location[i].numpy()
            img = imgs[i]
            print("right")
            plot_prediction_path_3d(img, loc, 3, 64)
        if y != y_hat & y == number:
            print("wrong")
            loc = location[i].numpy()
            img = imgs[i]
            plot_prediction_path_3d(img, loc, 3, 64)

In [None]:
plot_path_of(0, 1)
plot_path_of(0, 2)
plot_path_of(0, 3)
plot_path_of(0, 4)
plot_path_of(0, 5)
plot_path_of(0, 6)
plot_path_of(0, 7)
plot_path_of(0, 8)
plot_path_of(0, 9)
plot_path_of(0, 10)

In [None]:
plot_path_of(1, 1)
plot_path_of(1, 2)
plot_path_of(1, 3)
plot_path_of(1, 4)
plot_path_of(1, 5)
plot_path_of(1, 6)
plot_path_of(1, 7)
plot_path_of(1, 8)
plot_path_of(1, 9)
plot_path_of(1, 10)

In [None]:
plot_path_of(2, 11)
plot_path_of(2, 12)
plot_path_of(2, 13)
plot_path_of(2, 14)
plot_path_of(2, 15)
plot_path_of(2, 16)
plot_path_of(2, 17)
plot_path_of(2, 18)
plot_path_of(2, 19)
plot_path_of(2, 20)
plot_path_of(2, 21)
plot_path_of(2, 22)
plot_path_of(2, 23)
plot_path_of(2, 24)
plot_path_of(2, 25)
plot_path_of(2, 26)

In [None]:
plot_path_of(3, 21)
plot_path_of(3, 22)
plot_path_of(3, 23)
plot_path_of(3, 24)
plot_path_of(3, 25)
plot_path_of(3, 26)
plot_path_of(3, 27)
plot_path_of(3, 28)
plot_path_of(3, 29)
plot_path_of(3, 30)
plot_path_of(3, 31)
plot_path_of(3, 32)
plot_path_of(3, 39)