In [None]:
"""
- Rollouts
- Define networks
- GAE
- Utils : Normal distributions and log probability of an action
- train
- logging utilities

Experiment with shared network for both policy and value function
"""

using Flux, CuArrays
using OpenAIGym
import Reinforce.action
import Reinforce:run_episode
import Flux.params
using Flux.Tracker: grad, update!
using Flux: onehot
using Statistics
using Distributed
using Distributions
using LinearAlgebra
using Base.Iterators
using BSON:@save,@load
using JLD

"""
A few intricacies : 
The policy is a Normal distribution and the `policy_net` outputs the `μ` and `logσ`.
Each action is assumed to be independent of the others.
Thus our covariance matrix is a diagonal matrix with each element representing the variance of
taking a particular action.
"""

"""
Utilities
"""
# weight initialization
function _random_normal(shape...)
    return map(Float32,rand(Normal(0,0.1),shape...))
end

function constant_init(shape...)
    return map(Float32,ones(shape...) * 0.1)
end

function normal_log_prob(μ,log_std,a)
    """
    Returns the log probability of an action under a policy Gaussian policy π
    """
    σ = exp.(log_std)
    σ² = σ.^2
    -(((a .- μ).^2)./(2.0 * σ²)) .- 0.5*log.(sqrt(2 * π)) .- log.(σ)
end

function normal_entropy(log_std)
    0.5 + 0.5 * log(2 * π) .+ log_std
end

function normalise(arr)
    (arr .- mean(arr))./(sqrt(var(arr) + 1e-10))
end

# Logging #
"""
Create logging utility
"""
mutable struct Logger
    hist_dict
end

Logger() = Logger(Dict())

function register(l::Logger,name::String)
     l.hist_dict[name] = []
end

function add(l::Logger,name,value)
    """
    Add a variable for it's history to be logged
    """
    if !(name in l.hist_dict.keys)
        err("Error...")
    else
        push!(l.hist_dict[name],value)
    end
end

"""
HYPERPARAMETERS
"""
# Policy parameters #
η = 3e-4 # Learning rate
STD = 0.0 # Standard deviation
HIDDEN_SIZE = 256
# Environment Variables #
STATE_SIZE = 3
ACTION_SIZE = 1
MIN_RANGE = -2.0f0
MAX_RANGE = 2.0f0
EPISODE_LENGTH = 20
TEST_STEPS = 10000
# GAE parameters
γ = 0.99
λ = 0.95
# Optimization parameters
PPO_EPOCHS = 4
NUM_EPISODES = 2000
BATCH_SIZE = 5
c₀ = 1.0
c₁ = 0.5
c₂ = 0.001
# PPO parameters
ϵ = 0.2
# FREQUENCIES
SAVE_FREQUENCY = 50
VERBOSE_FREQUENCY = 50
global_step = 0
# policy_type = "gaussian"

reward_hist = []

mutable struct PendulumPolicy <: Reinforce.AbstractPolicy
  train::Bool

  function PendulumPolicy(train = true)
    new(train)
  end
end

"""
Define the networks
"""
# if policy_type == "gaussian"
policy_μ = Chain(Dense(STATE_SIZE,HIDDEN_SIZE,relu;initW = _random_normal,initb=constant_init),
                 Dense(HIDDEN_SIZE,ACTION_SIZE;initW = _random_normal,initb=constant_init),
                 x->tanh.(x),
                 x->param(2.0) .* x) 
policy_Σ = param(ones(ACTION_SIZE) * STD)

value = Chain(Dense(STATE_SIZE,HIDDEN_SIZE,relu),
                  Dense(HIDDEN_SIZE,1))

# elseif policy_type == "linear"
#     policy = Chain(Dense(STATE_SIZE,HIDDEN_SIZE,relu;initW = _random_normal,initb=constant_init),
#                      Dense(HIDDEN_SIZE,ACTION_SIZE;initW = _random_normal,initb=constant_init),
#                      x->tanh.(x))
    
#     value = Chain(Dense(STATE_SIZE,HIDDEN_SIZE,relu),
#                   Dense(HIDDEN_SIZE,1))
# end

