In [25]:
using OpenAIGym
using Flux
import Statistics: mean
import StatsBase:sample 
using Flux: onehotbatch
import Distributions: Multinomial

In [26]:
env = GymEnv(:Acrobot, :v1)

# env parameters
n_actions = length(env.actions)
n_states = length(env.state)

#models
mlpvalue = Chain(Dense(n_states, 64, tanh), Dense(64, 64, tanh), Dense(64, 1))
mlppolicy = Chain(Dense(n_states, 64, tanh), Dense(64, 64, tanh), Dense(64, n_actions))

Chain(Dense(6, 64, tanh), Dense(64, 64, tanh), Dense(64, 3))

In [27]:
function mlp_categorical_policy(x, hidden_sizes, activation, output_activation, act_dim, a = nothing)    
    logits = mlppolicy(x)
    logp_all = logsoftmax(logits)
    
    if typeof(logits.data) != Array{Float64,2}
        dist = Multinomial(1, softmax(logits.data))
        sample = rand(dist, 1)

        pi = Flux.argmax(dropdims(sample, dims=2))
        if a == nothing
            logp = nothing
        else
            logp = sum(Flux.onehot(a, 1:act_dim) .* logp_all, dims=1)
        end
        logp_pi = sum(Flux.onehot(pi, 1:act_dim) .* logp_all, dims = 1)
        return pi, logp, logp_pi
    else
        pi = []
        for i in 1:size(logits.data)[2]
            k = logits[:, i]
            dist = Multinomial(1, softmax(k.data))
            sample = rand(dist, 1)
            push!(pi, Flux.argmax(dropdims(sample, dims=2)))
        end
        if a == nothing
            logp = nothing
        else
            logp = sum(Flux.onehotbatch(a, 1:act_dim) .* logp_all, dims = 1)
        end
        logp_pi = sum(Flux.onehotbatch(pi, 1:act_dim) .* logp_all, dims = 1)
        return pi, logp', logp_pi'
    end
end

mlp_categorical_policy (generic function with 2 methods)

In [28]:
function mlp_actor_critic(x, act_dim, a = nothing, hidden_sizes=(64,64), 
                          activation=tanh, output_activation=nothing, 
                          policy=nothing)

    # default policy builder depends on action space
    if policy == nothing 
        policy = mlp_categorical_policy
    end
    
    pi, logp, logp_pi = policy(x, hidden_sizes, activation, output_activation, act_dim, a)
    
    v = mlpvalue(x)
    return pi, logp, logp_pi, v
    end

mlp_actor_critic (generic function with 6 methods)

In [29]:
function pi_loss(x_ph, act_dim, a_ph, steps_per_epoch, adv_ph, logp_old_ph)
    pi, logp, logp_pi, v = mlp_actor_critic(x_ph, act_dim, a_ph)
    ratio = exp.(logp - logp_old_ph)
    min_adv = Matrix{Float32}(undef, steps_per_epoch, 1)
    for i in 1:length(adv_ph)
        if adv_ph[i, 1] > 0
            min_adv[i, 1] = (1+0.2)*adv_ph[i, 1]
        else
            min_adv[i, 1] = (1-0.2)*adv_ph[i, 1]
        end
    end
    return -mean(min.(ratio .* adv_ph, min_adv))
end

function v_loss(x_ph, act_dim, a_ph, ret_ph)
    pi, logp, logp_pi, v = mlp_actor_critic(x_ph, act_dim, a_ph)
    return mean((dropdims(ret_ph, dims = 2) - dropdims(v, dims = 1)).^2)
end

function update(x_ph, a_ph, adv_ph, ret_ph, logp_old_ph, 
                act_dim, train_pi_iters, train_v_iters, target_kl, steps_per_epoch, pi_lr, vf_lr) 
    train_pi = Flux.ADAM(Flux.params(mlppolicy), pi_lr)
    train_v = Flux.ADAM(Flux.params(mlpvalue), vf_lr)


    # Training
    for i in 1:train_pi_iters
        Flux.train!(pi_loss, [(x_ph, act_dim, a_ph, steps_per_epoch, adv_ph, logp_old_ph)], train_pi)
    end
        
    for _ in 1:train_v_iters
        Flux.train!(v_loss, [(x_ph, act_dim, a_ph, ret_ph)], train_v)

    end
end

update (generic function with 1 method)

In [30]:
function discount_cumsum(arr, discount)
    """
    magic from rllab for computing discounted cumulative sums of vectors.
    input: 
        vector x, 
        [x0, 
         x1, 
         x2]
    output:
        [x0 + discount * x1 + discount^2 * x2,  1
         x1 + discount * x2,                    2
         x2]                                    3
    """
    rtg = Array{Float32}(undef, length(arr))
    for i in length(arr):-1:1
        if i == length(arr)
            rtg[i] = arr[i]
        else
            rtg[i] = arr[i] + discount * rtg[i+1]
        end
    end
    return rtg
