In [10]:
import numpy as np
from collections import deque


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

import gym

In [11]:
env = gym.make("CartPole-v1")
s_size = env.observation_space.shape[0]
a_size = env.action_space.n

class Policy(nn.Module):
    def __init__(self, s_size, a_size, h_size):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(s_size, h_size)
        self.fc2 = nn.Linear(h_size, a_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.softmax(x, dim=1)
    
    def act(self, state):
        state = torch.from_numpy(state).float().unsqueeze(0)
        probs = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        return action.item(), m.log_prob(action)

In [22]:
def reinforce(policy, optimizer, n_training_episodes, max_t, gamma, print_every):
    # Help us to calculate the score during the training
    scores_deque = deque(maxlen=100)
    scores = []
    # Line 3 of pseudocode
    for i_episode in range(1, n_training_episodes+1):
        saved_log_probs = []
        rewards = []
        state = env.reset()
        # Line 4 of pseudocode
        for t in range(max_t):
            action, log_prob = policy.act(state)
            saved_log_probs.append(log_prob)
            state, reward, done, _ = env.step(action)
            rewards.append(reward)
            if done:
                break 
        scores_deque.append(sum(rewards))
        scores.append(sum(rewards))
        
        # Line 6 of pseudocode: calculate the return
        ## Here, we calculate discounts for instance [0.99^1, 0.99^2, 0.99^3, ..., 0.99^len(rewards)]
        discounts = []
        for k in range(len(rewards)+1):
            discounts.append(gamma**k)
        ## We calculate the return by sum(gamma[t] * reward[t]) 
        rt = []
        for g,r in zip(discounts,rewards):
            rt.append(g*r)
        R = sum(rt)
        
        # Line 7:
        policy_loss = []
        for log_prob in saved_log_probs:
            policy_loss.append(log_prob)
        policy_loss = torch.cat(policy_loss).sum()
        
        # Line 8:
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()
        
        if i_episode % print_every == 0:
            print('Episode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_deque)))
        
    return policy

In [23]:
params = {
    "h_size": 16,
    "n_training_episodes": 3000,
    "n_evaluation_episodes": 10,
    "max_t": 1000,
    "gamma": 1,
    "lr": 1e-3,
    # "env_id": env_id,
    "state_space": s_size,
    "action_space": a_size,
}

policy_net = Policy(params['state_space'],params['action_space'],params['h_size'])
optimizer = optim.Adam(policy_net.parameters(), lr=params['lr'])


In [24]:
policy = reinforce(
    policy_net, 
    optimizer, 
    params['n_training_episodes'],
    params['max_t'],
    params['gamma'],
    100)

Episode 100	Average Score: 17.20
Episode 200	Average Score: 19.32
Episode 300	Average Score: 17.82
Episode 400	Average Score: 17.13
Episode 500	Average Score: 16.96
Episode 600	Average Score: 15.35
Episode 700	Average Score: 16.99
Episode 800	Average Score: 15.82
Episode 900	Average Score: 15.88
Episode 1000	Average Score: 15.49
Episode 1100	Average Score: 15.40
Episode 1200	Average Score: 15.86
Episode 1300	Average Score: 14.61
Episode 1400	Average Score: 16.07
Episode 1500	Average Score: 16.02
Episode 1600	Average Score: 15.90
Episode 1700	Average Score: 16.11
Episode 1800	Average Score: 16.71
Episode 1900	Average Score: 17.09
Episode 2000	Average Score: 16.36
Episode 2100	Average Score: 15.40
Episode 2200	Average Score: 17.03
Episode 2300	Average Score: 17.77
Episode 2400	Average Score: 16.96
Episode 2500	Average Score: 18.00
Episode 2600	Average Score: 21.69
Episode 2700	Average Score: 20.96
Episode 2800	Average Score: 19.99
Episode 2900	Average Score: 18.86
Episode 3000	Average Sc

In [26]:
total_rewards = 0
state = env.reset()
done = False
while not done:
    action,_ = policy.act(state)    
    new_state,reward,done,_ = env.step(action)
    total_rewards += reward
    state = new_state
    env.render()
    
print(total_rewards)
env.close()

