In [None]:
import gym
from collections import deque
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# hyperparameters
lr            = 5e-4
gamma         = 0.98
buffer_limit  = 50000
batch_size    = 32

In [None]:
class ReplayBuffer():
  def __init__(self):
    self.buffer = deque(maxlen=buffer_limit)

  def put(self, transition):
    self.buffer.append(transition)

  def sample(self, n):
    mini_batch = random.sample(self.buffer, n)
    s_list, a_list, r_list, s_prime_list, done_mask_list = [], [], [], [], []

    for transition in mini_batch:
      s,a,r,s_prime,done_mask = transition

      s_list.append(s)
      a_list.append([a])
      r_list.append([r])
      s_prime_list.append(s_prime)
      done_mask_list.append([done_mask])

    return torch.tensor(s_list, dtype=torch.float, device=DEVICE),\
           torch.tensor(a_list, device=DEVICE), torch.tensor(r_list, device=DEVICE),\
           torch.tensor(s_prime_list, dtype=torch.float, device=DEVICE),\
           torch.tensor(done_mask_list, device=DEVICE)

  def size(self):
    return len(self.buffer)

In [None]:
class Qnet(nn.Module):
  def __init__(self):
    super(Qnet, self).__init__()
    self.fc = nn.Sequential(
        nn.Linear(4,128),
        nn.ReLU(),
        nn.Linear(128,128),
        nn.ReLU(),
        nn.Linear(128,2)
    )

  def forward(self, x):
    return self.fc(x)

  def sample_action(self, obs, epsilon):
    out = self.forward(obs)
    coin = random.random()
    if coin < epsilon:
      return random.randint(0,1)
    else:
      return out.argmax().item()

In [None]:
def train(q, q_target, memory, optimizer):
  for i in range(10):
    s,a,r,s_prime,done_mask = memory.sample(batch_size)

    q_out = q(s)
    q_a = q_out.gather(1,a)
    max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
    target = r + gamma*max_q_prime*done_mask
    loss = F.smooth_l1_loss(q_a, target)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
def main():
  env = gym.make('CartPole-v1')
  q = Qnet().to(DEVICE)
  q_target = Qnet().to(DEVICE)
  q_target.load_state_dict(q.state_dict())
  memory = ReplayBuffer()

  print_interval = 10
  score = 0.0
  optimizer = optim.Adam(q.parameters(), lr)

  for n_epi in range(1000):
    epsilon = max(0.01, 0.08 - 0.01*(n_epi/200))
    s = env.reset()
    done = False

    while not done:
      a = q.sample_action(torch.from_numpy(s).float().to(DEVICE), epsilon)
      s_prime,r,done,_ = env.step(a)
      done_mask = 0.0 if done else 1.0
      memory.put((s,a,r/100.0,s_prime,done_mask))
      s = s_prime

      score += r
      if done:
        break

    if 2000 < memory.size():
      train(q,q_target,memory,optimizer)

    if (n_epi%print_interval == 0) and (n_epi != 0):
      q_target.load_state_dict(q.state_dict())
      print(f"n_epi: {n_epi}, score: {score/print_interval:.1f}, n_buffer: {memory.size()}, eps: {epsilon*100:.1f}%")
      score = 0.0

    env.close()

In [None]:
main()