# Train Pendulum problem(for continuous output) in OpenAI Env using Actor Critic method

In [162]:
import math
import random

import gym
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [147]:
env = gym.make("Pendulum-v0")

episodes = 1000
batch_size = 20
gamma = 0.99
tau = 0.95
goal_steps = 200
input_shape = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
# print(input_shape)
buffer_capacity = 200
epochs = 2
clip_param = 0.2

In [148]:
# Memory to save the experiences 
class Buffer(object):
    def __init__(self):
        self.buffer = []
        self.buffer_capacity = 1000
        self.batch = 32
        self.counter = 0
        self.check = 1
        
    def add(self, params):
        self.buffer.append(params)
            
    def reinit(self):
        self.buffer = []
        
    def length(self):
        return len(self.buffer)

In [167]:
# Network for Actor and Critic
class ActorCritic(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(ActorCritic, self).__init__()
        
        self.fc1 = nn.Linear(input_shape, 512)
        self.fc2 = nn.Linear(512, 1)
        
        self.fc = nn.Linear(input_shape, 512)
        self.mu_head = nn.Linear(512,num_actions)
        
        self.log_std = nn.Parameter(torch.zeros(num_actions))
        
    def forward(self, x):
        c1 = F.relu(self.fc1(x))
        value = self.fc2(c1)
        
        x = F.relu(self.fc(x))
        mu = self.mu_head(c1)
        std   = self.log_std.exp().expand_as(mu)
        dist  = Normal(mu, std)
        return dist, value

In [161]:
def compute_returns(next_value, rewards, masks):
    R = next_value
    returns = []
    for step in reversed(range(len(rewards))):
        R = rewards[step] + gamma * R * masks[step]
        returns.insert(0, R)
    return returns

# Returns can also be calculated by Generalized Advantage Estimation method

def compute_gae(next_value, rewards, masks, values):
    values = torch.cat((values, next_value), 0)
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns
    
def update(next_s, entropy):
    next_s = torch.FloatTensor(next_s)
    _, next_value = model(next_s)
    
    mem = memory.buffer
    v = torch.FloatTensor([m[0] for m in mem])
    log_probs = torch.FloatTensor([m[1] for m in mem])
    r = torch.FloatTensor([m[2] for m in mem])
    mask = torch.FloatTensor([m[3] for m in mem])
    
#     _, v = model(s)
    returns = torch.FloatTensor(compute_gae(next_value, r, mask, v))
    adv = returns - v
    
    for _ in range(epochs):
        for id in BatchSampler(SubsetRandomSampler(range(len(r))), batch_size, False):
            actor_loss  = -(log_probs[id] * adv[id]).mean()
            critic_loss = adv[id].pow(2).mean()

            loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy
            opt_a.zero_grad()
            loss.backward(retain_graph=True)
            opt_a.step()

In [168]:
model = ActorCritic(input_shape, num_actions)

opt_a = optim.Adam(model.parameters())
memory = Buffer()

In [170]:
state = env.reset()
for idx in range(episodes):
    state = env.reset()
    score = 0
    entropy = 0
    done = False
    while not done:
        state = torch.FloatTensor(state)
        dist, v = model(state)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy += dist.entropy().mean()
#         action = action.clamp(-2, 2)
#         print(action)
        next_state, reward, done, _ = env.step(action.numpy())
        score += reward
        memory.add([v, log_prob, reward, 1 - done])
        state = next_state
        
    update(next_state, entropy)
    memory.reinit()
    print("Episode = " + str(idx) + ", Score = " + str(score))

Episode = 0, Score = -1356.40999427
Episode = 1, Score = -1137.9588504
Episode = 2, Score = -1092.16985961
Episode = 3, Score = -1379.26368902
Episode = 4, Score = -873.167381107
Episode = 5, Score = -844.521414793
Episode = 6, Score = -1004.25892938
Episode = 7, Score = -992.950540286
Episode = 8, Score = -953.746750369
Episode = 9, Score = -864.441022979
Episode = 10, Score = -1439.41401407
Episode = 11, Score = -1478.26795255
Episode = 12, Score = -987.4746337
Episode = 13, Score = -1620.20363207
Episode = 14, Score = -1363.08621632
Episode = 15, Score = -1371.88482274
Episode = 16, Score = -1014.77686257
Episode = 17, Score = -1182.26932818
Episode = 18, Score = -1498.56799708
Episode = 19, Score = -1768.22969378
Episode = 20, Score = -752.714187625
Episode = 21, Score = -1532.6700347
Episode = 22, Score = -1190.82025026
Episode = 23, Score = -995.155272298
Episode = 24, Score = -1668.12272106
Episode = 25, Score = -1582.22841221
Episode = 26, Score = -1243.86233218
Episode = 27, S

Episode = 220, Score = -1303.17816708
Episode = 221, Score = -1152.42677009
Episode = 222, Score = -1016.32420306
Episode = 223, Score = -1152.93668866
Episode = 224, Score = -1795.76453363
Episode = 225, Score = -666.350173792
Episode = 226, Score = -1517.42033029
Episode = 227, Score = -1402.90458033
Episode = 228, Score = -968.027996352
Episode = 229, Score = -923.973708133
Episode = 230, Score = -1005.82774722
Episode = 231, Score = -1331.9551426
Episode = 232, Score = -1260.66136918
Episode = 233, Score = -938.72139068
Episode = 234, Score = -1410.68356658
Episode = 235, Score = -1198.46919898
Episode = 236, Score = -1496.22421908
Episode = 237, Score = -975.870851244
Episode = 238, Score = -1211.4500592
Episode = 239, Score = -896.275388358
Episode = 240, Score = -1170.83586162
Episode = 241, Score = -1039.03482187
Episode = 242, Score = -1143.74247769
Episode = 243, Score = -968.813444375
Episode = 244, Score = -1126.35023997
Episode = 245, Score = -1063.91791041
Episode = 246, 

Episode = 438, Score = -1617.00726044
Episode = 439, Score = -952.019378536
Episode = 440, Score = -979.476520894
Episode = 441, Score = -917.46844318
Episode = 442, Score = -1075.5157458
Episode = 443, Score = -1464.25149681
Episode = 444, Score = -1445.78193901
Episode = 445, Score = -1061.28077431
Episode = 446, Score = -872.159212843
Episode = 447, Score = -1049.92656068
Episode = 448, Score = -978.441983475
Episode = 449, Score = -1196.04183291
Episode = 450, Score = -1364.80745217
Episode = 451, Score = -872.144130781
Episode = 452, Score = -1066.19203926
Episode = 453, Score = -962.840672472
Episode = 454, Score = -1282.5964472
Episode = 455, Score = -1440.81916972
Episode = 456, Score = -998.71043364
Episode = 457, Score = -1573.77555239
Episode = 458, Score = -865.696628487
Episode = 459, Score = -963.206947319
Episode = 460, Score = -1068.25452605
Episode = 461, Score = -971.943724534
Episode = 462, Score = -1370.52816889
Episode = 463, Score = -1187.97167078
Episode = 464, S

Episode = 656, Score = -1596.15709353
Episode = 657, Score = -1089.18478825
Episode = 658, Score = -997.366634725
Episode = 659, Score = -1644.18579187
Episode = 660, Score = -1074.91953757
Episode = 661, Score = -873.419879262
Episode = 662, Score = -1411.35503692
Episode = 663, Score = -1054.03014491
Episode = 664, Score = -1690.9764778
Episode = 665, Score = -1516.80060864
Episode = 666, Score = -851.034911896
Episode = 667, Score = -1054.28863713
Episode = 668, Score = -1534.4202236
Episode = 669, Score = -1082.63867959
Episode = 670, Score = -991.209705703
Episode = 671, Score = -1383.99872502
Episode = 672, Score = -1762.84255858
Episode = 673, Score = -982.697541852
Episode = 674, Score = -1277.88662051
Episode = 675, Score = -1073.30418097
Episode = 676, Score = -1508.20821444
Episode = 677, Score = -1228.0854182
Episode = 678, Score = -1377.24661032
Episode = 679, Score = -1025.38099699
Episode = 680, Score = -1084.47993827
Episode = 681, Score = -1606.7614436
Episode = 682, S

Episode = 874, Score = -1331.62775198
Episode = 875, Score = -1226.73731547
Episode = 876, Score = -1353.23500854
Episode = 877, Score = -1406.38320551
Episode = 878, Score = -1683.08478614
Episode = 879, Score = -974.722066967
Episode = 880, Score = -1639.34251636
Episode = 881, Score = -1587.3244124
Episode = 882, Score = -1038.79734736
Episode = 883, Score = -1596.61956223
Episode = 884, Score = -869.606646872
Episode = 885, Score = -1366.9688365
Episode = 886, Score = -1089.45876318
Episode = 887, Score = -1519.83530061
Episode = 888, Score = -1505.53072434
Episode = 889, Score = -1397.20905505
Episode = 890, Score = -1473.01249777
Episode = 891, Score = -1096.66815002
Episode = 892, Score = -1370.85693716
Episode = 893, Score = -1192.26894997
Episode = 894, Score = -1333.79969421
Episode = 895, Score = -1002.67620744
Episode = 896, Score = -959.890022753
Episode = 897, Score = -1334.66936313
Episode = 898, Score = -1332.98409179
Episode = 899, Score = -1180.80467021
Episode = 900,