In [1]:
#----------------__CHECKING__------------------- #
using Flux, CuArrays
using OpenAIGym
import Reinforce.action
import Reinforce:run_episode
import Flux.params
using Flux.Tracker: grad, update!
using Flux: onehot
using Statistics
using Distributed
using Distributions
using LinearAlgebra
using Base.Iterators
using BSON:@save,@load
using JLD

include("policy.jl")

│ Try running `] pin CuArrays@0.9`.
└ @ Flux.CUDA /home/shreyas/.julia/packages/Flux/WSB7k/src/cuda/cuda.jl:12


value_fn (generic function with 2 methods)

In [2]:
"""
HYPERPARAMETERS
"""
# Environment Creation #
env_name = "CartPole-v0"
MODE = "CAT" # Can be either "CON" (Continuous) or "CON" (Categorical)

# Environment Variables #
STATE_SIZE = 4
ACTION_SIZE = 2
MIN_RANGE = -2.0f0
MAX_RANGE = 2.0f0
EPISODE_LENGTH = 100
TEST_STEPS = 10000
# Policy parameters #
η = 3e-4 # Learning rate
STD = 0.0 # Standard deviation
HIDDEN_SIZE = 256
# GAE parameters
γ = 0.99
λ = 0.95
# Optimization parameters
PPO_EPOCHS = 10
NUM_EPISODES = 15000
BATCH_SIZE = 5
c₀ = 1.0
c₁ = 0.5
c₂ = 0.001
# PPO parameters
ϵ = 0.2
# FREQUENCIES
SAVE_FREQUENCY = 50
VERBOSE_FREQUENCY = 5
global_step = 0

# Global variable to monitor losses
reward_hist = []
policy_l = 0.0
entropy_l = 0.0
value_l = 0.0

0.0

In [3]:
function scale_rewards(rewards)
    return rewards # ./ 16.2736044
end

scale_rewards (generic function with 1 method)

In [4]:
"""
Define the networks
"""

if MODE == "CON"
	policy_μ,policy_Σ = gaussian_policy(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE)
	value = value_fn(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE,tanh)
elseif MODE == "CAT"
	policy = categorical_policy(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE)
	value = value_fn(STATE_SIZE,HIDDEN_SIZE,ACTION_SIZE,relu)
else 
	error("MODE can only be (CON) or (CAT)...")
end

opt = ADAM(η)

ADAM(0.0003, (0.9, 0.999), IdDict{Any,Any}())

In [5]:
"""
Functions to get rollouts
"""

function action(state)
    # Acccounting for the element type
    state = reshape(Array(state),length(state),1) 

    a = nothing
    if MODE == "CON"
	    # Our policy outputs the parameters of a Normal distribution
	    μ = policy_μ(state)
	    μ = reshape(μ,ACTION_SIZE)
	    log_std = policy_Σ
	    
	    σ² = (exp.(log_std)).^2
	    Σ = diagm(0=>σ².data)
	    
	    dis = MvNormal(μ.data,Σ)
	    
	    a = rand(dis,ACTION_SIZE)
	else
		action_probs = policy(state).data
        action_probs = reshape(action_probs,ACTION_SIZE)
    	a = sample(1:ACTION_SIZE,Weights(action_probs)) - 1
    end
    a
end

function run_episode(env)
    experience = []
    
    s = reset!(env)
    for i in 1:EPISODE_LENGTH
        a = action(s)
        # a = convert.(Float64,a)
        
        if MODE == "CON"
            a = reshape(a,ACTION_SIZE)
        end

        r,s_ = step!(env,a)
        push!(experience,(s,a,r,s_))
        s = s_
        if env.done
           break 
        end
    end
    experience
end

run_episode (generic function with 2 methods)

In [6]:
"""
Multi-threaded parallel rollout collection
"""

num_processes = 9
addprocs(num_processes) 

@everywhere function collect(env)
    run_episode(env)
end

@everywhere function rollout()
  env = GymEnv(env_name)
  return collect(env)
end

function get_rollouts()
    g = []
    for  w in workers()
      push!(g, rollout())
    end

    fetch.(g)
end

get_rollouts (generic function with 1 method)

In [135]:
"""
Generalized Adavantage Estimation
"""

function gae(states,actions,rewards,next_states)
    """
    Returns a Generalized Advantage Estimate for an episode
    """
    Â = []
    A = 0.0
    for i in reverse(1:length(states))
        δ = rewards[i] + γ*cpu.(value(next_states[i]).data[1]) - cpu.(value(states[i]).data[1])
        A = δ + (γ*λ*A)
        push!(Â,A)
    end
    
    Â = reverse(Â)
    return Â
end

function disconunted_returns(rewards)
    r = 0.0
    returns = []
    for i in reverse(1:length(rewards))
        r = rewards[i] + γ*r
        push!(returns,r)
    end
    returns = reverse(returns)
    returns
end

