# Políticas para jugar 21 con Aprendizaje por Refuerzo

In [1]:
import Distributions

In [2]:
struct MDP
    states
    actions
    is_terminal
    legal_actions
    transition
    reward
end

## Algoritmos de aprendizaje por refuerzo

In [3]:
function ϵ_greedy(model::MDP, Q, s, ϵ)
    if rand(Distributions.Binomial(1, ϵ)) == 1
        return rand(model.legal_actions(s))
    else
        actions_q_value = Dict(a => get(Q, [s, a], 0) for a ∈ model.legal_actions(s))
        
        return findmax(actions_q_value)[2]
    end
end

ϵ_greedy (generic function with 1 method)

In [4]:
function Sarsa(mdp::MDP, γ::Float64, α::Float64, ϵ::Float64, episodes::Int64, steps::Int64)
    Q = Dict()
    for s ∈ mdp.states
        if !mdp.is_terminal(s)
            for a ∈ mdp.legal_actions(s)
                Q[[s, a]] = 0
            end
        else
            Q[[s, nothing]] = 0
        end
    end
    
    for episode ∈ 1:episodes
        println("In episode ", episode)
        
        s = rand(mdp.states)
        
        not_terminal = false
        while !not_terminal
            s = rand(mdp.states)
            
            if !mdp.is_terminal(s)
                not_terminal = true
            end
        end
        
        if mdp.is_terminal(s)
            Q[[s, nothing]] = mdp.reward(s, nothing, s)
            continue
        end
        
        a = ϵ_greedy(mdp, Q, s, ϵ)
        
        for step ∈ 1:steps
            n_s = mdp.transition(s, a)
            
            if mdp.is_terminal(n_s)
                Q[[n_s, nothing]] = mdp.reward(n_s, nothing, n_s)
                break
            end
            
            r = mdp.reward(s, a, n_s)
            n_a = ϵ_greedy(mdp, Q, s, ϵ)
            
            Q[[s, a]] = get(Q, [s, a], 0) + α * (r + γ*get(Q, [n_s, n_a], 0) - get(Q, [s, a], 0))
            
            s = n_s
            a = n_a
        end
    end
    
    pol = Dict()
    
    for (s, a) ∈ keys(Q)
        #println(s)
        
        if !mdp.is_terminal(s)
            actions_value = Dict()
            
            for a ∈ mdp.legal_actions(s)
                actions_value[a] = Q[[s, a]]
            end

            pol[s] = findmax(actions_value)[2]
        else
            pol[s] = a
        end
    end
    
    return pol
end

Sarsa (generic function with 1 method)

## MDP para simular

In [5]:
states = []

for i ∈ 2:21
    for j ∈ 1:11
        push!(states, [2, i, 1, j, 0])
    end
end

for s ∈ states
    if s[2] >= 21 || s[4] >= 17 || s[1] == 4 || s[3] == 4
        s[5] = 1
    end
end

states

220-element Array{Any,1}:
 [2, 2, 1, 1, 0]  
 [2, 2, 1, 2, 0]  
 [2, 2, 1, 3, 0]  
 [2, 2, 1, 4, 0]  
 [2, 2, 1, 5, 0]  
 [2, 2, 1, 6, 0]  
 [2, 2, 1, 7, 0]  
 [2, 2, 1, 8, 0]  
 [2, 2, 1, 9, 0]  
 [2, 2, 1, 10, 0] 
 [2, 2, 1, 11, 0] 
 [2, 3, 1, 1, 0]  
 [2, 3, 1, 2, 0]  
 ⋮                
 [2, 20, 1, 11, 0]
 [2, 21, 1, 1, 1] 
 [2, 21, 1, 2, 1] 
 [2, 21, 1, 3, 1] 
 [2, 21, 1, 4, 1] 
 [2, 21, 1, 5, 1] 
 [2, 21, 1, 6, 1] 
 [2, 21, 1, 7, 1] 
 [2, 21, 1, 8, 1] 
 [2, 21, 1, 9, 1] 
 [2, 21, 1, 10, 1]
 [2, 21, 1, 11, 1]

In [6]:
actions = [:hit, :stand]

2-element Array{Symbol,1}:
 :hit  
 :stand

In [7]:
function is_terminal(s)
    terminal = false
    
    if s[5] == 1
        terminal = true
    end
    
    return terminal
end

is_terminal (generic function with 1 method)

In [8]:
function legal_actions(s)
    if s[5] == 0
        return [:hit, :stand]
    else
        return nothing
    end
end

legal_actions (generic function with 1 method)

In [9]:
function transition(s, a)
    if a == :hit
        s[1] += 1
        s[2] += rand([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 11])
    elseif a == :stand
        s[3] += 1
        s[4] += rand([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 11])
    end
    
    if s[2] >= 21 || s[4] >= 21 || s[1] == 4 || s[3] == 4
        s[5] = 1
    end
    
    return s
end

transition (generic function with 1 method)

In [10]:
function reward(s, a, n_s)
    if n_s[5] == 1
        if n_s[2] <= 21
            if n_s[2] == 21 || n_s[1] == 4 || n_s[2] >= n_s[4] || n_s[4] > 21
                return 1
            else
                return -1
            end
        else
            return -1
        end
    else
        return 0
    end
end

reward (generic function with 1 method)

In [11]:
twenty_one = MDP(states, actions, is_terminal, legal_actions, transition, reward)

