In [1]:
using Pkg
Pkg.activate("../../Project.toml")

using Flux
using Gym
import Reinforce.action
import Reinforce:run_episode
import Flux.params
using Flux.Tracker: grad, update!
using Flux: onehot
using Statistics
using Distributed
using Distributions
using LinearAlgebra
using Base.Iterators
using Random
using BSON
using BSON:@save,@load
using JLD

┌ Info: activating new environment at /media/shreyas/Data/GSoC/PPO.jl/Project.toml.
└ @ Pkg.API /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/Pkg/src/API.jl:519
┌ Info: Recompiling stale cache file /home/shreyas/.julia/compiled/v1.1/Flux/QdkVy.ji for Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1184
┌ Info: Precompiling Gym [56b9baea-2481-11e9-37ae-75904354ad8c]
└ @ Base loading.jl:1186


imported




In [2]:
num_processes = 1
include("../common/policies.jl")
include("../common/utils.jl")
include("../common/buffer.jl")
include("rollout.jl")
include("trpo_utils.jl")

function initialize_episode_buffer()
    eb = Buffer()
    register(eb,"states")
    register(eb,"actions")
    register(eb,"rewards")
    register(eb,"next_states")
    register(eb,"dones")
    register(eb,"returns")
    register(eb,"advantages")
    register(eb,"log_probs")
    register(eb,"kl_params")
    
    return eb
end

function initialize_stats_buffer()
    sb = Buffer()
    register(sb,"rollout_returns")
    
    return sb
end

function get_policy(env_wrap::EnvWrap)
    if typeof(env_wrap.env._env.action_space) <: Gym.Space.Discrete
        return CategoricalPolicy(env_wrap)
    elseif typeof(env_wrap.env._env.action_space) <: Gym.Space.Box
        return DiagonalGaussianPolicy(env_wrap)
    else
        error("Policy type not supported")
    end
end

get_policy (generic function with 1 method)

In [3]:
# ----------------Hyperparameters----------------#
# Environment Variables #
ENV_NAME = "Pendulum-v0"
EPISODE_LENGTH = 100
# Policy parameters #
η = 3e-4 # Learning rate
STD = 0.0 # Standard deviation
# GAE parameters
γ = 0.99
λ = 0.95
# Optimization parameters
δ = 0.01 # KL-Divergence constraint
PPO_EPOCHS = 10
NUM_EPISODES = 100000
BATCH_SIZE = 256
c₀ = 1.0
c₁ = 1.0
c₂ = 0.001
# PPO parameters
ϵ = 0.2
# FREQUENCIES
SAVE_FREQUENCY = 50
VERBOSE_FREQUENCY = 5
global_step = 0

# Define policy
env_wrap = EnvWrap(ENV_NAME)
policy = get_policy(env_wrap)

# Define buffers
episode_buffer = initialize_episode_buffer()
stats_buffer = initialize_stats_buffer()

# Define optimizer for optimizing value function neural network
opt_value = ADAM(η)

ADAM(0.0003, (0.9, 0.999), IdDict{Any,Any}())

In [73]:
function kl_loss(policy,states::Array,kl_vars)
    return mean(kl_divergence(policy,kl_vars,states))
end

function policy_loss(policy,states::Array,actions::Array,advantages::Array,old_log_probs::Array)
    # Surrogate loss computation
    new_log_probs = log_prob(policy,states,actions)

    ratio = new_log_probs .- old_log_probs
    π_loss = mean(ratio .* advantages)
    π_loss
end

function value_loss(policy,states::Array,returns::Array)
    return mean((policy.value_net(states) .- returns).^2)
end

function linesearch(policy,step_dir,states,actions,advantages,old_log_probs,kl_vars,num_steps=10;α=0.5)
    old_loss = policy_loss(policy,states,actions,advantages,old_log_probs).data
    old_params = get_flat_params(get_policy_net(policy))

    for i in 1:num_steps
        # Obtain new parameters
        new_params = old_params .+ ((α^i) .* step_dir)

        # Set the new parameters to the policy
        set_flat_params(new_params,get_policy_net(policy))

        # Compute surrogate loss
        new_loss = policy_loss(policy,states,actions,advantages,old_log_probs).data
        
        # Compute kl divergence
        kl_div = kl_loss(policy,states,kl_vars).data
        
        if new_loss > old_loss && (kl_div <= δ)
            set_flat_params(new_params,get_policy_net(policy))
        end
    end
    
    set_flat_params(old_params,get_policy_net(policy))
end

