# DQN model

## Libraries and load the enviroment

In [None]:
import numpy as np
from mlagents_envs.environment import UnityEnvironment, ActionTuple
from typing import Dict
import matplotlib.pyplot as plt
import os

# pytorch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch as T

In [None]:
env = UnityEnvironment(file_name = None, seed=1, side_channels=[])
env.reset()

## All DRL components

### Buffer

In [None]:
class ReplayBuffer:
    def __init__(self, max_size, input_shape, n_actions):
        self.mem_size = max_size
        self.mem_count = 0
        self.state_memory = np.zeros((self.mem_size, *input_shape), dtype=np.float32)
        self.new_state_memory = np.zeros((self.mem_size, *input_shape), dtype=np.float32)
        self.action_memory = np.zeros(self.mem_size, dtype=np.int64)
        self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
        self.terminal_memory = np.zeros(self.mem_size, dtype=np.uint8)

    def store_transition(self, state, action, reward, state_, done):
        index = self.mem_count % self.mem_size
        self.state_memory[index] = state
        self.action_memory[index] = action
        self.reward_memory[index] = reward
        self.new_state_memory[index] = state_
        self.terminal_memory[index] = done
        self.mem_count += 1

    def sample_buffer(self, batch_size):
        max_mem = min(self.mem_count, self.mem_size)
        batch = np.random.choice(max_mem, batch_size, replace=False)

        states = self.state_memory[batch]
        actions = self.action_memory[batch]
        rewards = self.reward_memory[batch]
        states_ = self.new_state_memory[batch]
        dones = self.terminal_memory[batch]

        return states, actions, rewards, states_, dones

### Model

In [None]:
class DQNetwork(nn.Module):
	def __init__(self, input_dims, n_actions, lr, name, chkpt_dir):
		super().__init__()
		self.checkpoint_dir = chkpt_dir
		self.checkpoint_file = os.path.join(self.checkpoint_dir, name)

		self.conv1 = nn.Conv2d(input_dims[0], 32, (8, 8), stride=(4, 4))
		self.conv2 = nn.Conv2d(32, 64, (4, 4), stride=(2, 2))
		self.conv3 = nn.Conv2d(64, 64, (3, 3), stride=(1, 1))

		fc_input_dims = self.calculate_conv_output_dims(input_dims)
		self.fc1 = nn.Linear(fc_input_dims, 512)
		self.fc2 = nn.Linear(512, n_actions)

		self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
		self.loss = nn.MSELoss()
		self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
		self.to(self.device)

	def calculate_conv_output_dims(self, input_dims):
		state = T.zeros(1, *input_dims)
		dims = self.conv1(state)
		dims = self.conv2(dims)
		dims = self.conv3(dims)
		return int(np.prod(dims.size()))

	def forward(self, state):
		x = F.relu(self.conv1(state))
		x = self.conv2(x)
		x = F.relu(x)
		x = self.conv3(x)
		x = x.view(x.size()[0], -1)
		x = F.relu(x)
		x = self.fc1(x)
		x = F.relu(x)
		x = self.fc2(x)
		return x

	def save_checkpoint(self):
		print('... saving checkpoint ...')
		T.save(self.state_dict(), self.checkpoint_file)

	def load_checkpoint(self):
		print('... loading checkpoint ...')
		self.load_state_dict(T.load(self.checkpoint_file))

### Agent

In [None]:
class Agent:
    def __init__(self, gamma, epsilon, lr, n_actions, input_dims,
                mem_size, batch_size, eps_min=0.01, eps_dec=5e-7,
                replace=1000, algo=None, env_name=None, chkpt_dir='tmp/dqn'):
        self.gamma = gamma
        self.epsilon = epsilon
        self.lr = lr
        self.n_actions = n_actions
        self.input_dims = input_dims
        self.mem_size = mem_size
        self.batch_size = batch_size
        self.eps_min = eps_min
        self.eps_dec = eps_dec
        self.replace_target_count = replace
        self.algo = algo
        self.env_name = env_name
        self.chkpt_dir = chkpt_dir
        self.action_space = [i for i in range(self.n_actions)]
        self.learn_step_counter = 0

        self.memory = ReplayBuffer(self.mem_size, self.input_dims, self.n_actions)

    def decrement_epsilon(self):
        self.epsilon = self.epsilon - self.eps_dec if self.epsilon > self.eps_min else self.eps_min

    def store_transition(self, state, action, reward, state_, done):
        self.memory.store_transition(state, action, reward, state_, done)

    def sample_memory(self):
        state, action, reward, new_state, done = self.memory.sample_buffer(self.batch_size)

        states = T.tensor(state).to(self.q_eval.device)
        actions = T.tensor(action).to(self.q_eval.device)
        rewards = T.tensor(reward).to(self.q_eval.device)
        states_ = T.tensor(new_state).to(self.q_eval.device)
        dones = T.tensor(done).to(self.q_eval.device)

        return states, actions, rewards, states_, dones

    def replace_target_network(self):
        if self.learn_step_counter % self.replace_target_count == 0:
            self.q_next.load_state_dict(self.q_eval.state_dict())

    def save_models(self):
        self.q_eval.save_checkpoint()
        self.q_next.save_checkpoint()

    def load_models(self):
        self.q_eval.load_checkpoint()
        self.q_next.load_checkpoint()


class DQNAgent(Agent):
    def __init__(self, *args, **kwargs):
        super(DQNAgent, self).__init__(*args, **kwargs)

        self.q_eval = DQNetwork(self.input_dims, self.n_actions, self.lr,
                                name=self.env_name + '_' + self.algo + '_q_eval',
                                chkpt_dir=self.chkpt_dir)
        self.q_next = DQNetwork(self.input_dims, self.n_actions, self.lr,
                                name=self.env_name + '_' + self.algo + '_q_next',
                                chkpt_dir=self.chkpt_dir)

    def choose_action(self, observation):
        if np.random.random() > self.epsilon:
            state = T.tensor([observation], dtype=T.float).to(self.q_eval.device)
            actions = self.q_eval.forward(state)
            action = T.argmax(actions).item()
        else:
            action = np.random.choice(self.action_space)
        return action

    def learn(self):
        if self.learn_step_counter < self.batch_size:
            self.learn_step_counter += 1
            return

        self.q_eval.optimizer.zero_grad()
        self.replace_target_network()
        states, actions, rewards, states_, dones = self.sample_memory()

        # Deep Q learning update
        indices = np.arange(self.batch_size)
        q_pred = self.q_eval(states)[indices, actions]
        q_next = self.q_next(states_).max(dim=1)[0]
        q_next[dones] = 0.0
        q_target = rewards + self.gamma * q_next
        loss = self.q_eval.loss(q_pred, q_target).to(self.q_eval.device)

        loss.backward()
        self.q_eval.optimizer.step()
        self.learn_step_counter += 1
        self.decrement_epsilon()