In [1]:
using OpenAIGym
using Flux
using Reinforce
using StatsBase
import Reinforce.action
import Flux.params
using Flux: onehotbatch
import Statistics: mean
import Distributions: Multinomial

In [2]:
env = GymEnv(:Acrobot, :v1)

# env parameters
n_actions = length(env.actions)
actions_arr = collect(0:n_actions-1) # ARR [0, .., A-1]
actions_arr_ = collect(1:n_actions)  # ARR [1, .., A]
n_states = length(env.state)
η = 0.01f0 # learning rate

model = Chain(Dense(n_states, 32, tanh),
                Dense(32, n_actions))

function loss(weights, actions, observations)
    # weights: batch * 1, actions: arr(batch * 1), observations: batch * states
    batch_size = length(weights)
    action_masks = onehotbatch(actions, actions_arr_) # actions * batch
    logits = logsoftmax(model(observations'))  # actions * batch
    log_probs = logits .* action_masks
    reduce_sum = sum(log_probs, dims=1) # 1 * batch
    loss = reduce_sum .* weights'
    return -mean(loss)
end

opt = ADAM(params(model), η)

#43 (generic function with 1 method)

In [3]:
function get_action(state)
    logits = model(state)
    dist = Multinomial(1, softmax(logits.data))
    sample = rand(dist, 1)
    pi = Flux.argmax(dropdims(sample, dims=2))
    return pi - 1 
end

function get_matrix(arr, batch_size, dim)
    mat = Matrix{Float32}(undef, batch_size, dim)
    if dim == 1
        for i in 1:length(arr)
            mat[i, 1] = arr[i] 
        end
        return mat
    end
    for (i, item) in enumerate(arr)
        mat[i, :] = item
    end
    return mat
end

function rtg(arr)
    rtg = Array{Float32}(undef, length(arr))
    for i in length(arr):-1:1
        if i == length(arr)
            rtg[i] = arr[i]
        else
            rtg[i] = arr[i] + rtg[i+1]
        end
    end
    return rtg
end
    
function train_one_epoch(env, batch_size=5000, render=false)
    # make some empty lists for logging.
    batch_obs = []          # for observations
    batch_acts = []         # for actions
    batch_weights = []      # for R(tau) weighting in policy gradient
    batch_rets = []         # for measuring episode returns
    batch_lens = []         # for measuring episode lengths
    
    # collect experience by acting in the environment with current policy
    ep_rews, o, r, ep_len, ep_ret = [], reset!(env), 0, 0, 0     # first obs comes from starting distribution
    
    while true
        # reset episode-specific variables
        temp = []
        for i in 1:length(o)
            push!(temp, o[i])
        end
        a = get_action(temp)
        push!(batch_obs, temp)
        push!(batch_acts, a+1)
        push!(ep_rews, r)
    
        r, o = step!(env, a)
        ep_len += 1
        ep_ret += r
        
        done = env.done
        if env.done
            push!(batch_rets, ep_ret)
            push!(batch_lens, ep_len)
            push!(batch_weights, rtg(ep_rews))
    
            if length(batch_obs) > batch_size
                break
            end
            println("Episode Length: ", ep_len, " Episode Reward: ", ep_ret)
            ep_rews, o, r, ep_len, ep_ret = [], reset!(env), 0, 0, 0
        end
    end
    # flatten batch_weights
    batch_weights = collect(Iterators.flatten(batch_weights))
    batch = length(batch_weights)
    mat_obs = get_matrix(batch_obs, batch, n_states)
    mat_acts = get_matrix(batch_acts, batch, 1)
    mat_weights = get_matrix(batch_weights, batch, 1)
    Flux.train!(loss, [(mat_weights, batch_acts, mat_obs)], opt)
    return batch_lens, batch_rets
end

epoch = []
epoch_performance = []
function train(env, epochs = 10)
    for i in 1:epochs
        batch_lens, batch_rets = train_one_epoch(env)
        push!(epoch_performance, mean(batch_rets))
        push!(epoch, i)
        println(i, " ", mean(batch_lens), " ", mean(batch_rets))
    end
end 

train (generic function with 2 methods)

In [4]:
train(env, 50)

│   caller = get_action(::Array{Any,1}) at In[3]:5
└ @ Main ./In[3]:5


Episode Length: 388 Episode Reward: -387.0
Episode Length: 376 Episode Reward: -375.0
Episode Length: 454 Episode Reward: -453.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 461 Episode Reward: -460.0
Episode Length: 386 Episode Reward: -385.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 308 Episode Reward: -307.0
Episode Length: 499 Episode Reward: -498.0
Episode Length: 500 Episode Reward: -500.0
1 447.6666666666667 -447.0833333333333
Episode Length: 499 Episode Reward: -498.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 423 Episode Reward: -422.0
Episode Length: 297 Episode Reward: -296.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 500 Episode Reward: -500.0
Episode Length: 419 Episode Reward: -418.0
2 459.8181818181818 -459.3636363636364
Episode Length: 381

Episode Length: 211 Episode Reward: -210.0
Episode Length: 176 Episode Reward: -175.0
Episode Length: 201 Episode Reward: -200.0
Episode Length: 381 Episode Reward: -380.0
Episode Length: 252 Episode Reward: -251.0
Episode Length: 153 Episode Reward: -152.0
Episode Length: 193 Episode Reward: -192.0
Episode Length: 162 Episode Reward: -161.0
Episode Length: 139 Episode Reward: -138.0
Episode Length: 217 Episode Reward: -216.0
Episode Length: 168 Episode Reward: -167.0
Episode Length: 134 Episode Reward: -133.0
Episode Length: 179 Episode Reward: -178.0
Episode Length: 283 Episode Reward: -282.0
Episode Length: 174 Episode Reward: -173.0
Episode Length: 255 Episode Reward: -254.0
Episode Length: 200 Episode Reward: -199.0
Episode Length: 175 Episode Reward: -174.0
Episode Length: 185 Episode Reward: -184.0
Episode Length: 175 Episode Reward: -174.0
Episode Length: 144 Episode Reward: -143.0
Episode Length: 289 Episode Reward: -288.0
Episode Length: 261 Episode Reward: -260.0
Episode Len

Episode Length: 142 Episode Reward: -141.0
Episode Length: 95 Episode Reward: -94.0
Episode Length: 96 Episode Reward: -95.0
Episode Length: 170 Episode Reward: -169.0
Episode Length: 247 Episode Reward: -246.0
Episode Length: 150 Episode Reward: -149.0
Episode Length: 104 Episode Reward: -103.0
Episode Length: 147 Episode Reward: -146.0
Episode Length: 125 Episode Reward: -124.0
Episode Length: 144 Episode Reward: -143.0
Episode Length: 176 Episode Reward: -175.0
Episode Length: 88 Episode Reward: -87.0
Episode Length: 399 Episode Reward: -398.0
Episode Length: 94 Episode Reward: -93.0
Episode Length: 115 Episode Reward: -114.0
Episode Length: 126 Episode Reward: -125.0
Episode Length: 110 Episode Reward: -109.0
Episode Length: 109 Episode Reward: -108.0
Episode Length: 111 Episode Reward: -110.0
Episode Length: 112 Episode Reward: -111.0
Episode Length: 110 Episode Reward: -109.0
Episode Length: 146 Episode Reward: -145.0
Episode Length: 141 Episode Reward: -140.0
Episode Length: 173

Episode Length: 103 Episode Reward: -102.0
Episode Length: 162 Episode Reward: -161.0
Episode Length: 157 Episode Reward: -156.0
Episode Length: 106 Episode Reward: -105.0
Episode Length: 170 Episode Reward: -169.0
Episode Length: 114 Episode Reward: -113.0
Episode Length: 114 Episode Reward: -113.0
Episode Length: 97 Episode Reward: -96.0
Episode Length: 118 Episode Reward: -117.0
Episode Length: 243 Episode Reward: -242.0
Episode Length: 131 Episode Reward: -130.0
Episode Length: 112 Episode Reward: -111.0
Episode Length: 105 Episode Reward: -104.0
Episode Length: 113 Episode Reward: -112.0
Episode Length: 106 Episode Reward: -105.0
Episode Length: 129 Episode Reward: -128.0
Episode Length: 137 Episode Reward: -136.0
23 132.10256410256412 -131.10256410256412
Episode Length: 119 Episode Reward: -118.0
Episode Length: 122 Episode Reward: -121.0
Episode Length: 121 Episode Reward: -120.0
Episode Length: 108 Episode Reward: -107.0
Episode Length: 93 Episode Reward: -92.0
Episode Length: 

Episode Length: 110 Episode Reward: -109.0
Episode Length: 114 Episode Reward: -113.0
Episode Length: 104 Episode Reward: -103.0
Episode Length: 114 Episode Reward: -113.0
Episode Length: 118 Episode Reward: -117.0
Episode Length: 141 Episode Reward: -140.0
Episode Length: 167 Episode Reward: -166.0
Episode Length: 99 Episode Reward: -98.0
Episode Length: 121 Episode Reward: -120.0
Episode Length: 144 Episode Reward: -143.0
Episode Length: 149 Episode Reward: -148.0
Episode Length: 119 Episode Reward: -118.0
Episode Length: 157 Episode Reward: -156.0
Episode Length: 131 Episode Reward: -130.0
Episode Length: 144 Episode Reward: -143.0
Episode Length: 128 Episode Reward: -127.0
Episode Length: 110 Episode Reward: -109.0
Episode Length: 90 Episode Reward: -89.0
Episode Length: 123 Episode Reward: -122.0
Episode Length: 123 Episode Reward: -122.0
Episode Length: 92 Episode Reward: -91.0
Episode Length: 125 Episode Reward: -124.0
Episode Length: 227 Episode Reward: -226.0
Episode Length: 1

Episode Length: 118 Episode Reward: -117.0
Episode Length: 109 Episode Reward: -108.0
Episode Length: 88 Episode Reward: -87.0
Episode Length: 98 Episode Reward: -97.0
Episode Length: 127 Episode Reward: -126.0
32 106.125 -105.125
Episode Length: 103 Episode Reward: -102.0
Episode Length: 112 Episode Reward: -111.0
Episode Length: 116 Episode Reward: -115.0
Episode Length: 135 Episode Reward: -134.0
Episode Length: 88 Episode Reward: -87.0
Episode Length: 97 Episode Reward: -96.0
Episode Length: 110 Episode Reward: -109.0
Episode Length: 109 Episode Reward: -108.0
Episode Length: 99 Episode Reward: -98.0
Episode Length: 85 Episode Reward: -84.0
Episode Length: 100 Episode Reward: -99.0
Episode Length: 83 Episode Reward: -82.0
Episode Length: 119 Episode Reward: -118.0
Episode Length: 151 Episode Reward: -150.0
Episode Length: 77 Episode Reward: -76.0
Episode Length: 107 Episode Reward: -106.0
Episode Length: 116 Episode Reward: -115.0
Episode Length: 92 Episode Reward: -91.0
Episode Le

Episode Length: 114 Episode Reward: -113.0
Episode Length: 101 Episode Reward: -100.0
Episode Length: 204 Episode Reward: -203.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 113 Episode Reward: -112.0
Episode Length: 124 Episode Reward: -123.0
Episode Length: 90 Episode Reward: -89.0
Episode Length: 89 Episode Reward: -88.0
Episode Length: 94 Episode Reward: -93.0
Episode Length: 150 Episode Reward: -149.0
Episode Length: 76 Episode Reward: -75.0
Episode Length: 100 Episode Reward: -99.0
Episode Length: 93 Episode Reward: -92.0
Episode Length: 94 Episode Reward: -93.0
Episode Length: 128 Episode Reward: -127.0
Episode Length: 137 Episode Reward: -136.0
Episode Length: 100 Episode Reward: -99.0
Episode Length: 70 Episode Reward: -69.0
Episode Length: 101 Episode Reward: -100.0
Episode Length: 133 Episode Reward: -132.0
Episode Length: 90 Episode Reward: -89.0
Episode Length: 70 Episode Reward: -69.0
Episode Length: 87 Episode Reward: -86.0
Episode Length: 101 Episode Reward:

Episode Length: 91 Episode Reward: -90.0
Episode Length: 125 Episode Reward: -124.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 99 Episode Reward: -98.0
Episode Length: 89 Episode Reward: -88.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 100 Episode Reward: -99.0
Episode Length: 81 Episode Reward: -80.0
Episode Length: 143 Episode Reward: -142.0
Episode Length: 95 Episode Reward: -94.0
Episode Length: 113 Episode Reward: -112.0
Episode Length: 101 Episode Reward: -100.0
Episode Length: 173 Episode Reward: -172.0
Episode Length: 95 Episode Reward: -94.0
Episode Length: 103 Episode Reward: -102.0
Episode Length: 118 Episode Reward: -117.0
Episode Length: 92 Episode Reward: -91.0
Episode Length: 95 Episode Reward: -94.0
Episode Length: 107 Episode Reward: -106.0
Episode Length: 87 Episode Reward: -86.0
Episode Length: 110 Episode Reward: -109.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 88 Episode Reward: -87.0
Episode Length: 112 Episode Reward: -1

Episode Length: 99 Episode Reward: -98.0
Episode Length: 80 Episode Reward: -79.0
Episode Length: 112 Episode Reward: -111.0
Episode Length: 80 Episode Reward: -79.0
Episode Length: 88 Episode Reward: -87.0
Episode Length: 91 Episode Reward: -90.0
Episode Length: 147 Episode Reward: -146.0
Episode Length: 90 Episode Reward: -89.0
Episode Length: 101 Episode Reward: -100.0
Episode Length: 90 Episode Reward: -89.0
Episode Length: 106 Episode Reward: -105.0
Episode Length: 99 Episode Reward: -98.0
Episode Length: 101 Episode Reward: -100.0
Episode Length: 105 Episode Reward: -104.0
Episode Length: 76 Episode Reward: -75.0
Episode Length: 163 Episode Reward: -162.0
Episode Length: 114 Episode Reward: -113.0
Episode Length: 102 Episode Reward: -101.0
Episode Length: 92 Episode Reward: -91.0
Episode Length: 102 Episode Reward: -101.0
Episode Length: 119 Episode Reward: -118.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 94 Episode Reward: -93.0
Episode Length: 79 Episode Reward: 

Episode Length: 94 Episode Reward: -93.0
Episode Length: 96 Episode Reward: -95.0
48 95.0377358490566 -94.0377358490566
Episode Length: 97 Episode Reward: -96.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 98 Episode Reward: -97.0
Episode Length: 143 Episode Reward: -142.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 96 Episode Reward: -95.0
Episode Length: 103 Episode Reward: -102.0
Episode Length: 82 Episode Reward: -81.0
Episode Length: 98 Episode Reward: -97.0
Episode Length: 72 Episode Reward: -71.0
Episode Length: 95 Episode Reward: -94.0
Episode Length: 106 Episode Reward: -105.0
Episode Length: 83 Episode Reward: -82.0
Episode Length: 92 Episode Reward: -91.0
Episode Length: 80 Episode Reward: -79.0
Episode Length: 86 Episode Reward: -85.0
Episode Length: 121 Episode Reward: -120.0
Episode Length: 123 Episode Reward: -122.0
Episode Length: 87 Episode Reward: -86.0
Episode Length: 92 Episode Reward: -91.0
Episode Length: 73 Episode Reward: -72.0
Episode L

In [6]:
for i in epoch_performance
    print(i, ", ")
end

-447.0833333333333, -459.3636363636364, -406.46153846153845, -416.3076923076923, -314.625, -318.9375, -249.1, -250.15, -218.0, -210.5, -184.85185185185185, -201.08, -175.27586206896552, -163.93548387096774, -161.3548387096774, -150.84848484848484, -151.21212121212122, -140.33333333333334, -155.5, -134.72972972972974, -127.075, -131.31578947368422, -131.10256410256412, -127.225, -120.9047619047619, -123.70731707317073, -129.89743589743588, -125.25, -119.14285714285714, -117.4186046511628, -119.02380952380952, -105.125, -116.77777777777777, -103.83333333333333, -108.67391304347827, -104.64583333333333, -106.55319148936171, -105.72340425531915, -113.88636363636364, -100.94, -106.51063829787235, -97.27450980392157, -96.42307692307692, -96.88461538461539, -102.61224489795919, -96.6923076923077, -96.0576923076923, -94.0377358490566, -95.9423076923077, -93.05555555555556, 