# Optimizer
opt = ADAM(η)

function action(state)
    # Acccounting for the element type
    state = reshape(Array(state),length(state),1)
    
    a = nothing
    # Our policy outputs the parameters of a Normal distribution
#     if policy_type == "gaussian"
    μ = policy_μ(state)
    μ = reshape(μ,ACTION_SIZE)
    log_std = policy_Σ
    
    σ² = (exp.(log_std)).^2
    Σ = diagm(0=>σ².data)
    
    dis = MvNormal(μ.data,Σ)
    
    a = rand(dis,ACTION_SIZE)
    
#     elseif policy_type == "linear"
#         out = policy(state)
#     end
    a
end

function run_episode(env)
    experience = []
    
    s = reset!(env)
    for i in 1:EPISODE_LENGTH
        a = action(s)
        a = convert.(Float64,a)
        a = reshape(a,ACTION_SIZE)
        
        r,s_ = step!(env,a)
        push!(experience,(s,a,r,s_))
        s = s_
        if env.done
           break 
        end
    end
    experience
end


function test_run(env)
    ep_r = 0.0
    
    s = reset!(env)
    for i in 1:TEST_STEPS
        OpenAIGym.render(env)
        a = policy_μ(s).data
        a = convert.(Float64,a)
        a = reshape(a,ACTION_SIZE)
        println("Action : $a")
        
        r,s_ = step!(env,a)
        ep_r += r
        
        s = s_
        if env.done
           break 
        end
    end
    ep_r
end

In [None]:
"""
Rollout collection
"""
num_processes = 9
addprocs(num_processes) 

@everywhere function collect(env)
    run_episode(env)
end

@everywhere function rollout()
  env = GymEnv(:Pendulum,:v0)
  return collect(env)
end

function get_rollouts()
    g = []
    for  w in workers()
      push!(g, rollout())
    end

    rollouts = fetch.(g)
end

In [None]:
function gae(states,actions,rewards,next_states)
    """
    Returns a Generalized Advantage Estimate for an episode
    """
    Â = []
    A = 0.0
    for i in reverse(1:length(states))
        δ = rewards[i] + γ*value(next_states[i]).data[1] - value(states[i]).data[1]
        A = δ + (γ*λ*A)
        push!(Â,A)
    end
    
    Â = reverse(Â)
    return Â
end

function disconunted_returns(rewards)
    r = 0.0
    returns = []
    for i in reverse(1:length(rewards))
        r = rewards[i] + γ*r
        push!(returns,r)
    end
    returns = reverse(returns)
    returns
end

function log_prob_from_actions(states,actions)
    """
    Returns log probabilities of the actions taken
    
    states,actions : episode vairbles in the form of a list
    """
    log_probs = []
    
    for i in 1:length(states)
        μ = reshape(policy_μ(states[i]),ACTION_SIZE).data
        logΣ = policy_Σ.data
        push!(log_probs,normal_log_prob(μ,logΣ,actions[i]))
    end
    
    log_probs
end

function process_rollouts(rollouts)
    """
    rollouts : variable returned by calling `get_rollouts`
    
    Returns : 
    states, actions, rewards for minibatch processing
    """
    # Process the variables
    states = []
    actions = []
    rewards = []
    next_states = []
    advantages = []
    returns = []
    log_probs = []
    
    # Logging statistics
    episode_mean_returns = []
    
    for ro in rollouts
        episode_states = []
        episode_actions = []
        episode_rewards = []
        episode_next_states = []
        
        for i in 1:length(ro)
             push!(episode_states,Array(ro[i][1]))
             push!(episode_actions,ro[i][2])
             push!(episode_rewards,ro[i][3])
             push!(episode_next_states,ro[i][4])
        end
        
