In [43]:
"""
Create a buffer object to store experiences.
Log mean and standard deviation to buffer.
"""

using Pkg
Pkg.activate("./Project.toml")

using Flux
using OpenAIGym
import Reinforce.action
import Reinforce:run_episode
import Flux.params
using Flux.Tracker: grad, update!
using Flux: onehot
using Statistics
using Distributed
using Distributions
using LinearAlgebra
using Base.Iterators
using Random
using BSON:@save,@load
using JLD

include("policies.jl")

"""
Utilities
"""

function get_flat_grads(gradients,models...)
    """
    Flattens out the gradients and concatenates them
    
    Returns : Tracker Array of shape (NUM_PARAMS,1)
    """

    flat_grads = []

    function flatten!(p)
        if typeof(p) <: TrackedArray
            prod_size = prod(size(p))
            push!(flat_grads,reshape(gradients[p],prod_size))
        end
    end
    
    for model in models
        mapleaves(flatten!,model)
    end
    
    flat_grads = cat(flat_grads...,dims=1)
    flat_grads = reshape(flat_grads,length(flat_grads),1)
    
    return flat_grads
end

function get_flat_params(models...)
    """
    Flattens out the parameters and concatenates them
    
    Returns : Tracker Array of shape (NUM_PARAMS,1)
    """

    flat_params = []
    
    function flatten!(p)
        if typeof(p) <: TrackedArray
            prod_size = prod(size(p))
            push!(flat_params,reshape(p,prod_size))
        end
    end
    
    for model in models
        mapleaves(flatten!,model)
    end
    
    flat_params = cat(flat_params...,dims=1)
    flat_params = reshape(flat_params,length(flat_params),1)
    
    return flat_params
end

function set_flat_params(parameters,models...)
    """
    Sets values of `parameters` to the `model`
    
    parameters : flattened out array of model parameters
    """
    ptr = 1
    
    function assign!(p)
        if typeof(p) <: TrackedArray
            prod_size = prod(size(p))
            
            p.data .= Float32.(reshape(parameters[ptr : ptr + prod_size - 1,:],size(p)...)).data
            ptr += prod_size
        end
    end
    
    for model in models
        mapleaves(assign!,model)
    end
    
    print("")
end

function categorical_kl(states,old_log_probs)
    action_probs = m(states)
    log_probs = log.(action_probs)
    
    log_ratio = log_probs .- old_log_probs
    kl_div = (exp.(old_log_probs)) .* log_ratio
    return sum(kl_div,dims=1)
end

function gaussian_kl(μ0,logΣ0,μ1,logΣ1)
    var0 = exp.(2 .* logΣ0)
    var1 = exp.(2 .* logΣ1)
    pre_sum = 0.5 .* (((μ0 .- μ1).^2 .+ var0) ./ (var1 .+ 1e-8) .- 1.0f0) .+ logΣ1 .- logΣ0
    kl = sum(pre_sum,dims=1)
    return kl
end

function kl_loss(states,actions,advantages,returns,old_log_probs)
    if MODE == "CON"
        # TODO
    else
        # TODO
    end
end

"""
HYPERPARAMETERS
"""
# Environment Creation #
env_name = "Pendulum-v0"
MODE = "CON" # Can be either "CON" (Continuous) or "CAT" (Categorical)

# Environment Variables #
STATE_SIZE = 3
ACTION_SIZE = 1
EPISODE_LENGTH = 2000
TEST_STEPS = 10000
REWARD_SCALING = 16.2736044
# Policy parameters #
η = 3e-4 # Learning rate
STD = 0.0 # Standard deviation
HIDDEN_SIZE = 30
# GAE parameters
γ = 0.99
λ = 0.95
# Optimization parameters
PPO_EPOCHS = 10
NUM_EPISODES = 100000
BATCH_SIZE = 256
c₀ = 1.0
c₁ = 1.0
c₂ = 0.001
# PPO parameters
ϵ = 0.2
# FREQUENCIES
SAVE_FREQUENCY = 50
VERBOSE_FREQUENCY = 5
global_step = 0

# Global variable to monitor losses
reward_hist = []
policy_l = 0.0
entropy_l = 0.0
value_l = 0.0

