In [None]:
import pickle
import torch
import torch.nn as nn
import numpy as np
from collections import namedtuple
import matplotlib
import matplotlib.pyplot as plt
import torch.optim as optim
import torch.nn.functional as F

In [None]:
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# initialize numpy random seed
np.random.seed(57)
random_gen = np.random.default_rng()

In [None]:
'''
    Takes a state or batch of states and converts
    them into a pytorch tensor format. Shapes should
    be (batch_size x 1 x 4 x 4) or (1 x 1 x 4 x 4)
'''
def process_state(state):
    if state is not None:
        state = torch.Tensor(state)
        state = state.unsqueeze(0).unsqueeze(0)
    return state

In [None]:
# Set up transition and ReplayMemory classes
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        # Gets a minibatch of tuples
        return random_gen.choice(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
# Load dataset and shuffle
raw_data = pickle.load(open('dataset.p', 'rb'))

dataset = []
for episode in raw_data:
    for exp in raw_data[episode]['experiences']:
        dataset.append(exp)

random_gen.shuffle(dataset)
print('Dataset has {} experiences'.format(len(dataset)))

In [None]:
# Push entire dataset into ReplayMemory
memory = ReplayMemory(len(dataset))

for exp in dataset:
    state, action, nextstate, reward = exp
    state = process_state(state)
    nextstate = process_state(state)
    memory.push(state, action, nextstate, reward)

# Test sample method
print(memory.sample(2))

In [None]:
# Define network

class DQN(nn.Module):
    def __init__(self, h, w, outputs):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=4, padding=1, stride=1)
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=2, padding=1, stride=1)
        self.bn2 = nn.BatchNorm2d(16)

        # Number of Linear input connections depends on output of conv2d layers
        # and therefore the input image size, so compute it.
        def conv2d_size_out(size, kernel_size, padding=1, stride = 1):
            return (size  +2*padding - (kernel_size - 1) - 1) // stride  + 1
        
        convw = conv2d_size_out(conv2d_size_out(w, 4), 2)
        convh = conv2d_size_out(conv2d_size_out(h, 4), 2)
        linear_input_size = convw * convh * 16
        self.head = nn.Linear(linear_input_size, outputs)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        return self.head(x.view(x.size(0), -1))

In [None]:
network = DQN(4, 4, 4)
batch = memory.sample(1)
print(batch)
results = network.forward(process_state(batch[0][0]))
print(results)

In [None]:
BATCH_SIZE = 128
GAMMA = 0.999
TARGET_UPDATE = 20

board_width = 4
board_height = 4
n_actions = 4

policy_net = DQN(board_height, board_width, n_actions).to(device)
target_net = DQN(board_height, board_width, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.RMSprop(policy_net.parameters())

steps_done = 0


def select_action(state):
    global steps_done
    steps_done += 1
    with torch.no_grad():
        # t.max(1) will return largest column value of each row.
        # second column on max result is index of where max element was
        # found, so we pick action with the larger expected reward.
        return policy_net(state).max(1)[1].view(1, 1)

episode_durations = []

def plot_durations():
    plt.figure(2)
    plt.clf()
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        display.clear_output(wait=True)
        display.display(plt.gcf())

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
    # detailed explanation). This converts batch-array of Transitions
    # to Transition of batch-arrays.
    batch = Transition(*zip(*transitions))

    # Compute a mask of non-final states and concatenate the batch elements
    # (a final state would've been the one after which simulation ended)
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                                if s is not None])
    state_batch = torch.cat(batch.state)
    print(state_batch)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
    # Compute the expected Q values
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute Huber loss
    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
    print(loss)
    
    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()

In [None]:
train_iterations = 200
for i_train in range(train_iterations):
    # Update policy network
    optimize_model()
    
    # Update the target network, copying all weights and biases in DQN
    if i_train % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())

print('Complete')
env.render()
env.close()
plt.ioff()
plt.show()