In [120]:
import gym
import numpy as np
import torch
from functools import namedtuple
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

In [132]:
env = gym.make('CartPole-v0')
info = namedtuple('info', ['next_state', 'reward', 'done', 'others'])
gamma = 0.8

In [133]:
class Network(nn.Module):
    def __init__(self, ):
        super(Network, self).__init__()
        self.fc1 = nn.Linear(4, 16)
        self.fc2 = nn.Linear(16, 2)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)


    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)      
        out = self.softmax(x)
        return out

In [None]:
test_r_ls = []

model = Network()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)

for i in tqdm(range(10000)):
    state_ls = []
    reward_ls = []
    gt_ls = []
    action_ls = []
    adv_ls = []

    s = env.reset()

    while True:
        state_ls.append(torch.Tensor(s))
        action_ls.append(env.action_space.sample())
        message = info(*env.step(action_ls[-1]))
        reward_ls.append(message.reward)

        if message.done:
            break

        s = message.next_state

    g = 0
    for i in reward_ls[::-1]:
        g = gamma * g + i
        gt_ls.append(g)

    gt_ls = gt_ls[::-1]

    g_mean = sum(gt_ls) / len(gt_ls)
    adv_ls = [x - g_mean for x in gt_ls]

    adv_ten = torch.Tensor(adv_ls).view(-1, 1).contiguous()
    state_ten = torch.cat(state_ls, dim=0).view(len(state_ls), -1).contiguous()
    action_ten = torch.LongTensor(action_ls).view(-1,1).contiguous()

    loss = model(state_ten)
    loss = torch.gather(loss, -1, action_ten)
    loss = -1 * torch.log1p(loss) * adv_ten
    loss = loss.sum()

    optim.zero_grad()
    loss.backward()
    optim.step()
    
    test_r = 0
    with torch.no_grad():
        s = env.reset()
        
        while True:
            a = model(torch.Tensor(s)).argmax().item()
            n_s, r_, done_, _ = env.step(a)
            test_r += r_

            if done_:
                break

            s = n_s
        
    test_r_ls.append(test_r)

plt.figure(figsize=(40, 8))
plt.plot(test_r_ls)
plt.show()

 83%|████████▎ | 8338/10000 [05:14<01:09, 23.85it/s]  