#         println("Ep Max A : $(maximum(episode_actions))")
        
        episode_advantages = gae(episode_states,episode_actions,episode_rewards,episode_next_states)
        episode_returns = disconunted_returns(episode_rewards)
        
        push!(episode_mean_returns,mean(episode_returns))
        
        episode_advantages = normalise(episode_advantages)
        
        push!(states,episode_states)
        push!(actions,episode_actions)
        push!(rewards,episode_rewards)
        push!(advantages,episode_advantages)
        push!(returns,episode_returns)
        push!(log_probs,log_prob_from_actions(episode_states,episode_actions))
    end
    
    states = cat(states...,dims=1)
    actions = cat(actions...,dims=1)
    rewards = cat(rewards...,dims=1)
    advantages = cat(advantages...,dims=1)
    returns = cat(returns...,dims=1)
    log_probs = cat(log_probs...,dims=1)
    
    push!(reward_hist,mean(episode_mean_returns))
    
    if length(reward_hist) <= 100
        println("RETURNS : $(mean(episode_mean_returns))")
    else
        println("LAST 100 RETURNS : $(mean(reward_hist[end-100:end]))")
    end
    
    return hcat(states...),hcat(actions...),hcat(rewards...),hcat(advantages...),hcat(returns...),hcat(log_probs...)
end

In [None]:
function print_losses(pl,vl,el) 
   println("------")
   println("Policy Loss : $pl")
   println("Value Loss : $vl")
   println("Entropy Loss : $el") 
   println("------")
end

function loss(states,actions,advantages,returns,old_log_probs)
#     println("---")
#     println(size(states))
#     println("States : $states")
#     println("---")
#     println("Actions : $actions")
#     println("---")
    global global_step
    global_step += 1
    
    μ = policy_μ(states)
    logΣ = policy_Σ
    
#     println("μ : $μ")
#     println("---")
#     println("logΣ : $logΣ")
#     println("---")
    
    new_log_probs = normal_log_prob(μ,logΣ,actions)
#     println("New Log Probs : $new_log_probs")
#     println("Old Log Probs : $old_log_probs")
    
    # Surrogate loss computation
    ratio = exp.(new_log_probs .- old_log_probs)
    surr1 = ratio .* advantages
    surr2 = clamp.(ratio,1.0 - ϵ,1.0 + ϵ)
    policy_loss = mean(min.(surr1,surr2))
    
#     println("Surr1 : $surr1")
#     println("Surr2 : $surr2")
#     println("Policy Loss : $policy_loss")
    
    value_predicted = value(states)
    value_loss = mean((value_predicted .- returns).^2)
#     println("Value Loss : $value_loss")
    
    entropy_loss = mean(normal_entropy(logΣ))
    
#     if global_step % VERBOSE_FREQUENCY == 0
#         print_losses(policy_loss.data,value_loss.data,entropy_loss.data)
#     end
    
    -c₀*policy_loss + c₁*value_loss # - c₂*entropy_loss
end

function ppo_update(states,actions,advantages,returns,old_log_probs)
    # Define model parameters
    model_params = params(params(policy_μ)...,params(policy_Σ)...,params(value)...)

    # Calculate gradients
    gs = Tracker.gradient(() -> loss(states,actions,advantages,returns,old_log_probs),model_params)
#     println("Gradient Done")
    
    g = gs[policy_μ.layers[1].W]
#     println("GRAD : $(mean(g))")
    # Take a step of optimisation
    update!(opt,model_params,gs)
#     println("Update Done")
end

function train_step()    
    routs = get_rollouts()
    states,actions,rewards,advantages,returns,log_probs = process_rollouts(routs)
    
    idxs = partition(1:size(states)[end],BATCH_SIZE)
    
    for epoch in 1:PPO_EPOCHS
#         println("Epoch : $epoch")
        for i in idxs
#             println(i)
            mb_states = states[:,i] 
            mb_actions = actions[:,i] 
            mb_advantages = advantages[:,i] 
            mb_returns = returns[:,i] 
            mb_log_probs = log_probs[:,i]
            
            ppo_update(mb_states,mb_actions,mb_advantages,mb_returns,mb_log_probs)
        end
    end
end

function train()
    for i in 1:NUM_EPISODES
        println("EP : $i")
        train_step()
        println("Ep done")
        
        if i%SAVE_FREQUENCY == 0
            @save "weights/policy_mu.bson" policy_μ
            @save "weights/policy_sigma.bson" policy_Σ
            @save "weights/value.bson" value
            
            save("stats.jld","rewards",reward_hist)
            println("\n\n\n----MAX REWRD SO FAR : $(maximum(reward_hist))---\n\n\n")
        end
    end
