In [223]:
import math
import random
import collections
import copy
import gym
import numpy as np
from operator import mul
from functools import reduce
from torch import Tensor
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal
from torch.nn.utils.convert_parameters import vector_to_parameters, parameters_to_vector
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

from IPython.display import clear_output
import matplotlib.pyplot as plt
%matplotlib inline

In [59]:
env = gym.make("Pendulum-v0")

episodes = 1000
batch_size = 32
gamma = 0.99
goal_steps = 200
input_shape = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
# print(num_actions)
buffer_capacity = 1000
epochs = 5
clip_param = 0.2
tau = 0.97
max_kl=0.001
damping=0.001
iters=10
residual_tol=1e-10
ent_coeff=0.00

In [12]:
class Buffer(object):
    def __init__(self):
        self.buffer = []
        self.buffer_capacity = 1000
        self.batch = 32
    
    def add(self, params):
        self.buffer.append(params)
        
    def reinit(self):
        self.buffer = []
        
    def length(self):
        return len(self.buffer)

In [29]:
class Policy(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(input_shape, 64)
        self.fc2 = nn.Linear(64, 128)
        self.mu_head = nn.Linear(128,num_actions)
        self.log_std = nn.Parameter(torch.zeros(num_actions))
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        mu = self.mu_head(x)
        std = self.log_std.exp().expand_as(mu)
        dist = Normal(mu, std)
        return dist

class Value(nn.Module):
    def __init__(self, input_shape, num_actions):
        super(Value, self).__init__()
        self.fc1 = nn.Linear(input_shape, 64)
        self.fc2 = nn.Linear(64, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        value = self.fc3(x)
        return value

In [240]:
class Agent:
    def __init__(self,input_shape, num_actions):
        self.device   = torch.device("cpu")
        self.policy = Policy(input_shape, num_actions)#.to(self.device)
        self.value = Value(input_shape, num_actions)#.to(self.device)
        self.optimizer = optim.Adam(self.value.parameters())
        self.buffer_capacity = 1000
        self.batch_size = 32
        self.policy_model_properties = collections.OrderedDict()
        for k, v in self.policy.state_dict().items():
            self.policy_model_properties[k] = v.size()

    def kl_divergence(self, model):
#         observations_tensor = torch.cat([Tensor(observation).unsqueeze(0) for observation in self.observations])
        dist_n = model(self.s)
        actprob = dist_n.log_prob(self.a.unsqueeze(1))
        old_dist = self.policy(self.s)
        old_actprob = old_dist.log_prob(self.a.unsqueeze(1))
        return torch.sum(torch.exp(old_actprob) * (old_actprob / actprob)).mean()

    def hessian_vector_product(self, vector):
        self.policy.zero_grad()
        kl_div = self.kl_divergence(self.policy)
        kl_div.backward(create_graph=True)
        gradient = flatten_model_params([p.grad for p in self.policy.parameters()]).squeeze(0)
        gradient_vector_product = torch.sum(gradient * vector)
        gradient_vector_product.backward()#torch.ones(gradient.size())
        return (flatten_model_params([p.grad for p in self.policy.parameters()]).squeeze(0) - gradient).data 

    def conjugate_gradient(self, b):
        p = b.clone().data
        r = b.clone().data
        x = np.zeros_like(b.data.numpy())
        rdotr = r.dot(r)
        for i in range(iters):
            z = self.hessian_vector_product(p)
            v = rdotr / p.dot(z)
            x += v * p
            r -= v * z
            newrdotr = r.dot(r)
            mu = newrdotr / rdotr
            p = r + mu * p
            rdotr = newrdotr
            if rdotr < residual_tol:
                break
        return x
    
    def linesearch(self, x, fullstep, expected_improve_rate):
        accept_ratio = .1
        max_backtracks = 10
        
        fval = self.surrogate_loss(x)
        for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)):
            xnew = x.data.numpy() + stepfrac * fullstep
            newfval = self.surrogate_loss(torch.from_numpy(xnew))
            actual_improve = fval - newfval
            expected_improve = expected_improve_rate * stepfrac
            ratio = actual_improve / expected_improve
            if ratio > accept_ratio and actual_improve > 0:
                return torch.from_numpy(xnew)
        return x
    
    def surrogate_loss(self, theta):
        new_model = self.construct_model_from_theta(theta.data)
#         observations_tensor = torch.cat([Tensor(observation).unsqueeze(0) for observation in self.observations])
        dist_new = new_model(self.s)
        prob_new = dist_new.log_prob(self.a.unsqueeze(1))
        
        dist_old = new_model(self.s)
        prob_old = dist_old.log_prob(self.a.unsqueeze(1))
        
        return -torch.sum(torch.exp(prob_new / prob_old) * self.advantages)
    
    def construct_model_from_theta(self, theta):
        theta = theta.squeeze(0)
        new_model = copy.deepcopy(self.policy)
        state_dict = collections.OrderedDict()
        start_index = 0
        for k, v in self.policy_model_properties.items():
            param_length = reduce(mul, v, 1)
            state_dict[k] = theta[start_index : start_index + param_length].view(v)
            start_index += param_length
        new_model.load_state_dict(state_dict)
        return new_model

    def update(self, entropy):
        mem = memory.buffer
        self.s = torch.FloatTensor([m[0].numpy() for m in mem])
        self.a = torch.FloatTensor([m[1] for m in mem]) 
        old_log_a = torch.FloatTensor([m[2] for m in mem])
        r = torch.FloatTensor([m[3] for m in mem])
        masks = torch.FloatTensor([m[4] for m in mem])
        values = self.value(self.s)
