In [2]:
from torch.distributions import Categorical
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
gamma = 0.99
episodes = 500

In [4]:
class Pi(nn.Module):
    # Pi constructs the policy network that is a simple one-layer MLP with 64 hidden units 
    def __init__(self, in_dim, out_dim):
        super(Pi, self).__init__()
        layers = [
            nn.Linear(in_dim, 64),
            nn.ReLU(),
            nn.Linear(64, out_dim)
        ]
        self.model = nn.Sequential(*layers)
        self.onpolicy_reset()
        self.train()
    def onpolicy_reset(self):
        self.log_probs = []
        self.rewards = []

    def forward(self, x):
        pdparam = self.model(x)
        return pdparam

    # act defines the method to produce action
    def act(self, state):
        x = torch.from_numpy(state.astype(np.float32)) # to tensor
        pdparam = self.forward(x) # forward pass
        pd = Categorical(logits = pdparam) #probability distribution
        action = pd.sample() # pi (a|s) in action via pd
        log_prob = pd.log_prob(action) #log_prob of pi(a|s)
        self.log_probs.append(log_prob) #store for training
        return action.item()

In [5]:
def train(pi, optimizer):
    # Inner gradient-ascent loop of REINFORCE algorithm
    T = len(pi.rewards)
    rets = np.empty(T, dtype = np.float32) # the returns
    future_ret = 0.0
    # compute the returns efficiently
    for t in reversed(range(T)):
        future_ret = pi.rewards[t] + gamma * future_ret
        rets[t] = future_ret
    rets = torch.tensor(rets)
    log_probs = torch.stack(pi.log_probs)
    loss = -log_probs * rets # gradient term; NEGATIVE for maximizing
    loss = torch.sum(loss)
    optimizer.zero_grad()
    loss.backward() # backpropagate, compute gradients
    optimizer.step() # gradient-ascent, update the weights
    return loss

In [6]:
def main():
    env = gym.make('CartPole-v0')
    in_dim = env.observation_space.shape[0] # 4
    out_dim = env.action_space.n # 2
    pi = Pi(in_dim, out_dim) # policy pi_theta for REINFORCE
    optimizer = optim.Adam(pi.parameters(), lr = 0.01)
    for epi in range(episodes):
        state = env.reset()
        for t in range(200): # cartpole max timestep is 200
            action = pi.act(state)
            state, reward, done, _ = env.step(action)
            pi.rewards.append(reward)
            env.render()
            if done:
                break
        loss = train(pi, optimizer) # train per episode
        total_reward = sum(pi.rewards)
        solved = total_reward > 195.0
        pi.onpolicy_reset() # onpolicy: clear memory after training
        print(f'Episode {epi}, loss: {loss}, \
        tota_reward: {total_reward}, solved: {solved}')

In [7]:
if __name__ == '__main__':
    main()

Episode 0, loss: 147.98931884765625,         tota_reward: 21.0, solved: False
Episode 1, loss: 145.2624053955078,         tota_reward: 21.0, solved: False
Episode 2, loss: 157.56776428222656,         tota_reward: 22.0, solved: False
Episode 3, loss: 47.44414138793945,         tota_reward: 12.0, solved: False
Episode 4, loss: 309.7136535644531,         tota_reward: 31.0, solved: False
Episode 5, loss: 111.8620376586914,         tota_reward: 19.0, solved: False
Episode 6, loss: 133.58294677734375,         tota_reward: 20.0, solved: False
Episode 7, loss: 130.39015197753906,         tota_reward: 20.0, solved: False
Episode 8, loss: 489.1654052734375,         tota_reward: 41.0, solved: False
Episode 9, loss: 238.98651123046875,         tota_reward: 28.0, solved: False
Episode 10, loss: 50.865577697753906,         tota_reward: 12.0, solved: False
Episode 11, loss: 137.66729736328125,         tota_reward: 21.0, solved: False
Episode 12, loss: 203.55758666992188,         tota_reward: 26.0, so

