# Políticas para jugar 21 con Aprendizaje por Refuerzo

En esta libreta atacamos el problema del 21 otra vez. Pero ahora usando **aprendizaje por refuerzo**, y en particular con el algoritmo **SARSA**.

Si llegaste directamente a esta libreta, te recomiendo visitar [esta página](https://rexemin.github.io/Topicos-IA-UNISON/2018/12/03/revancha-21.html) para que tengas un mejor contexto.

Ahora si, vamos a ver el código que puede resolver este problema. Primero importaremos un módulo que facilita el uso de distribuciones de probabilidad.

In [1]:
import Distributions

## Modelo

Análogamente a las libretas pasadas ([aquí](https://nbviewer.jupyter.org/github/rexemin/Topicos-IA-UNISON/blob/master/ProgramacionDinamica/21/21-ProgramacionDinamica-SinRepartidor.ipynb) y [aquí](https://nbviewer.jupyter.org/github/rexemin/Topicos-IA-UNISON/blob/master/ProgramacionDinamica/21/21-ProgramacionDinamica-ConRepartidor.ipynb)), aquí también usaremos una estructura que agrupe todas las cosas que necesitamos (a excepción del SARSA).

El modelo tendrá las siguientes cosas:
- states: los estados iniciales del 21
- actions: un arreglo de todas las posibles acciones
- is_terminal: una función que nos dice si el estado que recibe es terminal o no
- legal_actions: una función que regresa las posibles acciones en un estado particular
- transition: una función que nos devuelve un estado nuevo a partir de un estado actual y una acción
- reward: una función que nos devuelve la recompensa en un estado

In [2]:
mutable struct MDP
    states
    actions
    is_terminal
    legal_actions
    transition
    reward
end

## Los algoritmos

Usaremos dos algoritmos: $\epsilon$-greedy para decidir qué acción tomar, y SARSA para aproximar la función acción-valor.

In [3]:
function ϵ_greedy(model::MDP, Q, s, ϵ)
    if rand(Distributions.Binomial(1, ϵ)) == 1
        return rand(model.legal_actions(s))
    else
        actions_q_value = Dict(a => get(Q, [s, a], 0) for a ∈ model.legal_actions(s))
        
        return findmax(actions_q_value)[2]
    end
end

ϵ_greedy (generic function with 1 method)

In [4]:
function Sarsa(mdp::MDP, γ::Float64, α::Float64, ϵ::Float64, episodes::Int64, steps::Int64)
    Q = Dict()
    for s ∈ mdp.states
        if !mdp.is_terminal(s)
            for a ∈ mdp.legal_actions(s)
                Q[[s, a]] = 0
            end
        else
            Q[[s, nothing]] = 0
        end
    end
    
    for episode ∈ 1:episodes
        println("In episode ", episode)
        
        s = rand(mdp.states)
        
        not_terminal = false
        while !not_terminal
            s = rand(mdp.states)
            
            if !mdp.is_terminal(s)
                not_terminal = true
            end
        end
        
        if mdp.is_terminal(s)
            Q[[s, nothing]] = mdp.reward(s, nothing, s)
            continue
        end
        
        a = ϵ_greedy(mdp, Q, s, ϵ)
        
        for step ∈ 1:steps
            n_s = mdp.transition(s, a)
            
            if mdp.is_terminal(n_s)
                Q[[n_s, nothing]] = mdp.reward(n_s, nothing, n_s)
                break
            end
            
            r = mdp.reward(s, a, n_s)
            n_a = ϵ_greedy(mdp, Q, s, ϵ)
            
            Q[[s, a]] = get(Q, [s, a], 0) + α * (r + γ*get(Q, [n_s, n_a], 0) - get(Q, [s, a], 0))
            
            s = n_s
            a = n_a
        end
    end
    
    # Aquí recuperamos la política óptima en base a la Q que aproximamos.
    pol = Dict()
    
    for (s, a) ∈ keys(Q)
        if !mdp.is_terminal(s)
            actions_value = Dict()
            
            for a_ ∈ mdp.legal_actions(s)
                actions_value[a] = get(Q, [s, a_], 0)
            end

            pol[s] = findmax(actions_value)[2]
        else
            pol[s] = a
        end
    end
    
    return pol
end

Sarsa (generic function with 1 method)

## Simulación

### Estados iniciales
Ahora crearemos las cosas que necesitamos para poder empezar a simular el 21. Empezaremos calculando los estados iniciales (y únicamente estos estados).

In [5]:
states = []

for i ∈ 2:21
    for j ∈ 1:11
        push!(states, [2, i, j, 0])
    end
end

for s ∈ states
    if s[2] >= 21 || s[3] >= 17 || s[1] == 4
        s[4] = 1
    end
end

states

220-element Array{Any,1}:
 [2, 2, 1, 0]  
 [2, 2, 2, 0]  
 [2, 2, 3, 0]  
 [2, 2, 4, 0]  
 [2, 2, 5, 0]  
 [2, 2, 6, 0]  
 [2, 2, 7, 0]  
 [2, 2, 8, 0]  
 [2, 2, 9, 0]  
 [2, 2, 10, 0] 
 [2, 2, 11, 0] 
 [2, 3, 1, 0]  
 [2, 3, 2, 0]  
 ⋮             
 [2, 20, 11, 0]
 [2, 21, 1, 1] 
 [2, 21, 2, 1] 
 [2, 21, 3, 1] 
 [2, 21, 4, 1] 
 [2, 21, 5, 1] 
 [2, 21, 6, 1] 
 [2, 21, 7, 1] 
 [2, 21, 8, 1] 
 [2, 21, 9, 1] 
 [2, 21, 10, 1]
 [2, 21, 11, 1]

### Acciones

En las siguientes celdas se definen las funciones que nos permitirán manejar las acciones que pueden realizarse en cada estado durante la simulación.

In [6]:
actions = [:hit, :stand]

2-element Array{Symbol,1}:
 :hit  
 :stand

In [8]:
function legal_actions(s)
    if s[4] == 0
        return [:hit, :stand]
    else
        return nothing
    end
end

legal_actions (generic function with 1 method)

### Transiciones

En las siguientes celdas están las funciones que nos permitirán dar pasos en la simulación, así como saber cuando parar.

En la función de transición solo se le suma una carta aleatoria a la mano respectiva y se revisa si se llegó a un estado terminal.

In [7]:
function is_terminal(s)
    terminal = false
    
    if s[4] == 1
        terminal = true
    end
    
    return terminal
end

is_terminal (generic function with 1 method)

In [9]:
function transition(s, a)
    if a == :hit
        s[1] += 1
        s[2] += rand([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 11])
    elseif a == :stand
        s[3] += rand([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 10, 10, 11])
    end
    
    if s[2] >= 21 || s[3] >= 21 || s[1] == 4
        s[4] = 1
    end
    
    return s
end

transition (generic function with 1 method)

### Recompensa

Por último, definiremos la función de recompensa para un paso en la simulación.

In [10]:
function reward(s, a, n_s)
    if n_s[4] == 1
        if n_s[2] <= 21
            if n_s[2] == 21 || n_s[1] == 4 || n_s[2] >= n_s[3] || n_s[3] > 21
                return 1
            else
                return -1
            end
        else
            return -1
        end
    else
        return 0
    end
end

reward (generic function with 1 method)

## Ejecución

Con todo listo, podemos pasar a ejecutar la simulación. También podemos jugar con los parámetros de la misma.

In [11]:
twenty_one = MDP(states, actions, is_terminal, legal_actions, transition, reward)

MDP(Any[[2, 2, 1, 0], [2, 2, 2, 0], [2, 2, 3, 0], [2, 2, 4, 0], [2, 2, 5, 0], [2, 2, 6, 0], [2, 2, 7, 0], [2, 2, 8, 0], [2, 2, 9, 0], [2, 2, 10, 0]  …  [2, 21, 2, 1], [2, 21, 3, 1], [2, 21, 4, 1], [2, 21, 5, 1], [2, 21, 6, 1], [2, 21, 7, 1], [2, 21, 8, 1], [2, 21, 9, 1], [2, 21, 10, 1], [2, 21, 11, 1]], Symbol[:hit, :stand], is_terminal, legal_actions, transition, reward)

In [12]:
γ = 0.9
α = 0.5
ϵ = 0.1
episodes = 90
steps = 10

10

In [13]:
pol = Sarsa(twenty_one, γ, α, ϵ, episodes, steps)

In episode 1
In episode 2
In episode 3
In episode 4
In episode 5
In episode 6
In episode 7
In episode 8
In episode 9
In episode 10
In episode 11
In episode 12
In episode 13
In episode 14
In episode 15
In episode 16
In episode 17
In episode 18
In episode 19
In episode 20
In episode 21
In episode 22
In episode 23
In episode 24
In episode 25
In episode 26
In episode 27
In episode 28
In episode 29
In episode 30
In episode 31
In episode 32
In episode 33
In episode 34
In episode 35
In episode 36
In episode 37
In episode 38
In episode 39
In episode 40
In episode 41
In episode 42
In episode 43
In episode 44
In episode 45
In episode 46
In episode 47
In episode 48
In episode 49
In episode 50
In episode 51
In episode 52
In episode 53
In episode 54
In episode 55
In episode 56
In episode 57
In episode 58
In episode 59
In episode 60
In episode 61
In episode 62
In episode 63
In episode 64
In episode 65
In episode 66
In episode 67
In episode 68
In episode 69
In episode 70
In episode 71
In episode 72
I

Dict{Any,Any} with 201 entries:
  [2, 17, 21, 1] => :hit
  [2, 21, 10, 1] => nothing
  [2, 4, 21, 1]  => nothing
  [2, 10, 3, 0]  => :hit
  [2, 13, 21, 1] => :hit
  [2, 6, 7, 0]   => :hit
  [2, 3, 29, 1]  => nothing
  [2, 11, 26, 1] => :stand
  [2, 10, 29, 1] => nothing
  [2, 4, 11, 0]  => :hit
  [2, 2, 8, 0]   => :stand
  [2, 7, 26, 1]  => :hit
  [2, 9, 4, 0]   => :hit
  [2, 9, 3, 0]   => :hit
  [2, 19, 11, 0] => :hit
  [2, 9, 7, 0]   => :hit
  [2, 5, 10, 0]  => :stand
  [2, 15, 1, 0]  => :stand
  [2, 13, 6, 0]  => :hit
  [2, 16, 26, 1] => :hit
  [2, 16, 4, 0]  => :stand
  [2, 3, 11, 0]  => :hit
  [2, 18, 10, 0] => :stand
  [2, 17, 10, 0] => :stand
  [2, 5, 6, 0]   => :stand
  ⋮              => ⋮

## Política encontrada

Finalmente, podemos ver los resultados del SARSA. Muchísimo mejores que los que conseguí con programación dinámica.

In [14]:
for k ∈ keys(pol)
    if !(pol[k] == nothing)
        println(k, pol[k])
    end
end

[2, 17, 21, 1]hit
[2, 10, 3, 0]hit
[2, 13, 21, 1]hit
[2, 6, 7, 0]hit
[2, 11, 26, 1]stand
[2, 4, 11, 0]hit
[2, 2, 8, 0]stand
[2, 7, 26, 1]hit
[2, 9, 4, 0]hit
[2, 9, 3, 0]hit
[2, 19, 11, 0]hit
[2, 9, 7, 0]hit
[2, 5, 10, 0]stand
[2, 15, 1, 0]stand
[2, 13, 6, 0]hit
[2, 16, 26, 1]hit
[2, 16, 4, 0]stand
[2, 3, 11, 0]hit
[2, 18, 10, 0]stand
[2, 17, 10, 0]stand
[2, 5, 6, 0]stand
[2, 18, 6, 0]stand
[2, 14, 11, 0]hit
[2, 12, 27, 1]hit
[2, 18, 5, 0]hit
[2, 14, 9, 0]hit
[2, 5, 27, 1]stand
[2, 3, 9, 0]hit
[2, 19, 22, 1]stand
[2, 20, 11, 0]stand
[2, 7, 7, 0]hit
[2, 12, 23, 1]stand
[2, 17, 24, 1]stand
[2, 8, 2, 0]stand
[2, 20, 2, 0]hit
[2, 14, 5, 0]stand
[2, 16, 21, 1]hit
[2, 7, 6, 0]stand
[2, 15, 27, 1]stand
[2, 10, 7, 0]hit
[2, 4, 28, 1]stand
[2, 5, 11, 0]hit
[2, 15, 3, 0]hit
[2, 17, 23, 1]hit
[2, 5, 9, 0]hit
[2, 3, 28, 1]hit
[2, 10, 4, 0]stand
[2, 6, 1, 0]stand
[2, 15, 30, 1]stand
[3, 23, 19, 1]stand
[2, 16, 6, 0]stand
[2, 13, 22, 1]stand
[2, 10, 6, 0]hit
[2, 2, 30, 1]hit
[2, 17, 6, 0]stand
[2, 19