function trpo_update(policy,states,actions,advantages,returns,log_probs,kl_vars)
    model_params = get_policy_params(policy)
    policy_grads = Tracker.gradient(() -> policy_loss(policy,states,actions,advantages,log_probs),model_params)
    flat_policy_grads = get_flat_grads(policy_grads,get_policy_net(policy)).data

    x = conjugate_gradients(policy,states,kl_vars,Hvp,-1.0 .* flat_policy_grads,10)
    step_dir = sqrt.((2 * δ) ./ (x' * Hvp(policy,states,kl_vars,x))) .* x
    
    # Do a line search and update the parameters
    linesearch(policy,step_dir,states,actions,advantages,log_probs,kl_vars)
    
    # Update value function
    value_params = get_value_params(policy)
    gs = Tracker.gradient(() -> value_loss(policy,states,returns),value_params)
    update!(opt_value,value_params,gs)
end

trpo_update (generic function with 1 method)

In [63]:
function train_step()    
    clear(episode_buffer)
    collect_and_process_rollouts(policy,episode_buffer,EPISODE_LENGTH,stats_buffer)
    
    idxs = partition(shuffle(1:size(episode_buffer.exp_dict["states"])[end]),BATCH_SIZE)
    
    for epoch in 1:PPO_EPOCHS
        for i in idxs
            mb_states = episode_buffer.exp_dict["states"][:,i] 
            mb_actions = episode_buffer.exp_dict["actions"][:,i] 
            mb_advantages = episode_buffer.exp_dict["advantages"][:,i] 
            mb_returns = episode_buffer.exp_dict["returns"][:,i] 
            mb_log_probs = episode_buffer.exp_dict["log_probs"][:,i]
            mb_kl_vars = episode_buffer.exp_dict["kl_params"][i]
            
            trpo_update(policy,mb_states,mb_actions,mb_advantages,mb_returns,mb_log_probs,mb_kl_vars)
        end
    end
end

function train()
    for i in 1:NUM_EPISODES
        println(i)
        train_step()
    end
end

train (generic function with 1 method)

In [75]:
train()

1
[0.0271148] (tracked)
1
[0.0163902] (tracked)
0.0672244845767022
0.0
-0.04306243018707241
2
[0.0217525] (tracked)
0.007863635989935109
0.0
-0.017083452819085496
3
[0.0244336] (tracked)
0.0009252320712554929
0.0
-0.007093330147669706
4
[0.0257742] (tracked)
0.00011227930553915133
0.0
-0.0029040975922650283
5
[0.0264445] (tracked)
1.677673288771886e-5
0.0
-0.0012419467774290565
6
[0.0267796] (tracked)
3.1245598895490724e-6
0.0
-0.0005620099318288073
7
[0.0269472] (tracked)
6.689130622167249e-7
0.0
-0.00026548310947069494
8
[0.027031] (tracked)
1.5170308272138834e-7
0.0
-0.0001287683146364951
9
[0.0270729] (tracked)
3.29785660052595e-8
0.0
-6.338020083998606e-5
10
[0.0270938] (tracked)
4.5244595081594954e-9
0.0
-3.1437703516758945e-5
[0.0271148] (tracked)
1
[0.0163902] (tracked)
0.0672244845767022
0.0
-0.04306243018707241
2
[0.0217525] (tracked)
0.007863635989935109
0.0
-0.017083452819085496
3
[0.0244336] (tracked)
0.0009252320712554929
0.0
-0.007093330147669706
4
[0.0257742] (tracked)


InterruptException: InterruptException:

In [65]:
idxs = partition(shuffle(1:size(episode_buffer.exp_dict["states"])[end]),BATCH_SIZE)

Base.Iterators.PartitionIterator{Array{Int64,1}}([49, 10, 62, 18, 85, 63, 52, 87, 70, 79  …  15, 38, 98, 67, 91, 95, 77, 20, 4, 27], 256)

In [74]:
for i in idxs
    mb_states = episode_buffer.exp_dict["states"][:,i] 
    mb_actions = episode_buffer.exp_dict["actions"][:,i] 
    mb_advantages = episode_buffer.exp_dict["advantages"][:,i] 
    mb_returns = episode_buffer.exp_dict["returns"][:,i] 
    mb_log_probs = episode_buffer.exp_dict["log_probs"][:,i]
    mb_kl_vars = episode_buffer.exp_dict["kl_params"][i]

    trpo_update(policy,mb_states,mb_actions,mb_advantages,mb_returns,mb_log_probs,mb_kl_vars)
end

[0.0271148] (tracked)
1
[0.0214509] (tracked)
0.28235330333371783
0.0
-0.020369339185023407
2
[0.0242829] (tracked)
0.029566831754306248
0.0
-0.0006483389404778503
3
[0.0256988] (tracked)
0.004127729653248764
0.0
-0.0012264620685809933
4
[0.0264068] (tracked)
0.0005165995353730601
0.0
-0.00174077584381878
5
[0.0267608] (tracked)
4.065471548573285e-5
0.0
-0.0011430875293847258
6
[0.0269378] (tracked)
3.98564437463722e-6
0.0
-0.0006328873572621704
7
[0.0270263] (tracked)
6.636894742007015e-7
0.0
-0.0003313537789526239
8
[0.0270705] (tracked)
1.467643970515242e-7
0.0
-0.0001694026168658624
9
[0.0270927] (tracked)
3.260467021037883e-8
0.0
-8.564096522885785e-5
10
[0.0271037] (tracked)
4.613465157676799e-9
0.0
-4.305914912324087e-5


In [64]:
mus = hcat([kl_vars[i][1] for i in 1:length(kl_vars)]...)
logsigmas = hcat([kl_vars[i][2] for i in 1:length(kl_vars)]...)

1×300 Array{Float64,2}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0

In [48]:
old_params = get_flat_params(get_policy_net(policy)).data

152×1 Array{Float64,2}:
 -0.0030182378832250834
  0.014334198087453842 
  0.007296884898096323 
 -0.0030412396881729364
  0.023370012640953064 
  0.0070809233002364635
 -0.012167193926870823 
  0.029113518074154854 
 -0.01813005842268467  
 -0.014507287181913853 
 -0.002038281876593828 
 -0.03368284925818443  
 -0.006013437639921904 
  ⋮                    
  0.052237290889024734 
  0.025447048246860504 
  0.02046864479780197  
  0.049293071031570435 
 -0.02386072650551796  
 -0.013537056744098663 
  0.048183321952819824 
  0.026413224637508392 
 -0.004033444449305534 
  0.035017382353544235 
 -0.19217334687709808  
  0.027114788070321083 