end


discount_cumsum (generic function with 1 method)

In [31]:
epoch_performance = []
function ppo(env, actor_critic=mlp_actor_critic, ac_kwargs=Dict(), seed=0, 
        steps_per_epoch=4000, epochs=25, gamma=0.99, clip_ratio=0.2, pi_lr=3e-4,
        vf_lr=1e-3, train_pi_iters=80, train_v_iters=80, lam=0.97, max_ep_len=1000,
        target_kl=0.01)
    
    Random.seed!(1234567)
    
    act_dim = length(env.actions)
    obs_dim = length(env.state)
    
    #initializing buffer
    obs_buf = Matrix{Float32}(undef, obs_dim, steps_per_epoch)
    act_buf = zeros(Float32, steps_per_epoch)
    adv_buf = Matrix{Float32}(undef, steps_per_epoch, 1)
    rew_buf = Matrix{Float32}(undef, steps_per_epoch, 1)
    ret_buf = Matrix{Float32}(undef, steps_per_epoch, 1)
    val_buf = Matrix{Float32}(undef, steps_per_epoch, 1)
    logp_buf = Matrix{Float32}(undef, steps_per_epoch, 1)
    
    
    
    ptr, path_start_idx, max_size = 1, 1, steps_per_epoch
    
    adv_ph, ret_ph, logp_old_ph = nothing, nothing, nothing

    o, r, d, ep_ret, ep_len = reset!(env), 0, false, 0, 0
    for epoch ∈ 1:epochs
        epoch_p = []
        for t ∈ 1: steps_per_epoch
            
            x_ph = reshape(o, :, 1)
            temp = []
            for i in 1:length(x_ph)
                push!(temp, x_ph[i])
            end
            x_ph = temp
            a, _, logp_t, v_t = mlp_actor_critic(x_ph, act_dim)
            v_t = v_t.data
            logp_t = logp_t.data

            # store in buffer
            obs_buf[:, ptr] = convert(Array{Float32},o)
            
            act_buf[ptr] = a 
            rew_buf[ptr] = r
            val_buf[ptr] = v_t[1]
            logp_buf[ptr] = logp_t[1]
            ptr += 1
            
            r, o = step!(env, a-1)
            
            ep_ret += r
            ep_len += 1
            
            terminal = env.done | (ep_len == max_ep_len)
            if terminal || (t==steps_per_epoch)
                
                if ~terminal
                    println("Warning: trajectory cut off by epoch at $ep_len steps.")
                end
                # if trajectory didn't reach terminal state, bootstrap value target
                if env.done
                    last_val = r
                else
                    temp = []
                    temp1 = reshape(o, :, 1)
                    for i in 1:length(temp1)
                        push!(temp, temp1[i])
                    end
                    _, _, _, last_val = mlp_actor_critic(temp, act_dim)
                    last_val = last_val.data
                end
                rews = push!(rew_buf[path_start_idx: ptr-1], last_val[1])
                vals = push!(val_buf[path_start_idx: ptr-1], last_val[1])
                
        
                # the next two lines implement GAE-Lambda advantage calculation
                deltas = rews[1:end-1] + gamma * vals[2:end] - vals[1:end-1]
                adv_buf[path_start_idx: ptr-1] = discount_cumsum(deltas, gamma * lam)
        
                # the next line computes rewards-to-go, to be targets for the value function
                ret_buf[path_start_idx: ptr-1] = discount_cumsum(rews, gamma)[1:end-1]
                
                path_start_idx = ptr
                
#                 if terminal
#                     println("Reward $ep_ret , Length $ep_len")
#                 end
                push!(epoch_p, ep_ret)
                o, r, ep_ret, ep_len = reset!(env), 0, false, 0, 0
            end
            
        end
        ptr, path_start_idx = 1, 1
        println(epoch, " ", mean(epoch_p[1:end-1]), " ", length(epoch_p))
        push!(epoch_performance, mean(epoch_p[1:end-1]))
        update(obs_buf, act_buf, adv_buf, ret_buf, logp_buf, act_dim, 
               train_pi_iters, train_v_iters, target_kl, steps_per_epoch,
               pi_lr, vf_lr)   
    end 
end

ppo (generic function with 15 methods)

In [None]:
env = GymEnv(:Acrobot, :v1)
ppo(env)



│   caller = mlp_categorical_policy(::Array{Any,1}, ::Tuple{Int64,Int64}, ::Function, ::Nothing, ::Int64, ::Nothing) at In[27]:9
└ @ Main ./In[27]:9


1 -491.125 9


│   caller = mlp_categorical_policy(::Array{Float32,2}, ::Tuple{Int64,Int64}, ::Function, ::Nothing, ::Int64, ::Array{Float32,1}) at In[27]:23
└ @ Main ./In[27]:23