MDP(Any[[2, 2, 1, 1, 0], [2, 2, 1, 2, 0], [2, 2, 1, 3, 0], [2, 2, 1, 4, 0], [2, 2, 1, 5, 0], [2, 2, 1, 6, 0], [2, 2, 1, 7, 0], [2, 2, 1, 8, 0], [2, 2, 1, 9, 0], [2, 2, 1, 10, 0]  …  [2, 21, 1, 2, 1], [2, 21, 1, 3, 1], [2, 21, 1, 4, 1], [2, 21, 1, 5, 1], [2, 21, 1, 6, 1], [2, 21, 1, 7, 1], [2, 21, 1, 8, 1], [2, 21, 1, 9, 1], [2, 21, 1, 10, 1], [2, 21, 1, 11, 1]], Symbol[:hit, :stand], is_terminal, legal_actions, transition, reward)

In [12]:
γ = 0.9
α = 0.5
ϵ = 0.1
episodes = 50
steps = 10

10

In [13]:
pol = Sarsa(twenty_one, γ, α, ϵ, episodes, steps)

In episode 1
In episode 2
In episode 3
In episode 4
In episode 5
In episode 6
In episode 7
In episode 8
In episode 9
In episode 10
In episode 11
In episode 12
In episode 13
In episode 14
In episode 15
In episode 16
In episode 17
In episode 18
In episode 19
In episode 20
In episode 21
In episode 22
In episode 23
In episode 24
In episode 25
In episode 26
In episode 27
In episode 28
In episode 29
In episode 30
In episode 31
In episode 32
In episode 33
In episode 34
In episode 35
In episode 36
In episode 37
In episode 38
In episode 39
In episode 40
In episode 41
In episode 42
In episode 43
In episode 44
In episode 45
In episode 46
In episode 47
In episode 48
In episode 49
In episode 50


Dict{Any,Any} with 217 entries:
  [2, 18, 4, 20, 1] => nothing
  [2, 18, 1, 6, 0]  => :stand
  [2, 21, 1, 1, 1]  => nothing
  [2, 7, 3, 24, 1]  => :stand
  [2, 21, 1, 6, 1]  => nothing
  [2, 17, 1, 2, 0]  => :stand
  [2, 4, 1, 11, 0]  => :stand
  [2, 3, 1, 8, 0]   => :stand
  [3, 8, 4, 30, 1]  => :hit
  [2, 14, 3, 23, 1] => :hit
  [2, 13, 1, 1, 0]  => :stand
  [2, 19, 1, 2, 0]  => :stand
  [2, 15, 1, 5, 0]  => :stand
  [2, 21, 1, 3, 1]  => nothing
  [2, 17, 1, 8, 0]  => :stand
  [2, 9, 4, 30, 1]  => :stand
  [2, 7, 1, 7, 0]   => :stand
  [2, 20, 4, 24, 1] => :stand
  [2, 10, 1, 5, 0]  => :stand
  [2, 15, 1, 1, 0]  => :stand
  [2, 3, 1, 6, 0]   => :stand
  [2, 19, 1, 10, 0] => :stand
  [2, 12, 1, 4, 0]  => :stand
  [2, 14, 3, 21, 1] => :stand
  [2, 10, 1, 7, 0]  => :stand
  ⋮                 => ⋮

In [14]:
for k ∈ keys(pol)
    println(k)
    if !(pol[k] == nothing)
        println(k, pol[k])
    end
end

[2, 18, 4, 20, 1]
[2, 18, 1, 6, 0]
[2, 18, 1, 6, 0]stand
[2, 21, 1, 1, 1]
[2, 7, 3, 24, 1]
[2, 7, 3, 24, 1]stand
[2, 21, 1, 6, 1]
[2, 17, 1, 2, 0]
[2, 17, 1, 2, 0]stand
[2, 4, 1, 11, 0]
[2, 4, 1, 11, 0]stand
[2, 3, 1, 8, 0]
[2, 3, 1, 8, 0]stand
[3, 8, 4, 30, 1]
[3, 8, 4, 30, 1]hit
[2, 14, 3, 23, 1]
[2, 14, 3, 23, 1]hit
[2, 13, 1, 1, 0]
[2, 13, 1, 1, 0]stand
[2, 19, 1, 2, 0]
[2, 19, 1, 2, 0]stand
[2, 15, 1, 5, 0]
[2, 15, 1, 5, 0]stand
[2, 21, 1, 3, 1]
[2, 17, 1, 8, 0]
[2, 17, 1, 8, 0]stand
[2, 9, 4, 30, 1]
[2, 9, 4, 30, 1]stand
[2, 7, 1, 7, 0]
[2, 7, 1, 7, 0]stand
[2, 20, 4, 24, 1]
[2, 20, 4, 24, 1]stand
[2, 10, 1, 5, 0]
[2, 10, 1, 5, 0]stand
[2, 15, 1, 1, 0]
[2, 15, 1, 1, 0]stand
[2, 3, 1, 6, 0]
[2, 3, 1, 6, 0]stand
[2, 19, 1, 10, 0]
[2, 19, 1, 10, 0]stand
[2, 12, 1, 4, 0]
[2, 12, 1, 4, 0]stand
[2, 14, 3, 21, 1]
[2, 14, 3, 21, 1]stand
[2, 10, 1, 7, 0]
[2, 10, 1, 7, 0]stand
[2, 12, 4, 25, 1]
[2, 12, 4, 25, 1]stand
[2, 8, 4, 24, 1]
[2, 8, 4, 24, 1]stand
[2, 3, 1, 9, 0]
[2, 3, 1, 9, 0]sta