In [1]:
# model config params
N_HIDDEN_NODES = 64
LEARNING_RATE = 1e-3
DISCOUNT_FACTOR = 0.99
MAX_BUFFER_SIZE = 10_000
BATCH_SIZE = 64

In [2]:
from torch.nn import Linear, Module
from torch import Tensor


class DQN(Module):
    def __init__(self,
                 n_states: int,
                 n_actions: int,
                 n_hidden_nodes: int,
                 duelling: bool = False
                ):
        super(DQN, self).__init__()
        self.duelling = duelling
        self.l1 = Linear(n_states, n_hidden_nodes)
        self.l2 = Linear(n_hidden_nodes, n_hidden_nodes)
        self.l3 = Linear(n_hidden_nodes, n_actions)
        if duelling:
            self.l4 = Linear(n_hidden_nodes, 1) # state value estimation

    def forward(self, state: Tensor) -> Tensor:
        x = self.l1(state).relu()
        x = self.l2(x).relu()
        if self.duelling:
            return self.l4(x) + (self.l3(x) - self.l3(x).mean(dim=1, keepdim=True)[0])
        else:
            return self.l3(x)

In [3]:
from numpy import float32
from numpy.typing import NDArray
from typing import List, Tuple
from collections import deque


class ReplayBuffer:
    def __init__(self,
                 max_buffer_size: int
                ):
        self.buffer = deque(maxlen=max_buffer_size)


    def add(self,
            state: NDArray[float32],
            action: int,
            reward: float,
            next: NDArray[float32],
            terminal: bool
           ):
        self.buffer.append((state, action, reward, next, terminal))


    def sample(self, 
               batches: int
              ) -> Tuple[NDArray[float32], List[int], List[float], NDArray[float32], List[bool]]:
        samples = random.sample(self.buffer, batches)
        return zip(*samples)



In [4]:
import numpy as np

class PrioritisedReplayBuffer(ReplayBuffer):
    def __init__(self,
                 max_buffer_size: int,
                 bias_factor_start: float = 0.4,
                 bias_factor_end: float = 1.0,
                 bias_increment: float = 0.01,
                 priority_scale: float = 0.6
                ):
        super(PrioritisedReplayBuffer, self).__init__(max_buffer_size)
        self.priorities = deque(maxlen=max_buffer_size)
        self.min_priority = 0.01
        self.bias_factor = bias_factor_start
        self.bias_factor_end = bias_factor_end
        self.bias_increment = bias_increment
        self.priority_scale = priority_scale

    def add(self,
            state: NDArray[float32],
            action: int,
            reward: float,
            next: NDArray[float32],
            terminal: bool,
            error: float
           ):
        super(PrioritisedReplayBuffer, self).add(state, action, reward, next, terminal)
        self.priorities.append((np.abs(error) + self.min_priority)**self.priority_scale)

    def sample(self,
               batches: int
              ) -> Tuple[NDArray[float32], List[int], List[float], NDArray[float32], List[bool], List[int], List[float]]:
        probas = np.array(self.priorities)/np.sum(self.priorities)
        indices = np.random.choice(len(self.buffer), batches, p=probas)
        samples = [self.buffer[i] for i in indices]
        states, actions, rewards, nexts, terminals = zip(*samples)
        weights = (1/len(self.buffer) * 1/probas[indices]) ** self.bias_factor
        weights /= np.max(weights)
        return states, actions, rewards, nexts, terminals, indices, weights

    def update_priorities(self, idx, error):
        self.priorities[idx] = (np.abs(error) + self.min_priority) ** self.priority_scale

    def update_bias_factor(self):
        self.bias_factor = min(self.bias_factor + self.bias_increment, self.bias_factor_end)

In [12]:
import torch
from torch import cuda
from torch.backends import mps
from torch.optim import SGD, Adam
from torch.nn import MSELoss
from typing import List, Dict

def _get_torch_device() -> str:
    if cuda.is_available():
        return "cuda"
    elif mps.is_available():
        return "mps"
    else:
        return "cpu"