Episode 105, loss: 1447.0069580078125,         tota_reward: 84.0, solved: False
Episode 106, loss: 2159.10888671875,         tota_reward: 132.0, solved: False
Episode 107, loss: 3911.618896484375,         tota_reward: 200.0, solved: True
Episode 108, loss: 2595.260498046875,         tota_reward: 143.0, solved: False
Episode 109, loss: 4647.32421875,         tota_reward: 191.0, solved: False
Episode 110, loss: 3264.12109375,         tota_reward: 188.0, solved: False
Episode 111, loss: 2720.032470703125,         tota_reward: 155.0, solved: False
Episode 112, loss: 3266.90673828125,         tota_reward: 172.0, solved: False
Episode 113, loss: 4184.08203125,         tota_reward: 200.0, solved: True
Episode 114, loss: 1790.2537841796875,         tota_reward: 128.0, solved: False
Episode 115, loss: 3786.34130859375,         tota_reward: 200.0, solved: True
Episode 116, loss: 3303.7734375,         tota_reward: 159.0, solved: False
Episode 117, loss: 2928.91015625,         tota_reward: 179.0, 

Episode 210, loss: 6069.39599609375,         tota_reward: 200.0, solved: True
Episode 211, loss: 4875.95458984375,         tota_reward: 177.0, solved: False
Episode 212, loss: 6055.115234375,         tota_reward: 200.0, solved: True
Episode 213, loss: 5973.6103515625,         tota_reward: 200.0, solved: True
Episode 214, loss: 3909.597412109375,         tota_reward: 150.0, solved: False
Episode 215, loss: 2481.341552734375,         tota_reward: 117.0, solved: False
Episode 216, loss: 3799.739501953125,         tota_reward: 140.0, solved: False
Episode 217, loss: 3081.357666015625,         tota_reward: 130.0, solved: False
Episode 218, loss: 47.93174743652344,         tota_reward: 13.0, solved: False
Episode 219, loss: 1900.9716796875,         tota_reward: 99.0, solved: False
Episode 220, loss: 984.9493408203125,         tota_reward: 67.0, solved: False
Episode 221, loss: 87.2222900390625,         tota_reward: 18.0, solved: False
Episode 222, loss: 37.24396896362305,         tota_reward

Episode 315, loss: 138.868408203125,         tota_reward: 29.0, solved: False
Episode 316, loss: 77.98979187011719,         tota_reward: 22.0, solved: False
Episode 317, loss: 75.33306884765625,         tota_reward: 23.0, solved: False
Episode 318, loss: 174.49082946777344,         tota_reward: 29.0, solved: False
Episode 319, loss: 57.90510559082031,         tota_reward: 21.0, solved: False
Episode 320, loss: 185.66781616210938,         tota_reward: 24.0, solved: False
Episode 321, loss: 137.16363525390625,         tota_reward: 24.0, solved: False
Episode 322, loss: 96.50952911376953,         tota_reward: 22.0, solved: False
Episode 323, loss: 132.90982055664062,         tota_reward: 25.0, solved: False
Episode 324, loss: 46.78466033935547,         tota_reward: 13.0, solved: False
Episode 325, loss: 69.84103393554688,         tota_reward: 21.0, solved: False
Episode 326, loss: 138.7362823486328,         tota_reward: 30.0, solved: False
Episode 327, loss: 98.54566955566406,         tot

Episode 420, loss: 563.1002197265625,         tota_reward: 59.0, solved: False
Episode 421, loss: 1069.60400390625,         tota_reward: 87.0, solved: False
Episode 422, loss: 1888.4443359375,         tota_reward: 110.0, solved: False
Episode 423, loss: 161.18228149414062,         tota_reward: 30.0, solved: False
Episode 424, loss: 97.60592651367188,         tota_reward: 29.0, solved: False
Episode 425, loss: 126.26350402832031,         tota_reward: 31.0, solved: False
Episode 426, loss: 588.6298828125,         tota_reward: 48.0, solved: False
Episode 427, loss: 5527.61962890625,         tota_reward: 200.0, solved: True
Episode 428, loss: 5226.52978515625,         tota_reward: 200.0, solved: True
Episode 429, loss: 1832.986328125,         tota_reward: 103.0, solved: False
Episode 430, loss: 4189.787109375,         tota_reward: 200.0, solved: True
Episode 431, loss: 5988.828125,         tota_reward: 200.0, solved: True
Episode 432, loss: 5470.640625,         tota_reward: 200.0, solved: 