#---------Scale rewards-------#
function scale_rewards(rewards)
    return (rewards  ./ REWARD_SCALING) .+ 2.0f0
end

function normalise(arr)
    (arr .- mean(arr))./(sqrt(var(arr) + 1e-10))
end

"""
Define the networks
"""

if MODE == "CON"
    policy_μ,policy_Σ = gaussian_policy(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE)
    value = value_fn(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE,tanh)
elseif MODE == "CAT"
    policy = categorical_policy(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE)
    value = value_fn(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE,relu)
else 
    error("MODE can only be (CON) or (CAT)...")
end

opt = ADAM(η)

"""
Functions to get rollouts
"""

function action(state)
    # Acccounting for the element type
    state = reshape(Array(state),length(state),1) 

    a = nothing
    if MODE == "CON"
        # Our policy outputs the parameters of a Normal distribution
        μ = policy_μ(state)
        μ = reshape(μ,ACTION_SIZE)
        log_std = policy_Σ
        
        σ² = (exp.(log_std)).^2
        Σ = diagm(0=>σ².data)
        
        dis = MvNormal(μ.data,Σ)
        
        a = rand(dis,ACTION_SIZE)
    else
        action_probs = policy(state).data
        action_probs = reshape(action_probs,ACTION_SIZE)
        a = sample(1:ACTION_SIZE,Weights(action_probs)) - 1
    end
    a
end

function run_episode(env)
    experience = []
    
    s = reset!(env)
    for i in 1:EPISODE_LENGTH
        a = action(s)
        # a = convert.(Float64,a)
        
        if MODE == "CON"
            a = reshape(a,ACTION_SIZE)
        end

        r,s_ = step!(env,a)
        push!(experience,(s,a,r,s_))
        s = s_
        if env.done
           break 
        end
    end
    experience
end

"""
Multi-threaded parallel rollout collection
"""

num_processes = 3
addprocs(num_processes) 

@everywhere function collect(env)
    run_episode(env)
end

@everywhere function rollout()
  env = GymEnv(env_name)
  env.pyenv._max_episode_steps = EPISODE_LENGTH
  return collect(env)
end

function get_rollouts()
    g = []
    for  w in workers()
      push!(g, rollout())
    end

    fetch.(g)
end

"""
Generalized Adavantage Estimation
"""

function gae(states,actions,rewards,next_states)
    """
    Returns a Generalized Advantage Estimate for an episode
    """
    Â = []
    A = 0.0
    for i in reverse(1:length(states))
        if length(states) < EPISODE_LENGTH && i == length(states)
            δ = rewards[i] - cpu.(value(states[i]).data[1])
        else
            δ = rewards[i] + γ*cpu.(value(next_states[i]).data[1]) - cpu.(value(states[i]).data[1])
        end

        A = δ + (γ*λ*A)
        push!(Â,A)
    end
    
    Â = reverse(Â)
    return Â
end

function disconunted_returns(rewards)
    r = 0.0
    returns = []
    for i in reverse(1:length(rewards))
        r = rewards[i] + γ*r
        push!(returns,r)
    end
    returns = reverse(returns)
    returns
end

"""
Calculate Log Probabilities
"""
function log_prob_from_actions(states,actions)
    """
    Returns log probabilities of the actions taken
    
    states,actions : episode vairbles in the form of a list
    """
    log_probs = []
    
    for i in 1:length(states)
        if MODE == "CON"
            μ = reshape(policy_μ(states[i]),ACTION_SIZE).data
            logΣ = policy_Σ.data |> cpu
            push!(log_probs,normal_log_prob(μ,logΣ,actions[i]))
        else
            action_probs = policy(states[i])
            prob = action_probs[actions[i],:].data
            push!(log_probs,log.(prob))
        end
    end
    
    log_probs
end


"""
Process and extraction information from rollouts
"""

