In [21]:
import numpy as np
import torch
import torch.nn as nn
import gym
from copy import deepcopy

In [2]:
class ReplayBuffer:
    def __init__(self, action_dim = 1, state_dim, max_size, batch_size):
        self.max_size = max_size
        self.state_dim = state_dim
        self.other_dim = action_dim + 1 + 1
        self.batch_size = batch_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.state = np.empty((self.max_size, self.state_dim), dtype=np.float32)
        self.other = np.empty((self.max_size, self.other_dim), dtype=np.float32)
        self.size = 0
        self.current_index = 0
    
    def store(self, state, action, reward, done):
        self.state[self.current_index] = state
        self.other[self.current_index] = [action, reward, done]
        self.current_index = (self.current_index + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    
    def sample_batch(self):
        ptr = np.random.choice(self.size, self.batch_size)
        return (torch.FloatTensor(self.state[ptr]).to(self.device),
                torch.FloatTensor(self.state[ptr + 1]).to(self.device),
                    # TODO remove reshape
                torch.LongTensor(self.other[ptr, 0:1].reshape(-1, 1)).to(self.device),
                torch.FloatTensor(self.other[ptr, 1:2].reshape(-1, 1)).to(self.device),
                torch.FloatTensor(self.other[ptr, 2:].reshape(-1, 1)).to(self.device))
    
    def __len__(self):
        return self.size

In [5]:
class QNetwork(nn.Module):
    def __init__(self, state_dim, mid_dim, action_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, mid_dim),
            nn.ReLU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ReLU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ReLU(),
            nn.Linear(mid_dim, action_dim)
        )
    
    def forward(self, state):
        return self.network(state)

In [None]:
class Agent(object):
    def __init__(self, mid_dim = 256, env_name):
        self.learning_rate = 1e-4
        self.explore_rate = 0.1
        self.gamma = 0.99
        self.target_update = 300
        self.soft_update_tau = 5e-3
        self.update_step = 300
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.env = gym.make(env_name)
        self.state_dim = env.observation_space.shape[0]
        self.mid_dim = mid_dim
        self.action_dim = env.action_space.n
        self.max_size = 100000
        self.batch_size = 256
        
        self.network = QNetwork(self.state_dim, self.mid_dim, self.action_dim).to(self.device)
        self.target_network = deepcopy(self.network)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=self.learning_rate)
        self.criterion = torch.nn.MSELoss()
        
        self.replay_buffer = ReplayBuffer(self.state_dim, self.max_size, self.batch_size)
    
    def epsilon(self, episode):
        min_epsilon = 0.05
        max_epsilon = 1
        epsilon_decay = 800
        epsilon_episode = lambda episode : min_epsilon + np.exp(-episode / epsilon_decay)*0.95
        
        return epsilon_episode
        
    def select_action(self, episode, state):
        if np.random.random_sample() > epsilon(episode):
            return self.network(torch.FloatTensor(state).to(self.device)).argmax().detach().cpu().numpy()
        else:
            return self.env.action_space.sample()
    
    def update(self):
        for _ in range(self.update_step)
            with torch.no_grad():
                state, next_state, action, reward, done = self.replay_buffer.sample_batch()
                next_Q = self.target_network(next_state).max(dim = 1, keepdim=True)[0]
                target = reward + done * next_Q * self.gamma

            current_Q = self.network(state).gather(1, action)
            loss = self.criterion(current_Q, target)
            
            self.optimizer.zero_gard()
            loss.backward()
            self.optimizer.step()
            
            self.soft_update(self.target_network, self.network, self.soft_update_tau)
        
        return loss.item()
    
    @staticmethod
    def soft_update(target_net, current_net, tau):
        for tar, cur in zip(target_net.parameters(), current_net.parameters()):
            tar.data.copy_(cur.data.__mul__(tau) + tar.data.__mul__(1 - tau))
        
    def load_model(self):
        pass
    
    def save_model(self):
        pass
    

In [18]:
env_name = "CartPole-v0"
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n - 1
# replay_buffer = ReplayBuffer()
print(state_dim, action_dim)

4 1