"""
Calculate Log Probabilities
"""
function log_prob_from_actions(states,actions)
    """
    Returns log probabilities of the actions taken
    
    states,actions : episode vairbles in the form of a list
    """
    log_probs = []
    
    for i in 1:length(states)
    	if MODE == "CON"
	        μ = reshape(policy_μ(states[i]),ACTION_SIZE).data
	        logΣ = policy_Σ.data |> cpu
        	push!(log_probs,normal_log_prob(μ,logΣ,actions[i]))
        else
        	action_probs = policy(states[i])
        	prob = action_probs[actions[i],:].data
        	push!(log_probs,log.(prob))
        end
    end
    
    log_probs
end


log_prob_from_actions

In [136]:
"""
Process and extraction information from rollouts
"""

function process_rollouts(rollouts)
    """
    rollouts : variable returned by calling `get_rollouts`
    
    Returns : 
    states, actions, rewards for minibatch processing
    """
    # Process the variables
    states = []
    actions = []
    rewards = []
    next_states = []
    advantages = []
    returns = []
    log_probs = []
    
    # Logging statistics
    episode_mean_returns = []
    
    for ro in rollouts
        episode_states = []
        episode_actions = []
        episode_rewards = []
        episode_next_states = []
        
        for i in 1:length(ro)
             push!(episode_states,Array(ro[i][1]))
             
             if MODE == "CON"
                 push!(episode_actions,ro[i][2])
             else
                 push!(episode_actions,ro[i][2] + 1)
             end
             
             push!(episode_rewards,ro[i][3])
             push!(episode_next_states,ro[i][4])
        end
        
        episode_rewards = scale_rewards(episode_rewards)
        episode_advantages = gae(episode_states,episode_actions,episode_rewards,episode_next_states)
        # episode_rewards = normalise(episode_rewards)
        
        episode_returns = disconunted_returns(episode_rewards)
        
        push!(episode_mean_returns,mean(episode_returns))
        
        push!(states,episode_states)
        push!(actions,episode_actions)
        push!(rewards,episode_rewards)
        push!(advantages,episode_advantages)
        push!(returns,episode_returns)
        push!(log_probs,log_prob_from_actions(episode_states,episode_actions))
    end
    
    states = cat(states...,dims=1)
    actions = cat(actions...,dims=1)
    rewards = cat(rewards...,dims=1)
    advantages = cat(advantages...,dims=1)
    returns = cat(returns...,dims=1)
    log_probs = cat(log_probs...,dims=1)
    
    push!(reward_hist,mean(episode_mean_returns))
    
    if length(reward_hist) <= 100
        println("RETURNS : $(mean(episode_mean_returns))")
    else
        println("MEAN RETURNS : $(mean(reward_hist))")
        println("LAST 100 RETURNS : $(mean(reward_hist[end-100:end]))")
    end
    
    return hcat(states...),hcat(actions...),hcat(rewards...),hcat(advantages...),hcat(returns...),hcat(log_probs...)
end

"""
Loss function definition
"""
function loss(states,actions,advantages,returns,old_log_probs)
    global global_step,policy_l,entropy_l,value_l
    global_step += 1
    
    if MODE == "CON"
	    μ = policy_μ(states)
	    logΣ = policy_Σ 
        
	    new_log_probs = normal_log_prob(μ,logΣ,actions)
	else
		action_probs = policy(states) # ACTION_SIZE x BATCH_SIZE
		actions_one_hot = zeros(ACTION_SIZE,size(action_probs)[end])
        
		for i in 1:size(action_probs)[end]
			actions_one_hot[actions[:,i][1],i] = 1.0				
		end
        
		new_log_probs = log.(sum((action_probs .+ 1f-5) .* actions_one_hot,dims=1))
    end
    
    # Surrogate loss computations
    ratio = exp.(new_log_probs .- old_log_probs)
    surr1 = ratio .* advantages
    surr2 = clamp.(ratio,(1.0 - ϵ),(1.0 + ϵ)) .* advantages
    policy_loss = mean(min.(surr1,surr2))
    
    value_predicted = value(states)
    value_loss = mean((value_predicted .- returns).^2)
    
    if MODE == "CON"
        entropy_loss = mean(normal_entropy(logΣ))
    else
        entropy_loss = mean(categorical_entropy(action_probs))
    end
    
    policy_l = policy_loss.data
    entropy_l = entropy_loss.data
    value_l = value_loss.data
    
    -c₀*policy_loss + c₁*value_loss - c₂*entropy_loss
end

"""
Optimization Function
"""
function ppo_update(states,actions,advantages,returns,old_log_probs)
    # Define model parameters
    if MODE == "CON"
        model_params = params(params(policy_μ)...,params(policy_Σ)...,params(value)...)
    else
        model_params = params(params(policy)...,params(value)...)
    end

    # Calculate gradients
    gs = Tracker.gradient(() -> loss(states,actions,advantages,returns,old_log_probs),model_params)

    # Take a step of optimisation
    update!(opt,model_params,gs)
end

ppo_update

In [137]:
"""
Train
"""

