In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import torchvision.transforms as transforms

import copy  # To make a deepcopy of the mainVQN
import os
import time

from PIL.Image import BICUBIC
from torchvision.datasets import MNIST
from ram import RecurrentAttentionModel

In [2]:
# Image parameters
num_classes = 10       # Digits 0-9
image_size = 28        # MNIST images are 28x28 b/w images
in_channels = 1

# Basic training hyperparameters
batch_size = 1
lr = 0.01
grad_clip = 1

# Model hyperparameters
glimpse_sizes = [5, 7, 10]
padding = int((glimpse_sizes[-1] - 1) // 2)
padded_img_size = image_size + 2*padding
pad_imgs = False                 # We'll prepad to speed things up
glimpse_h_size = 128             # Size of the transformed Glimpse representation in the GlimpseNetwork
loc_h_size = 128                 # Size of the transformed location representation in the GlimpseNetwork
glimpse_network_size = 256       # Size of the GlimpseNetwork output
rnn_state_size = 256
action_state_size = num_classes  # Classification task
num_rnn_layers = 1
learn_kernels = False
continuous_location = True       # Glimpse locations are represented using pairs of floats in the range of (0, 1)
dropout = 0.15

In [3]:
# Initialize the mainVQN (that will be predicting the best actions to take)
# and the valueVQN (that will be estimating the Q-values for the actions
# predicted by the mainVQN)
mainVQN = RecurrentAttentionModel(
    image_size = image_size, in_channels = in_channels,
    glimpse_h_size = glimpse_h_size, loc_h_size = loc_h_size,
    glimpse_network_size = glimpse_network_size,
    rnn_state_size = rnn_state_size,
    action_state_size = action_state_size,
    num_rnn_layers = num_rnn_layers, glimpse_sizes = glimpse_sizes,
    pad_imgs = pad_imgs, learn_kernels = learn_kernels,
    continuous_location = continuous_location
)
# Deepcopy the mainVQN to get the valueVQN
valueVQN = copy.deepcopy(mainVQN)
# Set valueVQN's parameters to not require gradients since we'll
# be updating them manually rather than through backpropagation
for p in valueVQN.parameters():
    p.requires_grad = False

In [4]:
# Down/load MNIST data
mnist_data_loc = os.path.join('data', 'mnist')
data_transforms = transforms.Compose([ # Keep things pretty basic for now
    # Rotate images a random number of degrees in the range (-deg, deg), keeping the same image size
    transforms.RandomRotation(degrees = 65),
    # Crop the given PIL Image at a random location
    transforms.RandomCrop(size = 24),
    # Resize the input PIL Image to the given size
    transforms.Resize(image_size, interpolation = BICUBIC),
    # Pad images with 0s so the GlimpseSensor won't have to
    transforms.Pad(padding = padding),
    # Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]
    # to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    transforms.ToTensor(),
])

# Training dataset and dataloader
mnist_train_dataset = MNIST(
    root = mnist_data_loc, train = True, download = True,
    transform = data_transforms
)
train_data = utils.data.DataLoader(
    mnist_train_dataset, batch_size = batch_size, shuffle = True,
)

# Test dataset and dataloader
mnist_test_dataset = MNIST(
    root = mnist_data_loc, train = False, download = True,
    transform = data_transforms
)
test_data = utils.data.DataLoader(
    mnist_test_dataset, batch_size = batch_size
)

In [5]:
# RL training hyperparameters
max_ep_len = 30                 # Maximum number of glimpses per episode
explore_steps = 25              # Number of initial steps to just explore by using random glimpse locations
pretrain_steps = 10000          # Number of steps over which to anneal the Boltzmann temperature
num_episodes = len(train_data)  # Go through each example (as a single episode) in the training set once per epoch
num_epochs = 10                 # Number of times to go through all of the training examples

error_penalty = 0.  # Amount of penalty to apply each time the agent makes an incorrect classification (>= 0.)
gamma = 0.99        # Reward discount factor
tau = 0.001         # Weight applied to the main RAM's parameters to use when updating the value-predicting RAM's parameters

epsilon_start = 1         # Starting and ending probabilities of taking a random action during the initial exploration
epsilon_end = 0.1         #   steps. This value is annealed over the course of the initial exploration period
boltz_temp_start = 100.   # Starting and ending Boltzmann temperatures used to encourage exploration. This is annealed
boltz_temp_end = 1.       #   over the course of training so that by the end, a regular softmax is used

In [6]:
# Experience replay buffer
class ExperienceReplayBuffer:
    def __init__(self, buffer_size):
        self.buffer = []
        self.buffer_size = buffer_size
        
    def add(self, experience):
        if len(self.buffer) + 1 >= self.buffer_size:
            self.buffer[0 : len(self.buffer)-self.buffer_size+1] = []
        self.buffer.append(experience)
        
    def sample(self, batch_sz, trace_ln):
        sampled_eps = random.sample(self.buffer, batch_sz)
        sampled_traces = []
        for ep in sampled_eps:
            try:
                t_s = np.random.randint(0, len(ep)-trace_ln+1)
            except:
                print('len(self.buffer) = %d  | len(sampled_eps) = %d' % (len(self.buffer), len(sampled_eps)))
                print('len(ep) = %d  | trace_ln = %d' % (len(ep), trace_ln))
                print(self.buffer, '\n\n\n')
                print(sampled_eps)
            t_e = t_s + trace_ln
            sampled_traces.append(ep[t_s : t_e])
        sampled_traces = np.array(sampled_traces)
        return np.reshape(sampled_traces, [batch_sz*trace_ln, 5])