In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

# Exploration Policy

In [2]:
class ExplorationPolicy:
    def __init__(self, epsilon=1.0, epsilon_min=0.01, epsilon_dec=0.00001):
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_dec = epsilon_dec

    def explore(self):
        should_explore = random.random() < self.epsilon
        self.epsilon = max(self.epsilon-self.epsilon_dec, self.epsilon_min)
        return should_explore

# Replay Buffer

In [None]:
class ReplayBuffer:

    def __init__(self, max_size, batch_size, input_shape):
        self.memory = collections.deque(maxlen=max_size)
        self.batch_size = batch_size
        self.input_shape = input_shape
    
    def save(self, state, action, reward, state_, done):
        self.memory.append((state, action, reward, state_, done))
    
    def sample(self):
        sample_size = min(self.batch_size, len(self.memory))
        sample = random.sample(self.memory, sample_size)
        state, action, reward, state_, done = zip(*sample)
        state = np.array(state, dtype=np.float32)
        action = np.array(action, dtype=np.int32)
        reward = np.array(reward, dtype=np.float32)
        state_ = np.array(state_, dtype=np.float32)
        done = np.array(done, dtype=np.bool)
        return state, action, reward, state_, done

# Agent

In [None]:
class Agent:

    def __init__(
        self, 
        gamma=0.99 , lr=0.001,
        input_shape, n_actions, 
        epsilon, epsilon_min, epsilon_dec,
        mem_size, batch_size,
        model=DQN, buffer=ReplayBuffer, exploration_policy=ExplorationPolicy,
        replace_rate,
        ):
        self.gamma = gamma
        self.lr = lr
        
        self.n_actions = n_actions
        self.input_shape = input_shape
        self.action_space = [i for i in range(n_actions)]
        
        self.replace_rate = replace_rate
        self.lern_step_cnt = 0
        
        self.exploration_policy = exploration_policy(epsilon, epslion_min, epsilon_dec)
        self.buffer = buffer(max_size, batch_size, input_shape)
        
        self.q_eval = model(input_shape, n_actions, lr, name = "_eval")
        self.q_next = model(input_shape, n_actions, lr, name = "_next")

    def act(self.state):
        if self.exploration_policy.explore():
            action = np.random.choice(self.action_space)
        else :
            state = torch.tensor([state], dtype=torch.float32).to_device(self.q_eval.device)
            action = self.q_eval(state).argmax(dim=1).item()
        return action

    def save(self, state, action, reward, state_, done):
        self.buffer.save(state, action, reward, state_, done)
    
    def sample(self):
        state, action, reward, state_, done = self.buffer.sample()
        
        state = torch.tensor(state, dtype=torch.float32)
        action = torch.tensor(action, dtype=torch.int32)
        reward = torch.tensor(reward, dtype=torch.float32)
        state_ = torch.tensor(state_, dtype=torch.float32)
        done = torch.tensor(done, dtype=np.bool)

        return state, action, reward, state_, done

    def replace_target_network(self):
        self.learn_step_cnt += 1
        if self.learn_step_cnt == self.replace_rate:
            self.learn_step_cnt = 0
            self.q_next.load_state_dict(q_eval.state_dict())

    def learn(self):
        self.q_eval.optimizer.zero_grad()

        self.replace_target_network()

        states, actions, rewards, states_, dones = self.sample()

        q_pred = self.q_eval.forward(states)[:, actions]
        q_next = self.q_next.forward(states_).max(dim=1)[0]

        q_next[dones] = 0.0 # for not updating rewards when game is done
        q_target = rewards + self.gamma*q_next

        loss = self.q_eval.loss(q_target, q_pred).to(self.q_eval.device)
        
        loss.backward()

        self.q_eval.optimizer.step()

In [None]:
class DQN(nn.Module):
    
    def __init__(self, inputs, outputs, lr):
        super(DQN, self).__init__()
        
        self.fc1 = nn.Linear(inputs, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, outputs)

        self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
        self.loss = nn.MSELoss()

        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.to(self.device)
    
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = F.relu(self.fc3(out))
        out = self.fc4(out)
        return out

# Data