In [1]:

import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
from torch.distributions import Categorical


In [2]:
# Policy network
class PolicyNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        probs = F.softmax(self.fc2(x), dim=-1)
        return probs

# Value function network
class ValueNetwork(nn.Module):
    def __init__(self, input_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        value = self.fc2(x)
        return value

In [3]:
def get_advantages(values, masks, rewards, gamma=0.99, tau=0.95):
    rewards = torch.cat(rewards)
    values = torch.cat(values)
    masks = torch.cat(masks)
    targets = torch.zeros(rewards.size())
    gae = torch.zeros(rewards.size())
    for t in reversed(range(rewards.size(0))):
        if t == rewards.size(0) - 1:
            delta = rewards[t] - values[t]
            gae[t] = delta
        else:
            delta = rewards[t] + gamma * values[t + 1] * masks[t + 1] - values[t]
            gae[t] = delta + gamma * tau * masks[t + 1] * gae[t + 1]
    targets = values + gae
    return targets, (gae - gae.mean()) / (gae.std() + 1e-8)

In [13]:
def conjugate_gradient(Avp, b, nsteps=10, residual_tol=1e-10):
    p = b.clone().detach()
    r = b.clone().detach()
    x = torch.zeros_like(b).detach()
    rdotr = torch.dot(r, r)
    for i in range(nsteps):
        _tmp = Avp(p).detach()
        alpha = rdotr / torch.dot(p, _tmp)
        x += alpha * p
        r -= alpha * _tmp
        new_rdotr = torch.dot(r, r)
        betta = new_rdotr / rdotr
        p = r + betta * p
        rdotr = new_rdotr
        if rdotr < residual_tol:
            break
    return x

In [14]:
def kl_divergence(old_probs, new_probs):
    """
    Compute the KL divergence between old and new probabilities.
    """
    # Ensure the probabilities do not become 0 or 1
    epsilon = 1e-5
    old_probs = torch.clamp(old_probs, epsilon, 1 - epsilon)
    new_probs = torch.clamp(new_probs, epsilon, 1 - epsilon)
    
    return (old_probs * (torch.log(old_probs) - torch.log(new_probs))).sum(1, keepdim=True)

def hessian_vector_product(tau, p_net, v, damping_factor=1e-2):
    """
    Compute product of Hessian of KL divergence with vector v.
    """
    kl = kl_divergence(tau, p_net(tau))
    
    # Compute the first derivative
    grads = torch.autograd.grad(kl.sum(), p_net.parameters(), create_graph=True)
    flat_grads = torch.cat([grad.view(-1) for grad in grads])

    # Compute the product of the first derivative with the vector v
    grad_v_product = torch.sum(flat_grads * v)

    # Compute the second derivative
    hessian_v = torch.autograd.grad(grad_v_product, p_net.parameters())
    hessian_v_product = torch.cat([grad.contiguous().view(-1) for grad in hessian_v]).data

    return hessian_v_product + v * damping_factor


In [15]:
def linesearch(model, f, x, fullstep, expected_improve_rate, max_backtracks=10, accept_ratio=0.1):
    """
    Perform backtracking linesearch.
    """
    fval = f(True).data
    for step_frac in 0.5**np.arange(max_backtracks):
        xnew = x + step_frac * fullstep
        set_flat_params_to(model, xnew)
        newfval = f(True).data
        actual_improve = fval - newfval
        expected_improve = expected_improve_rate * step_frac
        ratio = actual_improve / expected_improve

        if ratio.item() > accept_ratio and actual_improve.item() > 0:
            return True, xnew
    return False, x

In [16]:
def set_flat_params_to(model, flat_params):
    """
    Set the parameters of the model using a flat parameter tensor.
    """
    prev_ind = 0
    for param in model.parameters():
        flat_size = int(np.prod(list(param.size())))
        param.data.copy_(flat_params[prev_ind:prev_ind+flat_size].view(param.size()))
        prev_ind += flat_size


In [17]:
def surrogate_objective(old_probs, new_probs, actions, advantages):
    """    
    Compute the surrogate objective, which is used to derive the policy update.
    """
    ratio = new_probs.gather(1, actions) / old_probs.gather(1, actions)
    return (ratio * advantages).mean()

In [18]:
def get_flat_params_from(model):
    """
    Retrieve the parameters from a model as a single flat tensor.
    """
    params = []
    for param in model.parameters():
        params.append(param.data.view(-1))
    return torch.cat(params)


In [19]:
env = gym.make('CartPole-v1')
in_dim = env.observation_space.shape[0]
out_dim = env.action_space.n

policy_net = PolicyNetwork(in_dim, out_dim)
value_net = ValueNetwork(in_dim)
value_optimizer = optim.Adam(value_net.parameters(), lr=1e-3)

num_epochs = 1000
max_steps = 200
num_trajectories = 10
gamma = 0.99
tau = 0.95

In [21]:
for epoch in range(num_epochs):
    log_probs = []
    values = []
    states = []
    actions = []
    rewards = []
    masks = []
    entropy = 0
    cumulative_rewards = 0
    for _ in range(num_trajectories):
        state, _ = env.reset()
        for _ in range(max_steps):
            state = torch.FloatTensor(state).unsqueeze(0)
            probs = policy_net(state)
            action = probs.multinomial(num_samples=1).detach()
            next_state, reward, done, truncated, _ = env.step(action.numpy()[0][0])
            done = done or truncated
            
            cumulative_rewards += reward
            
            log_prob = torch.log(probs.squeeze(0)[action])
            entropy += -(log_prob * probs.squeeze(0)).sum()
            
            log_probs.append(log_prob)
            values.append(value_net(state))
            rewards.append(torch.FloatTensor([reward]))
            masks.append(torch.FloatTensor([1 - done]))

            states.append(state)
            actions.append(action)

            state = next_state
            if done:
                break

    next_state = torch.FloatTensor(next_state).unsqueeze(0)
    next_value = value_net(next_state)
    returns, advantages = get_advantages(values, masks, rewards, gamma, tau)

    log_probs = torch.cat(log_probs)
    
    
    # 1. Compute the surrogate objective.
    old_probs = policy_net(torch.cat(states)).detach()
    new_probs = policy_net(torch.cat(states))
    surr_obj = surrogate_objective(old_probs, new_probs, torch.cat(actions), advantages)

    # 2. Compute the gradient of the surrogate objective.
    policy_net.zero_grad()
    surr_obj.backward(retain_graph=True)
    policy_grad = [param.grad for param in policy_net.parameters()]
    flat_policy_grad = torch.cat([grad.view(-1) for grad in policy_grad])
    
    # 3. Compute the step direction using the conjugate gradient algorithm.
    hvp = lambda v: hessian_vector_product(torch.cat(states), policy_net, v)
    step_direction = conjugate_gradient(hvp, flat_policy_grad)

    # 4. Compute the step size using the line search.
    max_kl = 0.01
    step_size = torch.sqrt(2 * max_kl / (torch.dot(step_direction, hvp(step_direction))))
    full_step = step_size * step_direction

    # 5. Perform line search to satisfy the KL constraint.
    expected_improve = torch.dot(-full_step, policy_grad)
    success, new_params = linesearch(policy_net, surrogate_objective, get_flat_params_from(policy_net), full_step, expected_improve)

    if success:
        set_flat_params_to(policy_net, new_params)

    # Update the value function
    values = torch.cat(values)
    value_loss = (values - returns).pow(2).mean()

    value_optimizer.zero_grad()
    value_loss.backward()
    value_optimizer.step()
    
    avg_reward = cumulative_rewards / num_trajectories
    print(f"Epoch {epoch+1}, Average Reward: {avg_reward}")


RuntimeError: The size of tensor a (4) must match the size of tensor b (2) at non-singleton dimension 1