## REINFORCE Example (Cartpole)

Applied from Algorithm 2.1 in "Foundations of Deep Reinforcement Learning"

In [None]:
import sys
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

### Policy Calculation

In [None]:
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, h_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.policy_nn = nn.Sequential(
            nn.Linear(state_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, action_dim)
        )

    def forward(self, state):
        a_logits = self.policy_nn(state)
        return a_logits

def sample_action(policy_nn, state):
    a_logits = policy_nn(torch.tensor(state, dtype=torch.float32))
    a_dist = Categorical(logits=a_logits)
    action = a_dist.sample()
    a_log_prob = a_dist.log_prob(action) # saved for backpropagation
    return action.item(), a_log_prob

### Value Calculation

In [None]:
def mc_cumulative_discounted_reward(trajectory, gamma=.99):
    s, a, a_lp, r, s_next, a_next, d = trajectory
    c = 0
    cdr = []
    for i in reversed(range(len(r))):
        c = r[i] + gamma * (1 - d[i]) * c
        cdr.append(c)
    cdr = np.array(list(reversed(cdr)))
    return s, a, a_lp, r, s_next, a_next, d, cdr

### Trajectory Generation

In [None]:
def sample_trajectory(env, policy_fn, max_steps=200, break_when_done=True):
    s, _ = env.reset()

    break_next = False
    gym_trajectory = []
    for t in range(max_steps):
        a, a_logprob = policy_fn(s)
        s_next, r, done, _, _ = env.step(a)
        d = int(done)
        gym_trajectory.append([list(s), a, a_logprob, r, d])
        if break_next and break_when_done:
            break
        if done:
            break_next = True
        s = s_next

    sarsa_trajectory = []
    for t in range(1, len(gym_trajectory)):
        s, a, a_logprob, r, d = gym_trajectory[t-1]
        s_next, a_next, _, _, _ = gym_trajectory[t]
        sarsa_trajectory.append([
            s, a, a_logprob, r, s_next, a_next, d
        ])

    s, a, a_logprob, r, s_next, a_next, d = list(zip(*sarsa_trajectory))
    trajectory = [
        np.array(s), np.array(a),
        None if a_logprob[0] is None else torch.stack(a_logprob),
        np.array(r),
        np.array(s_next), np.array(a_next),
        np.array(d)
    ]
    return trajectory

### Training

In [None]:
class REINFORCE:
    def __init__(self, state_dim, action_dim, h_dim=64, lr_alpha=.01):
        self.p_network = PolicyNetwork(state_dim, h_dim, action_dim)
        self.optimizer = optim.Adam(self.p_network.parameters(), lr=lr_alpha)
        
    def train(self, batch):
        s, a, a_logprob, r, s_next, a_next, d, cdr = batch
        loss = -torch.sum(torch.tensor(cdr) * a_logprob)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()

In [None]:
reinforce = REINFORCE(state_dim=4, action_dim=2)

policy_fn = lambda state: sample_action(reinforce.p_network, state)
value_fn = lambda trajectory: mc_cumulative_discounted_reward(trajectory)

env = gym.make("CartPole-v0")
count = 0
for episode in range(1, 1001):
    trajectory = sample_trajectory(env, policy_fn)    
    trajectory = value_fn(trajectory)
    loss = reinforce.train(trajectory)
    
    trajectory_len = len(trajectory[0])
    print("%5d %5d %5d %7.2f" % (episode, trajectory_len, count, loss))

    
    if trajectory_len == 199:
        count += 1
        if count > 5:
            break

### Cartpole Visualization

In [None]:
%env HV_DOC_HTML=true # required for colab
import holoviews as hv

def cartpole_visualize(states, step=4):
    hv.extension("bokeh") # required for colab
    states = states[::step]
    def cartpole2xy(states):
        xy = []
        positions, angles = states[:,0], states[:,2]
        for i in range(len(positions)):
            position, angle = positions[i], angles[i]
            x0, y0 = position, 0
            x1, y1 = position+np.sin(angle), np.cos(angle)
            xy.append([(x0, y0), (x1, y1)])
        return xy
    
    def cartpole_draw(curve):
        baseline = [(-2.4, 0), (2.4, 0)]
        return hv.Overlay([
            hv.Curve(baseline), hv.Curve(curve)
        ]).redim(
            x=hv.Dimension("x", range=(-2.4, 2.4)),
            y=hv.Dimension("y", range=(-.05, 1.05))
        ).opts(height=150, width=400)
    
    charts = []
    xy = cartpole2xy(states)
    for i in range(len(xy)):
        _ = cartpole_draw(xy[i])
        charts.append(_)
        
    holomap = hv.HoloMap({i*step:charts[i] for i in range(len(charts))})
    return holomap

cartpole_visualize(trajectory[0])