In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# hyperparameters
lr    = 2e-4
gamma = 0.98

In [None]:
class Policy(nn.Module):
  def __init__(self):
    super(Policy, self).__init__()
    self.data = []

    self.fc = nn.Sequential(
        nn.Linear(4,128),
        nn.ReLU(),
        nn.Linear(128,2),
        nn.Softmax(dim=0)
    )

    self.optimizer = optim.Adam(self.parameters(), lr)

  def forward(self, x):
    return self.fc(x)

  def put_data(self, item):
    self.data.append(item)

  def train_net(self):
    R = 0
    self.optimizer.zero_grad()
    for r,prob in self.data[::-1]:
      R = r + gamma*R
      loss = -R * torch.log(prob)
      loss.backward()
    self.optimizer.step()
    self.data = []

In [None]:
def main():
  env = gym.make('CartPole-v1')
  pi = Policy().to(DEVICE)

  score = 0.0
  print_interval = 20

  for n_epi in range(10000):
    s = env.reset()
    done = False

    while not done:
      prob = pi(torch.from_numpy(s).float().to(DEVICE))
      m = Categorical(prob)
      a = m.sample()

      s_prime,r,done,_ = env.step(a.item())
      pi.put_data([r,prob[a]])
      s = s_prime
      score += r

    pi.train_net()
    if (n_epi%print_interval == 0) and (n_epi != 0):
      print(f"# of episode: {n_epi}, avg score: {score/print_interval}")
      score = 0.0

  env.close()

In [None]:
main()