# Lab 02 : Gradient Policy Network - demo

Author: FAIR<br>
https://github.com/pytorch/examples/blob/master/reinforcement_learning/actor_critic.py

Cart pole dataset:<br>
https://github.com/openai/gym/wiki/CartPole-v0

In [1]:
import argparse
import gym
import numpy as np
from itertools import count

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

# Load the environment CartPole from OpenAI
env = gym.make('CartPole-v0')

# class dictionary
class DotDict(dict):
    def __init__(self, **kwds):
        self.update(kwds)
        self.__dict__ = self
        

In [2]:
# class of policy network
class Policy(nn.Module): # OK
    def __init__(self):
        super(Policy, self).__init__()
        self.affine1 = nn.Linear(4, 128)
        self.affine2 = nn.Linear(128, 2)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        x = F.relu(self.affine1(x))
        action_scores = self.affine2(x)
        return F.softmax(action_scores, dim=1)
    

In [3]:

# Function that samples an action (left or right move) from policy network
def select_action(state): 
    
    # state=s
    state = torch.from_numpy(state).float().unsqueeze(0)
    
    # probability of action a in state s
    probs = policy(state)
    
    # sample action a with Bernoulli sampling
    m = Categorical(probs)
    action = m.sample()
    
    # compute and store log probability of selected action
    policy.saved_log_probs.append(m.log_prob(action))
    
    return action.item()


# Function that compute the expected discounted reward when episode is done
# and backpropagate to update the policy network
eps = np.finfo(np.float32).eps.item()
def finish_episode():
    
    # initialize
    R = 0
    policy_loss = []
    rewards = []
    
    # compute the return at each time step (with backward for loop)
    for r in policy.rewards[::-1]: 
        R = r + args.gamma * R
        rewards.insert(0, R)
    rewards = torch.tensor(rewards)
    
    # center the rewards, and make the variance of rewards equal to 1
    rewards = (rewards - rewards.mean()) / (rewards.std() + eps)
    for log_prob, reward in zip(policy.saved_log_probs, rewards):
        policy_loss.append(-log_prob * reward)
        
    # backpropagate
    optimizer.zero_grad()
    policy_loss = torch.cat(policy_loss).sum()
    policy_loss.backward()
    optimizer.step()
    
    # delete rewards and log_probs to prepare for next episode
    del policy.rewards[:]
    del policy.saved_log_probs[:]
    

In [4]:
# hyper-parameters
args = DotDict()
args.gamma = 0.99
args.seed = 1
args.render = True
args.log_interval = 10
print(args)


# initialize the environment with the same seed/initialization value
env.seed(args.seed)
torch.manual_seed(args.seed)


# Instantiate the policy network
policy = Policy()


# Optimizer
optimizer = optim.Adam(policy.parameters(), lr=1e-2)


{'gamma': 0.99, 'seed': 1, 'render': True, 'log_interval': 10}


In [None]:
def main():
    
    running_reward = 10
    for i_episode in count(1): # consider multiple episodes
        
        state = env.reset() # reset the environment
        
        # draw an episode until it finishes or stop a t=10000. 
        for t in range(10000):  
            action = select_action(state) # select action=a from state=s
            state, reward, done, _ = env.step(action) # receive next state=s' and reward=r
            if args.render:
                env.render() # see the state
            policy.rewards.append(reward) # store immediate rewards 
            if done:
                break

        running_reward = running_reward * 0.99 + t * 0.01 # current reward = 99% of running reward
                                                          #  + 1% last episode length 
        
        finish_episode() # when episode is done, compute the expected discounted reward,
                         # and backpropagate to update the policy network
        
        # display results every log_interval=10 episodes
        if i_episode % args.log_interval == 0:
            print('Episode {}\tLast length: {:5d}\tAverage length: {:.2f}'.format(
                i_episode, t, running_reward))
        
        # stop the training when the reward is high enough
        if running_reward > env.spec.reward_threshold:
            print("Solved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(running_reward, t))
            break


            
main()           

In [6]:
env.close() # close the render window