2 -466.5 9
3 -493.625 9
4 -485.75 9
5 -477.125 9
6 -486.625 9
7 -361.5 11
8 -437.125 9
9 -481.0 9
10 -402.0 10
11 -483.375 9
12 -467.375 9
13 -497.125 9
14 -485.875 9
15 -411.0 10
16 -358.09090909090907 12
17 -299.0 14
18 -214.33333333333334 19

In [8]:
env = GymEnv(:CartPole, :v0)
ppo(env)

  likely near /Users/parthshah/.julia/packages/IJulia/DL02A/src/kernel.jl:41
  likely near /Users/parthshah/.julia/packages/IJulia/DL02A/src/kernel.jl:41
  likely near /Users/parthshah/.julia/packages/IJulia/DL02A/src/kernel.jl:41
  likely near /Users/parthshah/.julia/packages/IJulia/DL02A/src/kernel.jl:41
│   caller = mlp_categorical_policy(::Array{Any,1}, ::Tuple{Int64,Int64}, ::Function, ::Nothing, ::Int64, ::Nothing) at In[3]:9
└ @ Main ./In[3]:9


1 30.16030534351145 132Any[14.0, 64.0, 38.0, 41.0, 37.0, 22.0, 15.0, 26.0, 37.0, 48.0, 23.0, 31.0, 15.0, 15.0, 96.0, 13.0, 22.0, 66.0, 18.0, 12.0, 42.0, 17.0, 23.0, 22.0, 17.0, 19.0, 13.0, 58.0, 37.0, 23.0, 15.0, 34.0, 11.0, 42.0, 16.0, 30.0, 35.0, 57.0, 14.0, 17.0, 21.0, 37.0, 25.0, 22.0, 27.0, 30.0, 26.0, 67.0, 25.0, 18.0, 37.0, 20.0, 27.0, 28.0, 46.0, 19.0, 15.0, 37.0, 16.0, 23.0, 21.0, 36.0, 30.0, 22.0, 86.0, 22.0, 27.0, 17.0, 15.0, 54.0, 38.0, 17.0, 94.0, 23.0, 32.0, 37.0, 15.0, 47.0, 39.0, 16.0, 19.0, 9.0, 21.0, 63.0, 73.0, 22.0, 37.0, 21.0, 37.0, 37.0, 16.0, 45.0, 22.0, 51.0, 24.0, 22.0, 37.0, 64.0, 36.0, 52.0, 15.0, 22.0, 14.0, 13.0, 67.0, 25.0, 19.0, 20.0, 26.0, 29.0, 19.0, 36.0, 18.0, 34.0, 37.0, 16.0, 11.0, 24.0, 13.0, 43.0, 16.0, 43.0, 30.0, 39.0, 38.0, 21.0, 31.0, 16.0, 14.0, 22.0, 23.0, 49.0]


│   caller = mlp_categorical_policy(::Array{Float32,2}, ::Tuple{Int64,Int64}, ::Function, ::Nothing, ::Int64, ::Array{Float32,1}) at In[3]:23
└ @ Main ./In[3]:23


2 50.23376623376623 78Any[35.0, 35.0, 36.0, 33.0, 79.0, 19.0, 21.0, 81.0, 38.0, 29.0, 32.0, 29.0, 88.0, 41.0, 129.0, 51.0, 58.0, 37.0, 100.0, 103.0, 23.0, 108.0, 44.0, 36.0, 24.0, 99.0, 84.0, 13.0, 58.0, 29.0, 52.0, 12.0, 46.0, 27.0, 60.0, 21.0, 72.0, 20.0, 113.0, 18.0, 108.0, 110.0, 58.0, 51.0, 64.0, 55.0, 60.0, 60.0, 33.0, 26.0, 56.0, 33.0, 121.0, 64.0, 32.0, 83.0, 36.0, 85.0, 29.0, 36.0, 33.0, 60.0, 49.0, 24.0, 23.0, 62.0, 14.0, 16.0, 50.0, 21.0, 56.0, 33.0, 60.0, 15.0, 72.0, 26.0, 21.0, 132.0]
3 147.37037037037038 28Any[170.0, 134.0, 112.0, 88.0, 164.0, 144.0, 200.0, 198.0, 128.0, 147.0, 40.0, 111.0, 200.0, 170.0, 34.0, 142.0, 103.0, 200.0, 141.0, 41.0, 200.0, 200.0, 200.0, 157.0, 200.0, 162.0, 193.0, 21.0]
4 185.66666666666666 22Any[200.0, 161.0, 200.0, 187.0, 200.0, 200.0, 187.0, 139.0, 140.0, 200.0, 200.0, 200.0, 178.0, 198.0, 200.0, 168.0, 200.0, 200.0, 184.0, 194.0, 163.0, 101.0]
5 187.95238095238096 22Any[200.0, 200.0, 178.0, 200.0, 200.0, 180.0, 171.0, 200.0, 200.0, 200.0, 2