function process_rollouts(rollouts)
    """
    rollouts : variable returned by calling `get_rollouts`
    
    Returns : 
    states, actions, rewards for minibatch processing
    """
    # Process the variables
    states = []
    actions = []
    rewards = []
    next_states = []
    advantages = []
    returns = []
    log_probs = []
    
    # Logging statistics
    episode_mean_returns = []
    
    for ro in rollouts
        episode_states = []
        episode_actions = []
        episode_rewards = []
        episode_next_states = []
        
        for i in 1:length(ro)
             push!(episode_states,Array(ro[i][1]))
             
             if MODE == "CON"
                 push!(episode_actions,ro[i][2])
             else
                 push!(episode_actions,ro[i][2] + 1)
             end
             
             push!(episode_rewards,ro[i][3])
             push!(episode_next_states,ro[i][4])
        end
        
        episode_rewards = scale_rewards(episode_rewards)
        episode_advantages = gae(episode_states,episode_actions,episode_rewards,episode_next_states)
        episode_advantages = normalise(episode_advantages)
        
        episode_returns = disconunted_returns(episode_rewards)
         
        push!(episode_mean_returns,mean(episode_returns))
         
        push!(states,episode_states)
        push!(actions,episode_actions)
        push!(rewards,episode_rewards)
        push!(advantages,episode_advantages)
        push!(returns,episode_returns)
        push!(log_probs,log_prob_from_actions(episode_states,episode_actions))
    end
    
    states = cat(states...,dims=1)
    actions = cat(actions...,dims=1)
    rewards = cat(rewards...,dims=1)
    advantages = cat(advantages...,dims=1)
    returns = cat(returns...,dims=1)
    log_probs = cat(log_probs...,dims=1)
    
    push!(reward_hist,mean(episode_mean_returns))
    
    if length(reward_hist) <= 100
        println("RETURNS : $(mean(episode_mean_returns))")
    else
        println("MEAN RETURNS : $(mean(reward_hist))")
        println("LAST 100 RETURNS : $(mean(reward_hist[end-100:end]))")
    end
    
    return hcat(states...),hcat(actions...),hcat(rewards...),hcat(advantages...),hcat(returns...),hcat(log_probs...)
end


"""
Loss function definition
"""

function policy_loss(states,actions,advantages,returns,old_log_probs)
    if MODE == "CON"
        μ = policy_μ(states)
        logΣ = policy_Σ 
        
        new_log_probs = normal_log_prob(μ,logΣ,actions)
    else
        action_probs = policy(states) # ACTION_SIZE x BATCH_SIZE
        actions_one_hot = zeros(ACTION_SIZE,size(action_probs)[end])
        
        for i in 1:size(action_probs)[end]
            actions_one_hot[actions[:,i][1],i] = 1.0                
        end
        
        new_log_probs = log.(sum((action_probs .+ 1f-5) .* actions_one_hot,dims=1))
    end
    
    # Surrogate loss computations
    ratio = exp.(new_log_probs .- old_log_probs)
    π_loss = mean(ratio .* advantages)
    return π_loss
end

function value_loss(states,returns)
    value_predicted = value(states)
    value_loss = mean((value_predicted .- returns).^2)
end

"""
Optimization Function
"""

