In [2]:
using OpenAIGym
using Flux
using Reinforce
using StatsBase
import Reinforce.action
import Flux.params
using Flux: onehotbatch
import Statistics: mean
import Distributions: Multinomial

In [3]:
env = GymEnv(:Acrobot, :v1)

# env parameters
n_actions = length(env.actions)
actions_arr = collect(0:n_actions-1) # ARR [0, .., A-1]
actions_arr_ = collect(1:n_actions)  # ARR [1, .., A]
n_states = length(env.state)
η = 0.01f0 # learning rate

model = Chain(Dense(n_states, 32, tanh),
                Dense(32, n_actions))

function loss(weights, actions, observations)
    # weights: batch * 1, actions: arr(batch * 1), observations: batch * states
    batch_size = length(weights)
    action_masks = onehotbatch(actions, actions_arr_) # actions * batch
    logits = logsoftmax(model(observations'))  # actions * batch
    log_probs = logits .* action_masks
    reduce_sum = sum(log_probs, dims=1) # 1 * batch
    loss = reduce_sum .* weights'
    return -mean(loss)
end

opt = ADAM(params(model), η)

#43 (generic function with 1 method)

In [4]:
function get_action(state)
    logits = model(state)
    dist = Multinomial(1, softmax(logits.data))
    sample = rand(dist, 1)
    pi = Flux.argmax(dropdims(sample, dims=2))
    return pi - 1 
end

function get_matrix(arr, batch_size, dim)
    mat = Matrix{Float32}(undef, batch_size, dim)
    if dim == 1
        for i in 1:length(arr)
            mat[i, 1] = arr[i] 
        end
        return mat
    end
    for (i, item) in enumerate(arr)
        mat[i, :] = item
    end
    return mat
end

function train_one_epoch(env, batch_size=5000, render=false)
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths
    
    # collect experience by acting in the environment with current policy
    ep_rews, o, r, ep_len, ep_ret = [], reset!(env), 0, 0, 0     # first obs comes from starting distribution
    
    while true
        # reset episode-specific variables
        temp = []
        for i in 1:length(o)
            push!(temp, o[i])
        end
        a = get_action(temp)
        push!(batch_obs, temp)
        push!(batch_acts, a+1)
        push!(ep_rews, r)
    
        r, o = step!(env, a)
        ep_len += 1
        ep_ret += r
        done = env.done
        if env.done
            push!(batch_rets, ep_ret)
            push!(batch_lens, ep_len)
            push!(batch_weights, fill(ep_ret, Int(ep_len)))
    
            if length(batch_obs) > batch_size
                break
            end
            println("Episode Length: ", ep_len, " Episode Reward: ", ep_ret)
            ep_rews, o, r, ep_len, ep_ret = [], reset!(env), 0, 0, 0
        end
    end
    # flatten batch_weights
    batch_weights = collect(Iterators.flatten(batch_weights))
    batch = length(batch_weights)
    mat_obs = get_matrix(batch_obs, batch, n_states)
    mat_acts = get_matrix(batch_acts, batch, 1)
    mat_weights = get_matrix(batch_weights, batch, 1)
    Flux.train!(loss, [(mat_weights, batch_acts, mat_obs)], opt)
    return batch_lens, batch_rets
end

epoch = []
epoch_performance = []
function train(env, epochs = 10)
    for i in 1:epochs
        batch_lens, batch_rets = train_one_epoch(env)
        push!(epoch_performance, mean(batch_rets))
        push!(epoch, i)
        println(i, " ", mean(batch_lens), " ", mean(batch_rets))
    end
end 

train (generic function with 2 methods)

In [5]:
train(env, 50)

│   caller = get_action(::Array{Any,1}) at In[4]:5
└ @ Main ./In[4]:5


Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
1 500.0 -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
2 500.0 -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Epis

Episode Length: 183 Episode Reward: -182.0
Episode Length: 293 Episode Reward: -292.0
Episode Length: 158 Episode Reward: -157.0
Episode Length: 207 Episode Reward: -206.0
Episode Length: 229 Episode Reward: -228.0
Episode Length: 234 Episode Reward: -233.0
Episode Length: 227 Episode Reward: -226.0
Episode Length: 196 Episode Reward: -195.0
Episode Length: 193 Episode Reward: -192.0
Episode Length: 273 Episode Reward: -272.0
Episode Length: 200 Episode Reward: -199.0
Episode Length: 213 Episode Reward: -212.0
15 215.875 -214.875
Episode Length: 226 Episode Reward: -225.0
Episode Length: 192 Episode Reward: -191.0
Episode Length: 203 Episode Reward: -202.0
Episode Length: 179 Episode Reward: -178.0
Episode Length: 185 Episode Reward: -184.0
Episode Length: 167 Episode Reward: -166.0
Episode Length: 332 Episode Reward: -331.0
Episode Length: 189 Episode Reward: -188.0
Episode Length: 212 Episode Reward: -211.0
Episode Length: 162 Episode Reward: -161.0
Episode Length: 236 Episode Reward

Episode Length: 187 Episode Reward: -186.0
22 164.16129032258064 -163.16129032258064
Episode Length: 164 Episode Reward: -163.0
Episode Length: 173 Episode Reward: -172.0
Episode Length: 213 Episode Reward: -212.0
Episode Length: 154 Episode Reward: -153.0
Episode Length: 207 Episode Reward: -206.0
Episode Length: 148 Episode Reward: -147.0
Episode Length: 149 Episode Reward: -148.0
Episode Length: 175 Episode Reward: -174.0
Episode Length: 234 Episode Reward: -233.0
Episode Length: 252 Episode Reward: -251.0
Episode Length: 181 Episode Reward: -180.0
Episode Length: 219 Episode Reward: -218.0
Episode Length: 217 Episode Reward: -216.0
Episode Length: 227 Episode Reward: -226.0
Episode Length: 149 Episode Reward: -148.0
Episode Length: 180 Episode Reward: -179.0
Episode Length: 227 Episode Reward: -226.0
Episode Length: 184 Episode Reward: -183.0
Episode Length: 167 Episode Reward: -166.0
Episode Length: 214 Episode Reward: -213.0
Episode Length: 213 Episode Reward: -212.0
Episode Leng

Episode Length: 151 Episode Reward: -150.0
Episode Length: 172 Episode Reward: -171.0
Episode Length: 131 Episode Reward: -130.0
Episode Length: 188 Episode Reward: -187.0
Episode Length: 149 Episode Reward: -148.0
Episode Length: 125 Episode Reward: -124.0
Episode Length: 183 Episode Reward: -182.0
Episode Length: 178 Episode Reward: -177.0
Episode Length: 183 Episode Reward: -182.0
Episode Length: 170 Episode Reward: -169.0
Episode Length: 184 Episode Reward: -183.0
29 164.38709677419354 -163.38709677419354
Episode Length: 194 Episode Reward: -193.0
Episode Length: 168 Episode Reward: -167.0
Episode Length: 152 Episode Reward: -151.0
Episode Length: 169 Episode Reward: -168.0
Episode Length: 161 Episode Reward: -160.0
Episode Length: 155 Episode Reward: -154.0
Episode Length: 170 Episode Reward: -169.0
Episode Length: 157 Episode Reward: -156.0
Episode Length: 145 Episode Reward: -144.0
Episode Length: 202 Episode Reward: -201.0
Episode Length: 185 Episode Reward: -184.0
Episode Leng

Episode Length: 151 Episode Reward: -150.0
Episode Length: 164 Episode Reward: -163.0
Episode Length: 169 Episode Reward: -168.0
Episode Length: 167 Episode Reward: -166.0
Episode Length: 208 Episode Reward: -207.0
Episode Length: 129 Episode Reward: -128.0
Episode Length: 143 Episode Reward: -142.0
35 153.63636363636363 -152.63636363636363
Episode Length: 154 Episode Reward: -153.0
Episode Length: 138 Episode Reward: -137.0
Episode Length: 133 Episode Reward: -132.0
Episode Length: 151 Episode Reward: -150.0
Episode Length: 134 Episode Reward: -133.0
Episode Length: 113 Episode Reward: -112.0
Episode Length: 141 Episode Reward: -140.0
Episode Length: 162 Episode Reward: -161.0
Episode Length: 121 Episode Reward: -120.0
Episode Length: 116 Episode Reward: -115.0
Episode Length: 191 Episode Reward: -190.0
Episode Length: 146 Episode Reward: -145.0
Episode Length: 140 Episode Reward: -139.0
Episode Length: 128 Episode Reward: -127.0
Episode Length: 189 Episode Reward: -188.0
Episode Leng

Episode Length: 102 Episode Reward: -101.0
Episode Length: 113 Episode Reward: -112.0
Episode Length: 130 Episode Reward: -129.0
Episode Length: 170 Episode Reward: -169.0
Episode Length: 108 Episode Reward: -107.0
Episode Length: 112 Episode Reward: -111.0
Episode Length: 108 Episode Reward: -107.0
Episode Length: 125 Episode Reward: -124.0
Episode Length: 135 Episode Reward: -134.0
Episode Length: 145 Episode Reward: -144.0
Episode Length: 149 Episode Reward: -148.0
Episode Length: 97 Episode Reward: -96.0
Episode Length: 102 Episode Reward: -101.0
Episode Length: 89 Episode Reward: -88.0
Episode Length: 120 Episode Reward: -119.0
Episode Length: 107 Episode Reward: -106.0
Episode Length: 141 Episode Reward: -140.0
Episode Length: 136 Episode Reward: -135.0
Episode Length: 105 Episode Reward: -104.0
Episode Length: 129 Episode Reward: -128.0
Episode Length: 128 Episode Reward: -127.0
Episode Length: 154 Episode Reward: -153.0
Episode Length: 129 Episode Reward: -128.0
Episode Length:

Episode Length: 125 Episode Reward: -124.0
Episode Length: 97 Episode Reward: -96.0
Episode Length: 96 Episode Reward: -95.0
Episode Length: 126 Episode Reward: -125.0
Episode Length: 110 Episode Reward: -109.0
Episode Length: 141 Episode Reward: -140.0
Episode Length: 108 Episode Reward: -107.0
Episode Length: 99 Episode Reward: -98.0
Episode Length: 96 Episode Reward: -95.0
Episode Length: 116 Episode Reward: -115.0
Episode Length: 145 Episode Reward: -144.0
Episode Length: 101 Episode Reward: -100.0
Episode Length: 165 Episode Reward: -164.0
Episode Length: 270 Episode Reward: -269.0
Episode Length: 118 Episode Reward: -117.0
Episode Length: 171 Episode Reward: -170.0
Episode Length: 126 Episode Reward: -125.0
Episode Length: 103 Episode Reward: -102.0
Episode Length: 123 Episode Reward: -122.0
Episode Length: 117 Episode Reward: -116.0
Episode Length: 110 Episode Reward: -109.0
Episode Length: 104 Episode Reward: -103.0
Episode Length: 219 Episode Reward: -218.0
Episode Length: 114

Episode Length: 295 Episode Reward: -294.0
Episode Length: 104 Episode Reward: -103.0
Episode Length: 90 Episode Reward: -89.0
Episode Length: 93 Episode Reward: -92.0
Episode Length: 109 Episode Reward: -108.0
Episode Length: 107 Episode Reward: -106.0
50 122.78048780487805 -121.78048780487805


In [6]:
for i in epoch_performance
    print(i, ", ")
end

-500.0, -500.0, -500.0, -500.0, -500.0, -500.0, -500.0, -486.72727272727275, -477.72727272727275, -417.8333333333333, -319.0, -267.7, -255.35, -261.0, -214.875, -210.79166666666666, -210.79166666666666, -203.68, -200.96, -188.8148148148148, -194.69230769230768, -163.16129032258064, -190.74074074074073, -180.0, -167.66666666666666, -168.06666666666666, -180.71428571428572, -179.28571428571428, -163.38709677419354, -175.6206896551724, -160.5483870967742, -163.16129032258064, -152.0, -161.1290322580645, -152.63636363636363, -142.8, -156.9375, -148.5, -135.59459459459458, -146.6764705882353, -126.425, -136.9189189189189, -125.675, -126.0, -125.9, -120.28571428571429, -116.27906976744185, -116.16279069767442, -118.52380952380952, -121.78048780487805, 