In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# hyperparameters
lr        = 2e-4
gamma     = 0.98
n_rollout = 10

In [None]:
class AdvantageActorCritic(nn.Module):
  def __init__(self):
    super(AdvantageActorCritic, self).__init__()
    self.data = []

    self.fc1 = nn.Linear(4,256)
    self.fc_pi = nn.Linear(256,2)
    self.fc_v = nn.Linear(256,1)
    self.optimizer = optim.Adam(self.parameters(), lr)

  def pi(self, x, softmax_dim = 0):
    x = F.relu(self.fc1(x))
    x = self.fc_pi(x)
    prob = F.softmax(x, dim=softmax_dim)
    return prob

  def v(self, x):
    x = F.relu(self.fc1(x))
    v = self.fc_v(x)
    return v

  def put_data(self, transition):
    self.data.append(transition)

  def make_batch(self):
    s_list,a_list,r_list,s_prime_list,done_list = [],[],[],[],[]

    for transition in self.data:
      s,a,r,s_prime,done = transition

      s_list.append(s)
      a_list.append([a])
      r_list.append([r/100.0])
      s_prime_list.append(s_prime)
      done_mask = 0.0 if done else 1.0
      done_list.append([done_mask])

    s_batch = torch.tensor(s_list, device=DEVICE, dtype=torch.float)
    a_batch = torch.tensor(a_list, device=DEVICE)
    r_batch = torch.tensor(r_list, device=DEVICE, dtype=torch.float)
    s_prime_batch = torch.tensor(s_prime_list, device=DEVICE, dtype=torch.float)
    done_batch = torch.tensor(done_list, device=DEVICE, dtype=torch.float)

    self.data = []
    return s_batch, a_batch, r_batch, s_prime_batch, done_batch

  def train_net(self):
    s,a,r,s_prime,done = self.make_batch()

    td_target = r + gamma*self.v(s_prime)*done
    delta = td_target - self.v(s)

    pi = self.pi(s, softmax_dim=1)
    pi_a = pi.gather(1,a)
    loss = -torch.log(pi_a)*delta.detach() + F.smooth_l1_loss(self.v(s),td_target.detach()) # Actor(pi) loss + Critic(V) loss

    self.optimizer.zero_grad()
    loss.mean().backward()
    self.optimizer.step()

In [None]:
def main():
  env = gym.make('CartPole-v1')
  model = AdvantageActorCritic().to(DEVICE)

  print_interval = 20
  score = 0.0

  for n_epi in range(10000):
    done = False
    s = env.reset()

    while not done:
      for t in range(n_rollout):
        prob = model.pi(torch.from_numpy(s).float().to(DEVICE))
        m = Categorical(prob)
        a = m.sample().item()

        s_prime,r,done,_ = env.step(a)
        model.put_data([s,a,r,s_prime,done])

        s = s_prime
        score += r

        if done:
          break

      model.train_net()

    if (n_epi%print_interval == 0) and (n_epi != 0):
      print(f"# of episode :{n_epi}, avg score : {score/print_interval:.1f}")
      score = 0.0

  env.close()

In [None]:
main()