In [1]:
include("gridworld.jl")
g = DMUGridWorld();

Let's apply Q-learning from Algorithm 5.3 in the text. We'll train over 1000 100-step runs:

In [2]:
function Qlearn(g, alpha, epsilon)
    # initialize dictionary
    Q = Dict{Int, Vector{Float64}}()
    
    # initialize Q-values at initial state (s = 1)
    Q[1] = zeros(n_actions(g))
    
    # 1000 simulations
    for k = 1:1000
        s = 1
        for t = 0:100
            # choose a based on Q and some exploration strategy
            a_idx = findmax(Q[s])[2]
            if rand() < epsilon
                a_idx = rand(1:4)
            end
            a = actions(g)[a_idx]

            # observe new state s_{t+1} and reward rt
            sp, r = simulate(g, s, a)

            # if we've never observed this state, initialize it to zeros
            if !haskey(Q, sp)
                Q[sp] = zeros(n_actions(g))
            end

            # update Q values
            Q[s][a_idx] += alpha * ( r + discount(g)*maximum(Q[sp]) - Q[s][a_idx] )

            # update s
            s = sp
            
            # 73 and 88 are terminal states. Just quit if we get in them.
            if s == 73 || s == 88
                break
            end
        end
    end
    
    return Q
end

Qlearn (generic function with 1 method)

In [3]:
Q = Qlearn(g, 0.5, 0.5);

Did the Q-learning work? Let's compare it to a random policy during 10 simulations.

In [4]:
using Random
Random.seed!(1)     # for reproducibility, seed random number generator

r_sum = 0.0      # sum for policy from Q-learning
rr_sum = 0.0     # sum for random policy

# run 10 simulations
for k = 1:10
    s = 1    # initial state for policy from Q-learning
    sr = 1   # initial state for random policy
    
    for t = 0:100
        
        # generate actions for both policies
        a = actions(g)[findmax(Q[s])[2]]
        ar = actions(g)[rand(1:4)]
        
        # advance Q simulation if you aren't in a terminal state
        if s != 73 && s != 88
            sp, r = simulate(g, s, a)
            r_sum += r * discount(g) ^ (-t)
            s = sp
        end
        
        # advance random simulation if you aren't in a terminal state
        if sr != 73 && sr != 88
            spr, rr = simulate(g, sr, ar)
            rr_sum += rr * discount(g) ^ (-t)
            sr = spr
        end
    end
end

println("Q-learned policy: ", round(Int, r_sum))
println("random poilcy:    ", round(Int, rr_sum))

Q-learned policy: -898
random poilcy:    -1187629


The cumulative sum from Q-learning is much better.