In [95]:
# See also
# https://github.com/tejank10/Flux-baselines/blob/master/baselines/dqn/dqn.jl
# https://github.com/JuliaML/Reinforce.jl/blob/master/examples/mountain_car.jl

using Base: Matrix
using DataStructures
using Flux
using Plots
using Reinforce
using Reinforce: MaxIter, IterFunction, LearningStrategy, state, action, maxsteps
using Reinforce.MountainCarEnv: MountainCar, MountainCarState

In [126]:
gr()

# Environment setup
env = MountainCar()
Reinforce.maxsteps(env::MountainCar) = 500
state_array(state::MountainCarState) = [state.position, state.velocity]

state_array (generic function with 1 method)

In [127]:
MEMORY = CircularBuffer{Tuple}(10000)

function episode!(env, π)
  ep = Episode(env, π)
  for (s, a, r, sp) in ep
    gui(plot(env))
    push!(MEMORY, (state_array(s), a, r, state_array(sp), finished(env, sp)))
  end
  ep.total_reward, ep.niter
end

episode! (generic function with 2 methods)

In [128]:
# Deterministic policy that is solving the problem
mutable struct BasicCarPolicy <: Reinforce.AbstractPolicy end
Reinforce.action(policy::BasicCarPolicy, r, s, A) = s.velocity < 0 ? 1 : 3

In [227]:
@show s = state(env)
@show A = actions(env, s)

STATE_SIZE = 2 # length(s)
ACTION_SIZE = length(A)
@show STATE_SIZE, ACTION_SIZE

@show model = Chain(
    Dense(STATE_SIZE, 10, relu),
    Dense(10, ACTION_SIZE))
loss(x, y) = Flux.mse(model(x), y)
evalcb = () -> @show(loss(x, r))

@show model([1, 2]).data

function replay()
    batch_size = length(MEMORY)
    minibatch = sample(MEMORY, batch_size, replace = false)
    
    x = Matrix{Number}(undef, STATE_SIZE, batch_size)
    y = Matrix{Number}(undef, ACTION_SIZE, batch_size)
    
    for (iter, (state, action, reward, next_state, done)) in enumerate(minibatch)
        target = reward
        if !done
            target += 1.0 * maximum(model(next_state).data)
        end

        target_f = model(state).data
        target_f[action] = target
    
        x[:, iter] .= state
        y[:, iter] .= target_f
    end
     
    Flux.train!(loss, Flux.params(model), [(x, y)], ADAM())
end
replay()

mutable struct LearnedCarPolicy <: Reinforce.AbstractPolicy end
Reinforce.maxsteps(env::MountainCar) = 500
function Reinforce.action(π::LearnedCarPolicy, r, s, a)
  act_values = model(state_array(s))
  return Flux.onecold(act_values)
end

s = state(env) = MountainCarState(-1.0487946424807484, 0.021861897029926712)
A = actions(env, s) = DiscreteSet{UnitRange{Int64}}(1:3)
(STATE_SIZE, ACTION_SIZE) = (2, 3)
model = Chain(Dense(STATE_SIZE, 10, relu), Dense(10, ACTION_SIZE)) = Chain(Dense(2, 10, NNlib.relu), Dense(10, 3))
(model([1, 2])).data = Float32[-0.887088, -0.462151, -0.730609]


In [228]:
# Main part
R, n = episode!(env, BasicCarPolicy())
println("reward: $R, iter: $n")

reward: -72.0, iter: 72


In [229]:
# Main part
R, n = episode!(env, LearnedCarPolicy())
println("reward: $R, iter: $n")

replay()

reward: -500.0, iter: 500


In [224]:
for i in 1:100
    episode!(env, BasicCarPolicy())
    episode!(env, LearnedCarPolicy())
    replay()
end

InterruptException: InterruptException:

In [235]:
Flux.params(model)

Params([Float32[-0.430641 -0.263542; -0.150519 0.185679; … ; -0.552938 -0.595616; 0.223367 0.375575] (tracked), Float32[0.002, -0.002, 0.002, 0.00199999, 0.002, 0.002, 0.002, -0.002, -0.002, 0.002] (tracked), Float32[0.310715 0.215176 … 0.116383 -0.445864; -0.399639 0.00915593 … 0.403386 -0.318224; -0.027919 0.623285 … -0.396398 -0.535551] (tracked), Float32[-0.002, -0.002, -0.002] (tracked)])