function train_step()    
    routs = get_rollouts()
    states,actions,rewards,advantages,returns,log_probs = process_rollouts(routs)
    
    idxs = partition(1:size(states)[end],BATCH_SIZE)
    
    for epoch in 1:PPO_EPOCHS
        for i in idxs
            mb_states = states[:,i] 
            mb_actions = actions[:,i] 
            mb_advantages = advantages[:,i] 
            mb_returns = returns[:,i] 
            mb_log_probs = log_probs[:,i]
            
            ppo_update(mb_states,mb_actions,mb_advantages,mb_returns,mb_log_probs)
        end
    end
end

function train()
    for i in 1:NUM_EPISODES
        println("EP : $i")
        train_step()
        println("Ep done")
        
        # Anneal learning rate
        if i%300 == 0
            if opt.eta > 1e-6
                opt.eta = opt.eta / 3.0
            end
        end
        
        if i % VERBOSE_FREQUENCY == 0
            # Show important statistics
            println("-----___Stats___-----")
            
            if MODE == "CON"
                println("Entropy : $(normal_entropy(policy_Σ))")
            end
            
            println("Policy Loss : $(policy_l)")
            println("Entropy Loss : $(entropy_l)")
            println("Value Loss : $(value_l)")
        end
        
        if i%SAVE_FREQUENCY == 0
        	if MODE == "CON"
	            @save "weights/policy_mu.bson" policy_μ
	            @save "weights/policy_sigma.bson" policy_Σ
	            @save "weights/value.bson" value
	        else
	        	@save "weights/policy_cat.bson" policy
	        	@save "weights/value.bson" value
            end
            
            save("stats.jld","rewards",reward_hist)
            println("\n\n\n----MAX REWRD SO FAR : $(maximum(reward_hist))---\n\n\n")
        end
    end
end

train (generic function with 1 method)

In [138]:
train()

EP : 1
MEAN RETURNS : 5.984589584907436
LAST 100 RETURNS : 5.025881888292809
Ep done
EP : 2
MEAN RETURNS : 5.980573360504491
LAST 100 RETURNS : 5.025885332573053
Ep done
EP : 3
MEAN RETURNS : 5.976307051952417
LAST 100 RETURNS : 5.028993236547197
Ep done
EP : 4
MEAN RETURNS : 5.9711984559149025
LAST 100 RETURNS : 5.030545466394145
Ep done
EP : 5
MEAN RETURNS : 5.965854959408277
LAST 100 RETURNS : 5.029000099367136
Ep done
-----___Stats___-----
Policy Loss : 0.3179458056853174
Entropy Loss : -5.3645862e-11
Value Loss : 0.053309712047081764
EP : 6


InterruptException: InterruptException:

In [None]:
policy(ones(4,1))

In [71]:
env = GymEnv(env_name)

GymEnv CartPole-v0
  TimeLimit
  r  = 0.0
  ∑r = 0.0

│   caller = show(::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}}, ::GymEnv{PyCall.PyArray{Float64,1}}) at OpenAIGym.jl:64
└ @ OpenAIGym /home/shreyas/.julia/packages/OpenAIGym/wZkkM/src/OpenAIGym.jl:64
│   caller = show(::IOContext{Base.GenericIOBuffer{Array{UInt8,1}}}, ::GymEnv{PyCall.PyArray{Float64,1}}) at OpenAIGym.jl:65
└ @ OpenAIGym /home/shreyas/.julia/packages/OpenAIGym/wZkkM/src/OpenAIGym.jl:65


In [75]:
s = reset!(env)

4-element PyCall.PyArray{Float64,1}:
 -0.04849816918922241 
  0.013628561471056858
  0.006237321324826253
  0.04166733079767969 

In [81]:
function loss(x)
    action_probs = policy(x)
    mean(action_probs)
end

gs = Tracker.gradient(() -> loss(s),params(policy))

Grads(...)


In [82]:
policy.layers[1].W

Tracked 256×4 Array{Float32,2}:
  0.0463625    -0.0288894    0.0691652    -0.0266886 
  0.124687      0.0818793   -0.129135      0.204254  
 -0.106862      0.193595    -0.0882367    -0.0241373 
 -0.0685683     0.0427704    0.0284292     0.0874056 
 -0.176015      0.104862     0.0349711    -0.124782  
  0.0458958    -0.0517615    0.0304131     0.00226418
  0.0698005     0.0420365   -0.162521     -0.0820992 
 -0.0710014    -0.0629289    0.0741035     0.173773  
  0.036564      0.0159403    0.0240178    -0.0337437 
 -0.0883487    -0.074553     0.121145     -0.183758  
 -0.0945337    -0.00958645  -0.00795938   -0.0662095 
  0.0517599    -0.0620799   -0.000125911  -0.0471872 
 -0.0486801     0.006256     0.178224      0.092455  
  ⋮                                                  
  0.0060984     0.0074       0.149015      0.0837609 
 -0.12016      -0.0510779    0.182024      0.0255303 
 -0.142491     -0.138318    -0.078126     -0.124178  
  0.179972      0.124144    -0.0903058    -0.07220