end

In [None]:
train()

In [None]:
# Test the policy #
env = GymEnv("Pendulum-v0")
env.pyenv._max_episode_steps = 50000
TEST_STEPS = 50000

r = test_run(env)

In [None]:
policy_base(Array(ro[1][1][1]))

In [None]:
function loss(a)
    r = exp.(a .- (ones(size(a)) |> gpu))
    s1 = r .* (rand(size(a)) |> gpu)
    s2 = clamp.(a,0.9,1.1)
    -1.0 * mean(min.(s1,s2))
end

In [None]:
gs = Tracker.gradient(() -> loss(out),params(m))

In [None]:
gs[m.layers[1].W]

In [None]:
using Flux,CuArrays
using Flux:Tracker
using Statistics

m = Chain(Conv((3,3),3=>64)) |> gpu
x = rand(256,256,3,1) |> gpu

In [None]:
function loss(x)
   out = m(x)
   mean(out)
end

In [None]:
function test()
    @time o = loss(x)
    @time gs = Tracker.gradient(() -> loss(x),params(m))
end

test()

In [None]:
using Flux,CuArrays
using Flux:Tracker
using Flux:@treelike
using Statistics

In [None]:
struct Net
    u
end

@treelike Net

function Net()
   Net(Chain(Conv((3,3),3=>64))) 
end

function (n::Net)(x)
   return n.u(x) 
end

In [None]:
m = Net() |> gpu
x = rand(256,256,3,1) |> gpu

function loss(x)
   out = m(x)
   mean(out)
end

function test()
    @time o = loss(x)
    @time gs = Tracker.gradient(() -> loss(x),params(m))
end

test()

In [None]:
a = param(ones(2)) * 1.0

In [None]:
policy_Σ

In [None]:
using Pkg
Pkg.add("Plots")

In [None]:
using JLD

In [None]:
reward_hist

In [None]:
mean(reward_hist[end-100-100:end-100])

In [None]:
using JLD
save("stats.jld","rewards",reward_hist)

In [None]:
rh = load("stats.jld")

In [None]:
using Pkg
Pkg.activate("~/envs/test")

In [None]:
using JLD
using Plots

In [None]:
rh

In [None]:
policy_Σ

In [None]:
i = 1
if i == 1
    a = 3
else
    b = 5
end

In [None]:
using OpenAIGym
env = GymEnv(:Pendulum, :v0)
for i ∈ 1:20
  T = 0
  R = run_episode(env, RandomPolicy()) do (s, a, r, s′)
    render(env)
    T += 1
  end
  @info("Episode $i finished after $T steps. Total reward: $R")
end

In [3]:
"""
Test
"""

using Flux, CuArrays
using OpenAIGym
import Reinforce.action
import Reinforce:run_episode
import Flux.params
using Flux.Tracker: grad, update!
using Flux: onehot
using Statistics
using Distributed
using Distributions
using LinearAlgebra
using Base.Iterators
using BSON:@save,@load
using JLD

ACTION_SIZE = 1
TEST_STEPS = 10000

# Load the policy
@load "./weights/pendulum-working/policy_mu.bson" policy_μ
@load "./weights/pendulum-working/policy_sigma.bson" policy_Σ

# Test Run Function
function test_run(env)
    ep_r = 0.0
    
    s = reset!(env)
    for i in 1:TEST_STEPS
        OpenAIGym.render(env)
        a = policy_μ(s).data
        a = convert.(Float64,a)
        a = reshape(a,ACTION_SIZE)

        r,s_ = step!(env,a)
        ep_r += r
        
        s = s_
        if env.done
           break 
        end
    end
    ep_r
end

env = GymEnv("Pendulum-v0")
env.pyenv._max_episode_steps = TEST_STEPS

r = test_run(env)
println("---Total Steps : $TEST_STEPS ::: Total Reward : $r---")

---Total Steps : 10000 ::: Total Reward : -344.9227815301468---