class DQNAgent:
    def __init__(self,
                 n_states: int,
                 n_actions: int,
                 n_hidden_nodes: int,
                 learning_rate: float,
                 discount_factor: float,
                 max_buffer_size: int,
                 batch_size: int,
                 modifications: List[str] = None,
                 per_params: Dict[str, float] = None
                ):
        self.device = torch.device(_get_torch_device())

        self.n_actions = n_actions
        self.discount_factor = discount_factor
        self.batch_size = batch_size

        self.qnet = DQN(n_states, n_actions, n_hidden_nodes, 'duelling' in modifications).to(self.device)
        self.modifications = modifications

        # self.optimiser = SGD(self.qnet.parameters(), lr=learning_rate)
        self.optimiser = Adam(self.qnet.parameters(), lr=learning_rate)
        if 'per' in modifications:
            self.replay_buffer = PrioritisedReplayBuffer(max_buffer_size, **per_params)
        else:
            self.replay_buffer = ReplayBuffer(max_buffer_size)
        self.target_qnet = DQN(n_states, n_actions, n_hidden_nodes, 'duelling' in modifications).to(self.device)
        self.target_qnet.load_state_dict(self.qnet.state_dict())
        self.target_qnet.eval()

    
    def get_td_error(self,
                     state: float32,
                     action: int,
                     reward: float,
                     next: float32,
                     terminal: bool
                    ):
        state_tensor = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        action_tensor = torch.tensor([action], device=self.device)
        reward_tensor = torch.tensor([reward], dtype=torch.float32, device=self.device)
        next_tensor = torch.from_numpy(next).float().unsqueeze(0).to(self.device)
        terminal_tensor = torch.tensor([terminal], device=self.device)
        curr_q = self.qnet(state_tensor).gather(1, action_tensor.unsqueeze(-1)).squeeze(-1)
        next_q = self.target_qnet(next_tensor).max(1)[0]
        expected_q = reward_tensor + self.discount_factor * next_q * (1 - terminal_tensor)
        return MSELoss()(current_q, expected_q).item()


    def _step_no_per(self,
             state: NDArray[float32],
             action: int,
             reward: float,
             next: NDArray[float32],
             terminal: bool
            ):
        self.replay_buffer.add(state, action, reward, next, terminal)
        if len(self.replay_buffer.buffer) > self.batch_size:
            self.update_model()


    def _step_per(self,
             state: NDArray[float32],
             action: int,
             reward: float,
             next: NDArray[float32],
             terminal: bool
            ):
        error = self.get_td_error(state, action, reward, next, terminal)
        self.replay_buffer.add(state, action, reward, next, terminal, error)
        if len(self.replay_buffer.buffer) > self.batch_size:
            self.update_model()


    def step(self,
             state: NDArray[float32],
             action: int,
             reward: float,
             next: NDArray[float32],
             terminal: bool
            ):
        if 'per' in self.modifications:
            self._step_per(state, action, reward, next, terminal)
        else:
            self._step_no_per(state, action, reward, next, terminal)


    def act(self,
            state: NDArray[float32],
            exploration_chance: float
           ) -> int:
        if random.random() > exploration_chance:
            self.qnet.eval()
            state_tensor_batched = torch.from_numpy(state).float().unsqueeze(0)
            state = state_tensor_batched.to(self.device)
            with torch.no_grad():
                action_values = self.qnet(state)
            chosen_action = np.argmax(action_values.cpu().detach().numpy())
            self.qnet.train()
            return chosen_action
        else:
            return random.choice(np.arange(self.n_actions))


    def update_model(self):
        if 'per' in self.modifications:
            states, actions, rewards, nexts, terminals, indices, weights = self.replay_buffer.sample(self.batch_size)
        else:
            states, actions, rewards, nexts, terminals = self.replay_buffer.sample(self.batch_size)

        states = torch.from_numpy(np.stack(states)).float().to(self.device)
        actions = torch.from_numpy(np.array(actions)).long().to(self.device)
        rewards = torch.from_numpy(np.array(rewards)).float().to(self.device)
        nexts = torch.from_numpy(np.stack(nexts)).float().to(self.device)
        terminals = torch.from_numpy(np.array(terminals)).float().to(self.device)

        q_values = self.qnet(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

        if 'double' in self.modifications:
            next_q_values = self.target_qnet(nexts).gather(1, self.qnet(nexts).max(1)[1].unsqueeze(-1)).detach().squeeze(-1)
        else:
            next_q_values = self.target_qnet(nexts).max(1)[0].detach()

        expected_q_values = rewards + self.discount_factor * next_q_values * (1 - terminals)

        loss = MSELoss()(q_values, expected_q_values)
        if 'per' in self.modifications:
            loss = (loss * torch.from_numpy(weights).float().to(self.device)).mean()

        self.optimiser.zero_grad()

        loss.backward()

        self.optimiser.step()

        if 'per' in self.modifications:
            for i, (state,
                      action,
                      reward,
                      next,
                      terminal
                     ) in enumerate(zip(states.cpu().numpy(),
                                        actions.cpu().numpy(),
                                        rewards.cpu().numpy(),
                                        nexts.cpu().numpy(),
                                        terminals.cpu().numpy())
                                   ):
                error = self.get_td_error(state, action, reward, next, terminal)
                self.replay_buffer.update_priorities(i, error)
            self.replay_buffer.update_bias_factor()


    def update_target_network(self):
        self.target_qnet.load_state_dict(self.qnet.state_dict())

In [8]:
# training config params

N_EPISODES = 10_000
EXPLORATION_CHANCE_START = 1.0
EXPLORATION_CHANCE_END = 1e-4
EXPLORATION_CHANCE_DECAY = 0.995
TARGET_UPDATE_FREQ = 10
FINISH_CHECK_FREQ = 100
FINISH_SCORE = 250
PER_PARAMS = {
     'bias_factor_start': 0.4,
     'bias_factor_end': 1.0,
     'bias_increment': 0.01,
     'priority_scale': 0.6
}

In [13]:
import gym
import random

def train(modifications) -> DQNAgent:
    env = gym.make('LunarLander-v2')
    n_states = env.observation_space.shape[0]
    n_actions = env.action_space.n

    agent = DQNAgent(n_states,
                     n_actions,
                     N_HIDDEN_NODES,
                     LEARNING_RATE,
                     DISCOUNT_FACTOR,
                     MAX_BUFFER_SIZE,
                     BATCH_SIZE,
                     modifications=modifications,
                     per_params=PER_PARAMS if 'per' in modifications else None
                    )

    scores = []
    latest_scores = deque(maxlen=FINISH_CHECK_FREQ)
    
    exploration_chance = EXPLORATION_CHANCE_START

    for episode_n in range(1, N_EPISODES + 1):
        state, _ = env.reset()
        score = 0
    
        while True:
            action = agent.act(state, exploration_chance)
            next, reward, terminated, truncated, info = env.step(action)
            
            terminal = terminated or truncated
            agent.step(state, action, reward, next, terminal)
            
            state = next
            score += reward
    
            if terminal:
                break
    
        scores.append(score)
        latest_scores.append(score)
    
        exploration_chance = max(EXPLORATION_CHANCE_END, EXPLORATION_CHANCE_DECAY * exploration_chance)
    
        if episode_n % TARGET_UPDATE_FREQ == 0:
            agent.update_target_network()
    
        if episode_n % FINISH_CHECK_FREQ == 0:
            print(f'Average score of {np.mean(latest_scores):0.3f} @ {episode_n}/{N_EPISODES}')
            if np.mean(latest_scores) >= FINISH_SCORE:
                print(f'Average score was above {FINISH_SCORE} over last {FINISH_CHECK_FREQ} episodes. Ending training...')
                break
    env.close()
    return agent

agent = train(['double', 'duelling', 'per'])

RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.

In [None]:
def visual_run_agent(agent):
    env = gym.make('LunarLander-v2', render_mode='human')

    state, _ = env.reset()
    
    score = 0
    steps = 0
    
    while True:
        action = agent.act(state, 0)
        state, reward, terminated, truncated, info = env.step(action)
        steps += 1
    
        terminal = terminated or truncated
    
        score += reward
    
        if terminal:
            break
    
    print(f'Score achieved on test: {score}')
    print(f'Steps taken until termination: {steps}')
    env.close()

visual_run_agent(agent)