In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import gym
import gym_aero as ga
import scipy
from torch.distributions import Normal
from collections import namedtuple
import utilities as utils
from math import log

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class IndependentGaussianPolicy(nn.Module):
    """
    Gaussian policy function for continuous control tasks. Assumes all actions are
    independent (i.e. diagonal covariance matrix).
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(IndependentGaussianPolicy, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.mu = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                               nn.Tanh(),
                               nn.Linear(hidden_dim, hidden_dim),
                               nn.Tanh(),
                               nn.Linear(hidden_dim, output_dim))     
        self.logsigma = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                               nn.Tanh(),
                               nn.Linear(hidden_dim, hidden_dim),
                               nn.Tanh(),
                               nn.Linear(hidden_dim, output_dim))
        
    def forward(self, x):
        mu = self.mu(x)
        logsigma = self.logsigma(x)
        return mu, logsigma

class ValueNet(nn.Module):
    """
    Simple parameterized value function. We use this for both state value functions,
    and state-action value functions.
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(ValueNet, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.value = nn.Sequential(nn.Linear(input_dim, hidden_dim),
                               nn.Tanh(),
                               nn.Linear(hidden_dim, hidden_dim),
                               nn.Tanh(),
                               nn.Linear(hidden_dim, output_dim))
        
    def forward(self, x):
        value = self.value(x)
        return value
    
class REINFORCE(nn.Module):
    """
    Basic Monte Carlo Policy Gradient. This implementation is single-threaded only, and uses
    the score function estimator to find the policy gradient. We use importance sampling to
    take multiple update steps at the end of each trajectory. This is necessary to correct
    for the fact that the value function was estimated under a previous policy.
    """
    def __init__(self, beta, v_fn):
        super(REINFORCE, self).__init__()
        self.beta = beta
        self.v_fn = v_fn
        
    def select_action(self, x):
        mu, logsigma = self.beta(x)
        sigma = torch.exp(logsigma)
        dist = Normal(mu, sigma)
        action = dist.sample()
        lp = dist.log_prob(action)
        entropy = dist.entropy()
        return action, lp, entropy

    def get_phi(self, trajectory, critic, gamma=0.99, tau=0):
        states = torch.stack(trajectory["states"]).to(device)
        rewards = torch.stack(trajectory["rewards"]).to(device)
        next_states = torch.stack(trajectory["next_states"]).to(device)
        masks = torch.stack(trajectory["masks"]).to(device)

        values = critic(states)
        returns = torch.Tensor(rewards.size(0),1).to(device)
        deltas = torch.Tensor(rewards.size(0),1).to(device)
        advantages = torch.Tensor(rewards.size(0),1).to(device)
        prev_return = 0
        prev_value = 0
        prev_advantage = 0
        for i in reversed(range(rewards.size(0))):
            returns[i] = rewards[i] + gamma * prev_return * masks[i]
            deltas[i] = rewards[i] + gamma * prev_value * masks[i] - values.data[i]
            advantages[i] = deltas[i] + gamma * tau * prev_advantage * masks[i]
            prev_return = returns[i, 0]
            prev_value = values.data[i, 0]
            prev_advantage = advantages[i, 0]
        return deltas, returns

    def update(self, optim, trajectory, iters=4):
        log_probs = torch.stack(trajectory["log_probs"]).to(device)
        states = torch.stack(trajectory["states"]).to(device)
        actions = torch.stack(trajectory["actions"]).to(device)
        for i in range(iters):
            deltas, returns = self.get_phi(trajectory, self.v_fn)
            phi = deltas / returns.std()
            v_fn_loss = torch.mean((self.v_fn(states) - returns.detach()) ** 2)
            mu_p, logsigma_p = self.beta(states)
            sigma_p = torch.exp(logsigma_p)
            dist_p = Normal(mu_p, sigma_p)
            lp_p = dist_p.log_prob(actions)
            ratio = torch.exp((lp_p - log_probs.detach()).sum(dim=-1, keepdim=True))
            pol_loss = -torch.mean(ratio * phi.detach())
            loss = pol_loss + v_fn_loss
            optim.zero_grad()
            if i < iters-1:
                loss.backward(retain_graph=True)
            else:
                loss.backward()
            optim.step()
        return pol_loss.item(), v_fn_loss.item()

        
class PPO(REINFORCE):
    """
    Simple implementation of the proximal policy optimization algorithm (Schulman, 2017). The
    only difference between this and a standard REINFORCE algorithm is the use of the clipped
    objective, which helps keep the policy updates bound to a trust region (though there is some
    controversy around this).

    This implementation uses the same importance sampling correction as in the REINFORCE implementation
    above, but uses a clipped surrogate objective to keep the policy update bounded to a trust 
    region.
    """
    def __init__(self, beta, v_fn):
        super(PPO, self).__init__(beta, v_fn)
    
    def update(self, optim, trajectory, iters=4, eps=0.2):
        log_probs = torch.stack(trajectory["log_probs"]).to(device)
        states = torch.stack(trajectory["states"]).to(device)
        actions = torch.stack(trajectory["actions"]).to(device)
        for i in range(iters):
            deltas, returns = self.get_phi(trajectory, self.v_fn)
            phi = deltas / returns.std()
            v_fn_loss = torch.mean((self.v_fn(states) - returns.detach()) ** 2)
            mu_p, logsigma_p = self.beta(states)
            sigma_p = torch.exp(logsigma_p)
            dist_p = Normal(mu_p, sigma_p)
            lp_p = dist_p.log_prob(actions)
            ratio = torch.exp((lp_p - log_probs.detach()).sum(dim=-1, keepdim=True))
            clipped_objective = torch.min(ratio * phi, torch.clamp(ratio, 1 + eps, 1 - eps) * phi)
            pol_loss = -torch.mean(clipped_objective)
            loss = pol_loss + v_fn_loss
            optim.zero_grad()
            if i < iters-1:
                loss.backward(retain_graph=True)
            else:
                loss.backward()
            optim.step()
        return pol_loss.item(), v_fn_loss.item()
    
def test(env, agent):
    state = torch.Tensor(env.reset()).to(device)
    done = False
    r = 0.
    while not done:
        #env.render()
        action, _, _ = agent.select_action(state)
        next_state, reward, done, _ = env.step(action.cpu().data.numpy())
        r += reward
        next_state = torch.Tensor(next_state).to(device)
        state = next_state
    return r


def rollout(env, agent, batch_size):
    s_, a_, ns_, r_, lp_, masks = [], [], [], [], [], []
    T = 0
    while T < batch_size:
        t = 0
        state = torch.Tensor(env.reset()).to(device)
        done = False
        while not done:
            action, log_prob, entropy = agent.select_action(state)
            next_state, reward, done, info = env.step(action.cpu().data.numpy())
            reward = torch.Tensor([reward]).to(device)
            next_state = torch.Tensor(next_state).to(device)
            s_.append(state)
            a_.append(action)
            ns_.append(next_state)
            r_.append(reward)
            lp_.append(log_prob)
            masks.append(torch.Tensor([not done]).to(device))
            state = next_state
            t += 1
        T += t
    trajectory = {
                "states" : s_,
                "actions" : a_,
                "rewards" : r_,
                "next_states" : ns_,
                "masks" : masks,
                "log_probs" : lp_,
                }
    return trajectory


def train_offline(env, agent, opt, batch_size=1024, iterations=500, log_interval=10, t_runs=10):
    test_rew_best = np.mean([test(env, agent) for _ in range(t_runs)])
    data = []
    data.append(test_rew_best)
    print()
    print("Iterations: ", 0)
    print("Time steps: ", 0)
    print("Reward: ", test_rew_best)
    print()
    for ep in range(1, iterations+1):
        trajectory = rollout(env, agent, batch_size)
        agent.update(opt, trajectory)
        if ep % log_interval == 0:
            test_rew = np.mean([test(env, agent) for _ in range(t_runs)])
            data.append(test_rew)
            print("Iterations: ", ep)
            print("Time steps: ", batch_size*ep)
            print("Reward: ", test_rew)
            print()
    return data