function gvp(states,actions,advantages,returns,log_probs,x)
    """
    Intermediate utility function, calculates Σ∇D_kl*x
    
    x : Variable to be estimated using conjugate gradient (Hx = g); (NUM_PARAMS,1)
    """
    if MODE == "CON"
        model_params = params(params(policy_μ)...,params(policy_Σ)...)
        gs = Tracker.gradient(() -> policy_loss(states,actions,advantages,returns,log_probs),model_params;nest=true)
        flat_grads = get_flat_grads(gs,policy_μ,policy_Σ)
    else
        model_params = params(policy)
        gs = Tracker.gradient(() -> policy_loss(states,actions,advantages,returns,log_probs),model_params;nest=true)
        flat_grads = get_flat_grads(gs,policy)
    end
    
    return sum(x' * flat_grads)
end

function Hvp(states,actions,advantages,returns,log_probs,x)
    """
    Computes the Hessian Vector Product
    Hessian is that of the kl divergence between the old and the new policies wrt the policy parameters
    
    Returns : Hx; H = ∇²D_kl
    """
    if MODE == "CON"
        model_params = params(params(policy_μ)...,params(policy_Σ)...)
        hessian = Tracker.gradient(() -> gvp(states,actions,advantages,returns,log_probs,x),model_params)
        return get_flat_grads(hessian,policy_μ,policy_Σ)
    else
        model_params = params(policy)
        hessian = Tracker.gradient(() -> gvp(states,actions,advantages,returns,log_probs,x),model_params)
        return get_flat_grads(hessian,policy)
    end
end

function conjugate_gradients(states,actions,advantages,returns,log_probs,Hvp,b,nsteps,err=1e-10)
    """
    b : Array of shape (NUM_PARAMS,1)
    """
    x = zeros(size(b))
    
    r = copy(b)
    p = copy(b)
    
    rdotr = r' * r
    
    for i in 1:nsteps
        hvp = Hvp(states,p).data # Returns array of shape (NUM_PARAMS,1)

        α = rdotr ./ (p' * hvp)
        
        x = x .+ (α .* p)
        r = r .- (α .* hvp)

        new_rdotr = r' * r
        β = new_rdotr ./ rdotr
        p = r .+ (β .* p)
        
        rdotr = new_rdotr
        
        if rdotr[1] < err
            break
        end
    end
    
    return x
end

function trpo_update(states,actions,advantages,returns,log_probs)
    if MODE == "CON"
        model_params = params(params(policy_μ)...,params(policy_Σ)...)
        
        # Obtain an estimate of H_inv * g
        flat_policy_grads = get_flat_grads(Tracker.gradient(() -> 
                policy_loss(states,actions,advantages,returns,log_probs),model_params),policy_μ,polic_Σ).data
        x = conjugate_gradients(states,Hvp,-1.0 .* flat_policy_grads,10) # H_inv * g

        δ = 0.01
        step_dir = sqrt.((2 * δ) ./ (x' * Hvp(states,actions,advantages,returns,log_probs,x))) .* x
    else
        model_params = params(policy)
        
        # Obtain an estimate of H_inv * g
        flat_policy_grads = get_flat_grads(Tracker.gradient(() -> 
                policy_loss(states,actions,advantages,returns,log_probs),model_params),policy).data
        x = conjugate_gradients(states,Hvp,-1.0 .* flat_policy_grads,10) # H_inv * g
        
        δ = 0.01
        step_dir = sqrt.((2 * δ) ./ (x' * Hvp(states,actions,advantages,returns,log_probs,x))) .* x
    end
end

"""
Train
"""

function train_step()    
    routs = get_rollouts()
    states,actions,rewards,advantages,returns,log_probs = process_rollouts(routs)

    idxs = partition(shuffle(1:size(states)[end]),BATCH_SIZE)
      
    for epoch in 1:PPO_EPOCHS
        for i in idxs
            mb_states = states[:,i] 
            mb_actions = actions[:,i] 
            mb_advantages = advantages[:,i] 
            mb_returns = returns[:,i] 
            mb_log_probs = log_probs[:,i]
            
            trpo_update(mb_states,mb_actions,mb_advantages,mb_returns,mb_log_probs)
        end
    end
end

function train()
    for i in 1:NUM_EPISODES
        println("EP : $i")
        train_step()
        println("Ep done")
    end
end

┌ Info: activating new environment at /media/shreyas/Data/GSoC/TRPO.jl/Project.toml.
└ @ Pkg.API /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/Pkg/src/API.jl:519


UndefVarError: UndefVarError: gaussian_policy not defined

In [1]:
using Flux
using OpenAIGym
import Reinforce.action
import Reinforce:run_episode
import Flux.params
using Flux.Tracker: grad, update!
using Flux: onehot
using Statistics
using Distributed
using Distributions
using LinearAlgebra
using Base.Iterators
using Random
using BSON:@save,@load
using JLD



In [19]:
# TODO : DiagonalGaussian

"""
policies.jl
"""

mutable struct EnvWrap{T,V}
    env::T
    STATE_SIZE::V
    ACTION_SIZE::V
end

function scale_rewards(env_wrap::EnvWrap,rewards)
    if env_wrap.name == :Pendulum
         rewards = rewards ./ 16.2736044 .+ 2.0f0
    end
    
    rewards
end

"""
Define all policies here
"""

mutable struct CategoricalPolicy
    π # Neural network for the policy
    value_net # Neural network for the value function
    env_wrap # A wrapper for environment variables
end

function CategoricalPolicy(env_wrap::EnvWrap, policy_net = nothing, value_net = nothing)
    if policy_net == nothing
        policy_net = Chain(Dense(
                        env_wrap.STATE_SIZE,30,relu;initW = _random_normal,initb=constant_init),
                        Dense(30,env_wrap.ACTION_SIZE;initW = _random_normal,initb=constant_init),
                        x -> softmax(x))
    end
    
    if value_net == nothing
        value_net = Chain(Dense(env_wrap.STATE_SIZE,30,relu;initW=_random_normal),
                  Dense(30,30,relu;initW=_random_normal),
                  Dense(30,1;initW=_random_normal))
    end
    
    return CategoricalPolicy(policy_net,value_net,env_wrap)
end

"""
Define the following for each policy : 
    `action` : a function taking in the policy variable and giving a particular action according to the environemt
    `log_prob` : a function giving the log probability of an action under the current policy parameters
    `entropy` : a function defining the entropy of the policy distribution

Populate each function with it's appropriate distribution
"""

function action(policy,state)
    """
    policy : A policy type defined in `policy.jl`
    state : output of reset!(env) or step!(env,action)
    """
    
    state = reshape(Array(state),length(state),1)
    a = nothing
    
    if typeof(policy) <: CategoricalPolicy
        action_probs = policy.π(state).data
        action_probs = reshape(action_probs,policy.env_wrap.ACTION_SIZE)
        a = sample(1:policy.env_wrap.ACTION_SIZE,Weights(action_probs)) - 1
    else
        error("Policy type not yet implemented")
    end
    
    a
end

function log_prob(policy,states,actions)
    log_probs = []
    
    for i in 1:length(states)
        if policy <: CategoricalPolicy
            action_probs = policy.π(states[i])
            prob = action_probs[actions[i],:].data
            push!(log_probs,log.(prob))
        else
            error("Not Implemented")
        end
    end
    
    log_probs
end

function entropy(policy)
    if policy <: CategoricalPolicy
        return sum(policy.π .* log.(policy.π .+ 1f-10),dims=1)
    else
        error("Not Implemented")
    end
end

function get_policy_params(policy)
    if policy <: CategoricalPolicy
         return params(policy.π)
    end
end

function get_policy_net(policy)
    """
    Returns the policy neural network
    """
    if policy <: CategoricalPolicy
        return [policy.π]
    end
end

ErrorException: invalid redefinition of constant CategoricalPolicy

In [1]:
using Pkg
Pkg.activate("../../Project.toml")

using Flux
using Gym
import Reinforce.action
import Reinforce:run_episode
import Flux.params
using Flux.Tracker: grad, update!
using Flux: onehot
using Statistics
using Distributed
using Distributions
using LinearAlgebra
using Base.Iterators
using Random
using BSON:@save,@load
using JLD

num_processes = 1
include("../common/policies.jl")
include("../common/utils.jl")
include("../common/buffer.jl")
include("rollout.jl")

┌ Info: activating new environment at /media/shreyas/Data/GSoC/TRPO.jl/Project.toml.
└ @ Pkg.API /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/Pkg/src/API.jl:519


collect_and_process_rollouts (generic function with 1 method)

In [2]:
function initialize_episode_buffer()
    eb = Buffer()
    register(eb,"states")
    register(eb,"actions")
    register(eb,"rewards")
    register(eb,"next_states")
    register(eb,"dones")
    register(eb,"returns")
    register(eb,"advantages")
    register(eb,"log_probs")
    
    return eb
end

function initialize_stats_buffer()
    sb = Buffer()
    register(sb,"rollout_returns")
    
    return sb
end

function get_policy(env_wrap::EnvWrap)
    if typeof(env_wrap.env._env.action_space) <: Gym.Space.Discrete
        return CategoricalPolicy(env_wrap)
    elseif typeof(env_wrap.env._env.action_space) <: Gym.Space.Box
        return DiagonalGaussianPolicy(env_wrap)
    else
        error("Policy type not supported")
    end
end

#----------------Hyperparameters----------------#
# Environment Variables #
ENV_NAME = "CartPole-v0"
EPISODE_LENGTH = 100
# Policy parameters #
η = 3e-4 # Learning rate
STD = 0.0 # Standard deviation
# GAE parameters
γ = 0.99
λ = 0.95
# Optimization parameters
PPO_EPOCHS = 10
NUM_EPISODES = 100000
BATCH_SIZE = 256
c₀ = 1.0
c₁ = 1.0
c₂ = 0.001
# PPO parameters
ϵ = 0.2
# FREQUENCIES
SAVE_FREQUENCY = 50
VERBOSE_FREQUENCY = 5
global_step = 0

# Define policy
env_wrap = EnvWrap(ENV_NAME)
policy = get_policy(env_wrap)

# Define buffers
episode_buffer = initialize_episode_buffer()
stats_buffer = initialize_stats_buffer()

# Define optimizer
opt = ADAM(η)

ADAM(0.0003, (0.9, 0.999), IdDict{Any,Any}())

In [3]:
function loss(policy,states::Array,actions::Array,advantages::Array,returns::Array,old_log_probs::Array)
    new_log_probs = log_prob(policy,states,actions)
    
    # Surrogate loss computations
    ratio = exp.(new_log_probs .- old_log_probs)
    surr1 = ratio .* advantages
    surr2 = clamp.(ratio,(1.0 - ϵ),(1.0 + ϵ)) .* advantages
    policy_loss = mean(min.(surr1,surr2))
    
    value_predicted = policy.value_net(states)
    value_loss = mean((value_predicted .- returns).^2)
    
    entropy_loss = mean(entropy(policy,states))
    
    -c₀*policy_loss + c₁*value_loss - c₂*entropy_loss
end

loss (generic function with 1 method)

In [4]:
function ppo_update(policy,states::Array,actions::Array,advantages::Array,returns::Array,old_log_probs::Array)
    model_params = params(get_policy_params(policy)...,get_value_params(policy)...)
    
    # Calculate gradients
    gs = Tracker.gradient(() -> loss(policy,states,actions,advantages,returns,old_log_probs),model_params)
    
    # Take a step of optimisation
    update!(opt,model_params,gs)
end

ppo_update (generic function with 1 method)

In [5]:
function train_step()    
    clear(episode_buffer)
    collect_and_process_rollouts(policy,episode_buffer,EPISODE_LENGTH,stats_buffer)
    
    idxs = partition(shuffle(1:size(episode_buffer.exp_dict["states"])[end]),BATCH_SIZE)
    
    for epoch in 1:PPO_EPOCHS
        for i in idxs
            mb_states = episode_buffer.exp_dict["states"][:,i] 
            mb_actions = episode_buffer.exp_dict["actions"][:,i] 
            mb_advantages = episode_buffer.exp_dict["advantages"][:,i] 
            mb_returns = episode_buffer.exp_dict["returns"][:,i] 
            mb_log_probs = episode_buffer.exp_dict["log_probs"][:,i]
            
            ppo_update(policy,mb_states,mb_actions,mb_advantages,mb_returns,mb_log_probs)
        end
    end
end

train_step (generic function with 1 method)

In [6]:
function train()
    for i in 1:NUM_EPISODES
        println(i)
        train_step()
        println(mean(stats_buffer.exp_dict["rollout_returns"]))
    end
end

train (generic function with 1 method)

In [7]:
train()

1
10.733765292570691
2
8.73065678820008
3
8.666273437248606
4
9.609581457465662
5
10.091116830384955
6
10.053253043759009
7
9.578152363769075
8
9.335697576204243
9
9.197230719398695
10
10.492536008272543
11
10.514465943208739
12
10.460254957160451
13
10.243197965321219
14
10.057149115173305
15
9.835175726417049
16
9.612197191378899
17
9.627005245452452
18
9.41463988919647
19
9.24900195638648
20
9.572547373266705
21
9.586422932188734
22
9.371993354907653
23


InterruptException: InterruptException: