In [1]:
using Revise

# first import the POMDPs.jl interface
using POMDPs

# import our helper Distributions.jl module
using Distributions

# POMDPToolbox has some glue code to help us use Distributions.jl
using POMDPToolbox

using DataFrames

In [2]:
struct GameState 
    outs::Int64         # number of outs
    balls::Int64        # number of balls
    strikes::Int64      # number of strikes
    base_code::Int64    # coded runners on base
    done::Bool          # are we in a terminal state?
end

In [3]:
struct Pitch
    pitch_code::Int64   # pitch type ie: curveball, fastball, ..
    x_loc::Int64        # x quantile of strike zone
    z_loc::Int64        # z quantile of strike zone
end

In [4]:
function getPrior(s::Int64, numStates::Int64, numActions::Int64, N::Array{Int64,3})
    prior = zeros(numStates)
    for a = 1:numActions
        prior = prior .+ N[s,a,:]
    end
    return prior./numActions
end

function computeTransitionDists(data::Array{Int64,2}, numStates::Int64, numActions::Int64)
    N = zeros(Int64,numStates, numActions, numStates)
    for i = 1:length(data[:,1])
        s = data[i,1]
        a = data[i,2]
        sp = data[i,4]
        N[s,a,sp] = N[s,a,sp] + 1
    end
    
    T = zeros(numStates, numStates, numActions)
    for s = 1:numStates
        prior = getPrior(s, numStates, numActions, N)
        for a = 1:numActions
            T[:,s,a] = prior .+ N[s,a,:]
            T[:,s,a] = T[:,s,a]./sum(T[:,s,a])
        end
    end
    for s = 1:numStates
        if any(isnan,T[:,s,:])
            T[:,s,:] = (T[:,max(s-1,0),:] + T[:,min(s+1,numStates),:])./2
        end
    end
    return T
end
;

In [5]:
# object declaration
type PitchCalling <: MDP{GameState, Pitch} # Note that our MDP is parametarized by the state and the action
    discount_factor::Float64 # discount factor
    T::Array{Float64,3}
    S::Array{GameState,1}
    A::Array{Pitch,1}
end

# constructor
function PitchCalling(;
                    discount_factor::Float64=0.99)
    inputfilename = "../obs2.csv"
    s = GameState[] # initialize an array of GameStates
    for outs = 0:2, balls = 0:3, strikes = 0:2, base_code = 0:7
        push!(s, GameState(outs,balls,strikes,base_code, false))
    end
    push!(s, GameState(3, 0, 0, 0, true))
    a = Pitch[] # initialize an array of Pitches
    for pitch_code = 0:6, x_loc = 0:2, z_loc = 0:2
        push!(a, Pitch(pitch_code, x_loc, z_loc))
    end
    data = readtable(inputfilename)
    data = convert(Array{Int64}, data)
    T = computeTransitionDists(data, length(s), length(a))
    return PitchCalling(discount_factor, T, s, a)
end

# states method
function POMDPs.states(mdp::PitchCalling)
    return mdp.S
end;

# actions method
function POMDPs.actions(mdp::PitchCalling)
    return mdp.A
end;

# transition method
function POMDPs.transition(mdp::PitchCalling, state::GameState, action::Pitch)
    s = state.outs*96 + state.balls*24 + state.strikes*8 + state.base_code + 1
    a = action.pitch_code*9 + action.x_loc*3 + action.z_loc + 1
    if state.done
        return SparseCat([GameState(3,0,0,0,true)], [1.0])
    end
    return SparseCat(mdp.S,mdp.T[:,s,a])
end

function numRunners(state::GameState)
    if state.base_code in (0)
        return 0
    elseif state.base_code in (1, 2, 3)
        return 1
    elseif state.base_code in (4, 5, 6)
        return 2
    else
        return 3
    end
end

# reward method
function POMDPs.reward(mdp::PitchCalling, s::GameState, a::Pitch, sp::GameState)
    if s.done
        return 0.0
    end
    runs_scored = numRunners(s) - numRunners(sp) + 1 - (sp.outs) + (s.outs)
    return -runs_scored
end

# miscellaneous methods
POMDPs.n_states(mdp::PitchCalling) = length(mdp.S)
POMDPs.n_actions(mdp::PitchCalling) = length(mdp.A)
POMDPs.discount(mdp::PitchCalling) = mdp.discount_factor;

function POMDPs.state_index(mdp::PitchCalling, s::GameState)
    return  convert(Int64,s.outs*96 + s.balls*24 + s.strikes*8 + s.base_code + 1)
end

function POMDPs.action_index(mdp::PitchCalling, a::Pitch)
    return convert(Int64,a.pitch_code*9 + a.x_loc*3 + a.z_loc + 1)
end

POMDPs.isterminal(mdp::PitchCalling, s::GameState) = s.done

In [6]:
#=mdp = PitchCalling()
sim(mdp, GameState(0,0,0,0,false), max_steps=100) do s
    println("state is: $s")
    a = Pitch(rand(0:6), rand(0:2), rand(0:2))
    println("throwing $a")
    return a
end
;=#

In [7]:
# first let's load the value iteration module
using DiscreteValueIteration

# initialize the problem
mdp = PitchCalling()

# initialize the solver
# max_iterations: maximum number of iterations value iteration runs for (default is 100)
# belres: the value of Bellman residual used in the solver (defualt is 1e-3)
solver = ValueIterationSolver(max_iterations=1000, belres=1e-3)

# initialize the policy by passing in your problem
policy = ValueIterationPolicy(mdp) 

# solve for an optimal policy
# if verbose=false, the text output will be supressed (false by default)
solve(solver, mdp, policy, verbose=true);

[Iteration 1   ] residual:       2.11 | iteration runtime:     88.265 ms, (    0.0883 s total)
[Iteration 2   ] residual:       1.67 | iteration runtime:     72.618 ms, (     0.161 s total)
[Iteration 3   ] residual:       1.63 | iteration runtime:     67.083 ms, (     0.228 s total)
[Iteration 4   ] residual:        1.1 | iteration runtime:     68.950 ms, (     0.297 s total)
[Iteration 5   ] residual:      0.862 | iteration runtime:     65.784 ms, (     0.363 s total)
[Iteration 6   ] residual:      0.808 | iteration runtime:     67.292 ms, (      0.43 s total)
[Iteration 7   ] residual:      0.756 | iteration runtime:     65.732 ms, (     0.496 s total)
[Iteration 8   ] residual:      0.665 | iteration runtime:     66.733 ms, (     0.562 s total)
[Iteration 9   ] residual:      0.599 | iteration runtime:     65.887 ms, (     0.628 s total)
[Iteration 10  ] residual:      0.545 | iteration runtime:     69.916 ms, (     0.698 s total)
[Iteration 11  ] residual:      0.481 | iteration 

In [8]:
function write_Policy(policy::DiscreteValueIteration.ValueIterationPolicy, score::Float64)
    open(outputfilename, "a") do io
        @printf(io, "Policy Score: %f\n", score)
        @printf(io, "Baseline Score: %f\n", baseline)
        for s in mdp.S
            a = action(policy,s)
            @printf(io, "%d\tPitch code: %i\tX_loc: %i\tZ_loc: %i\n", action_index(mdp,a), a.pitch_code, a.x_loc, a.z_loc)
        end
    end
end

function computePolicyScore(policy::DiscreteValueIteration.ValueIterationPolicy)
    U = zeros(length(mdp.S),1)
    for i = 100
        for s in mdp.S
            if state_index(mdp,s) == 289
                continue
            end
            s_i = state_index(mdp,s)
            a_i = action_index(mdp,action(policy,s))
            U[s_i] = R[s_i,a_i] + sum(mdp.T[:,s_i,a_i].*U[:])
        end
    end
    return sum(U)
end

function computePolicyScore(Pi::Array{Int64,1})
    U = zeros(kNumStates,1)
    for i = 1:100
        for s = 1:(kNumStates-1)
            U[s] = R[s,Pi[s]] + sum(mdp.T[:,s,Pi[s]].*U[:])
        end
    end
    return sum(U)
end

function computeRewardMatrix(data::Array{Int64,2})
    N = zeros(Int64,kNumStates, kNumActions, kNumStates)
    for i = 1:length(data[:,1])
        s = data[i,1]
        a = data[i,2]
        sp = data[i,4]
        N[s,a,sp] = N[s,a,sp] + 1
    end

    R = zeros(kNumStates, kNumActions)
    counts = zeros(kNumStates, kNumActions)
    for i = 1:length(data[:,1])
        s = data[i,1]
        a = data[i,2]
        r = data[i,3]
        counts[s,a] = counts[s,a] + 1
        rhat = r/sum(N[s,a,:])
        R[s,a] = R[s,a] + (rhat - R[s,a])/counts[s,a]
    end
    return R
end

function computeBaselineScore()
    aveScore = 0.0
    n = 100
    for i = 1:n
        Pi = rand(1:kNumActions,kNumStates)
        score = computePolicyScore(Pi)
        @printf("Random Policy Score: %f\r\n", score)
        aveScore = aveScore + score
    end
    aveScore = aveScore/n
    @printf("Average of 100 Random Policy Scores: %f\r\n", aveScore)  
    return aveScore
end

kNumStates = n_states(mdp)
kNumActions = n_actions(mdp)
inputfilename = "../obs2.csv"
outputfilename = "./Results_No_Handedness/Value_Iteration/run2.policy"
data = readtable(inputfilename)
data = convert(Array{Int64}, data)
R = computeRewardMatrix(data)
baseline = computeBaselineScore()
write_Policy(policy, computePolicyScore(policy));

Random Policy Score: -50.405577
Random Policy Score: -70.185013
Random Policy Score: -45.707971
Random Policy Score: -54.675371
Random Policy Score: -51.319952
Random Policy Score: -64.189344
Random Policy Score: -91.596078
Random Policy Score: -104.067040
Random Policy Score: -58.025555
Random Policy Score: -72.601539
Random Policy Score: -87.048369
Random Policy Score: -121.634173
Random Policy Score: -137.690268
Random Policy Score: -65.689383
Random Policy Score: -68.920953
Random Policy Score: -67.861567
Random Policy Score: -72.726093
Random Policy Score: -72.232123
Random Policy Score: -65.430750
Random Policy Score: -79.706502
Random Policy Score: -61.403084
Random Policy Score: -96.278574
Random Policy Score: -83.583963
Random Policy Score: -81.039522
Random Policy Score: -103.712260
Random Policy Score: -71.231382
Random Policy Score: -81.008893
Random Policy Score: -46.235871
Random Policy Score: -89.045312
Random Policy Score: -37.091499
Random Policy Score: -83.297957
Rand