#         print(a.size(),self.s.size())
        returns = torch.Tensor(self.a.size(0),1)
        deltas = torch.Tensor(self.a.size(0),1)
        self.advantages = torch.Tensor(self.a.size(0),1)
        prev_return = 0
        prev_value = 0
        prev_advantage = 0
        for i in reversed(range(len(r))):
            returns[i] = r[i] + gamma * prev_return * masks[i]
            deltas[i] = r[i] + gamma * prev_value * masks[i] - values[i]
            self.advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i]
            prev_return = returns[i, 0]
            prev_value = values.data[i, 0]
            prev_advantage = self.advantages[i, 0]
        self.advantages = self.advantages.squeeze(1)
#         print(advantages.size())
        for _ in range(epochs):
            for id in BatchSampler(SubsetRandomSampler(range(200)), batch_size, False):
                dist = self.policy(self.s[id])
                new_log_a = dist.log_prob(self.a[id].unsqueeze(1))
                ratio = torch.exp(new_log_a.squeeze(1) - old_log_a[id])
                surrogate_loss = -torch.mean(ratio * self.advantages[id]) - (ent_coeff * entropy)

                self.policy.zero_grad()
                surrogate_loss.backward(retain_graph=True)
                policy_gradient = flatten_model_params([p.grad for p in self.policy.parameters()]).squeeze(0)
                
                step_direction = self.conjugate_gradient(policy_gradient)
                step_direction_variable = torch.from_numpy(step_direction)#.unsqueeze(1)
                
                shs = .5 * step_direction.dot(self.hessian_vector_product(step_direction_variable).numpy().T)
                lm = np.sqrt(shs / max_kl)
                fullstep = step_direction / lm
                gdotstepdir = policy_gradient.dot(step_direction_variable)#.data[0]
#                 print(list(self.policy.parameters()))
                theta = self.linesearch(flatten_model_params(list(self.policy.parameters())), fullstep, gdotstepdir / lm)

                # Update parameters of policy model
                old_model = copy.deepcopy(self.policy)
                old_model.load_state_dict(self.policy.state_dict())
                self.policy = self.construct_model_from_theta(theta.data)
                kl_old_new = self.kl_divergence(old_model)

                self.fit(returns)

        memory.reinit()
    
    def fit(self, labels):
        def closure():
            predicted = self.value(self.s)
            loss = torch.pow(predicted - labels, 2).sum()
            self.optimizer.zero_grad()
            loss.backward()
            return loss
        self.optimizer.step(closure)

In [241]:
# policy = Policy(input_shape, num_actions)#.to(self.device)

def flatten_model_params(parameters):
    return torch.cat([param.view(1, -1) for param in parameters], 1)

# grad = list(policy.parameters())
# x = flatten_model_params(grad)
# print(x)
agent = Agent(input_shape, num_actions)
# value = Value(input_shape, num_actions)#.to(self.device)
# optimizer = optim.Adam(value.parameters())

# policy_model_properties = collections.OrderedDict()
# for k, v in policy.state_dict().items():
#     policy_model_properties[k] = v.size()
# print(policy_model_properties)
    

In [242]:
memory = Buffer()
state = env.reset()
for idx in range(episodes):
    state = env.reset()
    score = 0
    entropy = 0
    done = False
    while not done:
        state = torch.FloatTensor(state)
        dist = agent.policy(state)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        entropy += dist.entropy().mean()
#         action = action.clamp(-2, 2)
#         print(action)
        next_state, reward, done, _ = env.step(action.numpy())
        score += reward
        memory.add([state, action, log_prob, reward, 1 - done])
        state = next_state
        
    agent.update(entropy)
    memory.reinit()
    print("Episode = " + str(idx) + ", Score = " + str(score))



Episode = 0, Score = -1143.93262294
Episode = 1, Score = -874.171427413
Episode = 2, Score = -1322.27626545
Episode = 3, Score = -1475.88116508
Episode = 4, Score = -1247.32785124
Episode = 5, Score = -1667.22229589
Episode = 6, Score = -1254.95440065
Episode = 7, Score = -1268.68848168
Episode = 8, Score = -1275.98459647
Episode = 9, Score = -916.64800343
Episode = 10, Score = -1372.89634395
Episode = 11, Score = -1218.54830486
Episode = 12, Score = -1177.36439435
Episode = 13, Score = -1576.38362215
Episode = 14, Score = -1141.50814625
Episode = 15, Score = -1363.90907364
Episode = 16, Score = -1449.83774515
Episode = 17, Score = -1040.214042
Episode = 18, Score = -1694.99835944
Episode = 19, Score = -1437.34522919
Episode = 20, Score = -1292.74964927
Episode = 21, Score = -964.911276204
Episode = 22, Score = -1504.7347425
Episode = 23, Score = -829.341913425
Episode = 24, Score = -1602.93754158
Episode = 25, Score = -963.683507892
Episode = 26, Score = -1081.94721951
Episode = 27, S

KeyboardInterrupt: 