In [76]:
from collections import defaultdict, deque
import random

import numpy as np
import torch

class ReplayMemory:
    def __init__(self, size):
        self.current_size = 0
        self.queue = deque(maxlen=size)
        
    def _get_current_size(self):
        return self.current_size
    
    def can_sample(self, size):
        return self.current_size < size
    
    def store(self, transition):
        self.current_size += 1
        self.queue.enqueue(transition)
        
    def sample(self, size):
        if not self.can_sample(size):
            raise Exception('Cannot sample, not enough experience')
            
        return random.sample(self.queue, size)

class DQN:
    def __init__(
        self,
        env,
        target_net,
        update_net,
        optimiser, 
        loss_func,
        w_sync_freq=10,
        batch_size=10,
        memory_size=5000,
        gamma=0.95,
        step_size=0.001,
        episodes=1000,
        eval_episodes=50,
        epsilon_start=0.3,
        epsilon_decay=0.9996,
        epsilon_min=0.01,
        negative_rewards=[-0.75, -0.85, -15.0],
    ):
        self.env = env
        self.gamma = np.float64(gamma)
        self.n_states = self.env.observation_space.n
        self.states = self.env.states
        self.n_actions = self.env.action_space.n
        self.actions = self.env.actions
        self.update_net = update_net
        self.target_net = target_net
        self.memory_size = memory_size
        self.replay_memory = ReplayMemory(size=memory_size)
        self.optimiser = optimiser
        self.loss_func = loss_func
        self.step_size = step_size
        self.episodes = episodes
        self.epsilon_start = epsilon_start
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.negative_rewards = negative_rewards
        self.eval_episodes = eval_episodes
        
        # initialize action-value function
        self.Q = defaultdict(
            lambda: np.zeros(self.n_actions),
        )
        
        # initialize traning logs
        self.logs = defaultdict(
            lambda: {
                'reward': 0,
                'cumulative_reward': 0,
                'epsilon': None
            },
        )
        
        #initialize evaluation logs
        self.eval_logs = defaultdict(
            lambda: {
                'reward': 0,
                'cumulative_reward': 0,
                'epsilon': None
            },
        )
       
    @staticmethod
    def _clip_reward(reward):
        return (2 * (
            reward - self.env.min_reward
        ) / (self.env.max_reward - self.env.min_reward)) - 1
    
    def _get_action_probs(self, state, epsilon):
        state = torch.FloatTensor(state)
        # initialize episilon probability to all the actions
        probs = np.ones(self.n_actions) * (epsilon / self.n_actions)
        print(f'state: {state}, type: {type(state)}')
        # CHANGE
        action_values = self.update_net.forward(state)
        best_action = torch.max(action_values, 1)[1].data.numpy()
        best_action = best_action[0] if ENV_A_SHAPE == 0 else best_action.reshape(ENV_A_SHAPE)
        # initialize 1-epsilon probability to the greedy action
        probs[best_action] = 1 - epsilon + (epsilon / self.n_actions)
        return probs
        
    def _get_action(self, state, epsilon):
        action = np.random.choice(
            self.actions, 
            p=self._get_action_probs(
                state,
                epsilon,
            ),
        ) 
        
        return action, self.actions.index(action)
    
    def _store_transition(self, transition):
        self.replay_memory.store(transition)
        
    def _train_one_batch(self, transitions, epsilon):
        states, actions, rewards, next_states = transitions
        
        Q_states = self.update_net(states).gather(1, actions)
        Q_targets = rewards + self.gamma * self.target_net(next_states).detach().max(1)[0]
        
        loss = self.loss_func(Q_states, Q_targets)
        self.optimizer.zero_grad()
        loss.backwards(retain_variables = True)
        self.optimizer.step()
        
    def _sync_weights(self):
        self.target_net.load_state_dict(self.update_net.state_dict())
        
    def run(self):
        epsilon = self.epsilon_start
        for episode_no in range(self.episodes):
            print(f'Episode: {episode_no}')
            epsilon = max(epsilon*self.epsilon_decay, self.epsilon_min)
            episode_ended = False
            self.logs[episode_no]['epsilon'] = epsilon
            episode_reward = 0
            state = self.env.reset()
            
            while not episode_ended:
                action, action_idx = self._get_action(state, epsilon)
                _, reward, goal, next_state, episode_ended = self.env.step(action=action)
                
                episode_reward += reward
                # ensure gradients are well conditioned 
                clipped_reward = self._clip_reward(reward)
                
                self._store_transition(
                    [torch.FloatTensor(_) for _ in [state, action_idx, clipped_reward, next_state]]
                )
                
                if self.replay_memory.can_sample(size=self.batch_size):
                    transitions = self.replay_memory.sample(size=self.batch_size)
                    self._train_one_batch(transitions, epsilon)
                
            
            if episode_no % self.w_sync_freq == 0:
                self._sync_weights()
            
            # save logs for analysis
            self.logs[episode_no]['reward'] = episode_reward
            if episode_no > 0:
                self.logs[episode_no]['cumulative_reward'] += \
                self.logs[episode_no-1]['cumulative_reward']
        