In [4]:
using Revise

# first import the POMDPs.jl interface
using POMDPs

# import our helper Distributions.jl module
using Distributions

# POMDPToolbox has some glue code to help us use Distributions.jl
using POMDPToolbox

using DataFrames

# first let's load the value iteration module
using DiscreteValueIteration

In [5]:
pitch_codes = ["change-up","curve-ball","fast-ball","cutter","splitter","sinker","slider"];

In [6]:
struct GameState 
    batter_hand::Int64  # 0 for left, 1 for right
    pitcher_hand::Int64 # 0 for left, 1 for right
    outs::Int64         # number of outs
    balls::Int64        # number of balls
    strikes::Int64      # number of strikes
    base_code::Int64    # coded runners on base
    done::Bool          # are we in a terminal state?
end

In [7]:
struct Pitch
    pitch_code::Int64   # pitch type ie: curveball, fastball, ..
    x_loc::Int64        # x quantile of strike zone
    z_loc::Int64        # z quantile of strike zone
    switch_hand::Bool   # signal to switch pitcher hand
end

In [12]:
function getPrior(s::Int64, numStates::Int64, numActions::Int64, N::Array{Int64,3})
    prior = zeros(numStates)
    for a = 1:numActions
        prior = prior .+ N[s,a,:]
    end
    return prior./numActions
end

function computeTransitionDists(data::Array{Int64,2}, numStates::Int64, numActions::Int64)
    N = zeros(Int64,numStates, numActions, numStates)
    for i = div(length(data[:,1]),2):length(data[:,1])
        s = data[i,1]
        a = data[i,2]
        sp = data[i,4]
        N[s,a,sp] = N[s,a,sp] + 1
    end
    
    T = zeros(numStates, numStates, numActions)
    for s = 1:numStates
        prior = getPrior(s,numStates,numActions,N)
        for a = 1:numActions
            T[:,s,a] = N[s,a,:] .+ prior
            T[:,s,a] = T[:,s,a] ./ sum(T[:,s,a])
        end
    end
    for s = 1:(numStates-1)
        if any(isnan,T[:,s,:])
            T[:,s,:] = T[:,(s+864),:]
        end
    end
    T[:,1153,:] = zeros(numStates,numActions)
    T[1153,1153,:] = ones(numActions)
    return T
end
;

In [9]:
# object declaration
type PitchCalling <: MDP{GameState, Pitch} # Note that our MDP is parametarized by the state and the action
    discount_factor::Float64 # discount factor
    T::Array{Float64,3}
    S::Array{GameState,1}
    A::Array{Pitch,1}
    switch_penalty::Float64
end

# constructor
function PitchCalling(;
                    discount_factor::Float64=1.0,
                    switch_penalty::Float64=-0.1)
    inputfilename = "../obs3.csv"
    s = GameState[]
    for batterhand = 0:1, pitcher_hand = 0:1, outs = 0:2, balls = 0:3, strikes = 0:2, base_code = 0:7
        push!(s, GameState(batterhand,pitcher_hand,outs,balls,strikes,base_code,false))
    end
    push!(s, GameState(2,0,0,0,0,0,true))
    
    a = Pitch[] # initialize an array of Pitches
    for pitch_code = 0:6, x_loc = 0:2, z_loc = 0:2
        push!(a, Pitch(pitch_code, x_loc, z_loc, false))
    end
    push!(a,Pitch(0,0,0,true))
    
    data = readtable(inputfilename)
    data = convert(Array{Int64}, data)
    T = computeTransitionDists(data, length(s), length(a))
    return PitchCalling(discount_factor, T, s, a, switch_penalty)
end

# states method
function POMDPs.states(mdp::PitchCalling)
    return mdp.S
end;

# actions method
function POMDPs.actions(mdp::PitchCalling)
    return mdp.A
end;

# transition method
function POMDPs.transition(mdp::PitchCalling, s::GameState, a::Pitch)
    s_i = state_index(mdp, s)
    a_i = action_index(mdp, a)
    if s.done
        return SparseCat([GameState(2,0,0,0,0,0,true)], [1.0])
    elseif a.switch_hand
        sp = GameState(s.batter_hand,(s.pitcher_hand+1)%2,s.outs,s.balls,s.strikes,s.base_code,s.done)
        return SparseCat([sp],[1.0])
    else
        return SparseCat(mdp.S,mdp.T[:,s_i,a_i])
    end
end

function numRunners(state::GameState)
    if state.base_code in (0)
        return 0
    elseif state.base_code in (1, 2, 3)
        return 1
    elseif state.base_code in (4, 5, 6)
        return 2
    else
        return 3
    end
end

# reward method
function POMDPs.reward(mdp::PitchCalling, s::GameState, a::Pitch, sp::GameState)
    if s.done
        return 0.0
    elseif a.switch_hand
        return mdp.switch_penalty 
    else
        n_s = numRunners(s)
        n_sp = numRunners(sp)
        if (sp.balls == 0) && (sp.strikes == 0)
            return -(numRunners(s) - numRunners(sp) + 1 - sp.outs + s.outs)
        else
            return -(numRunners(s) - numRunners(sp) - sp.outs + s.outs)
        end
    end
end

# miscellaneous methods
POMDPs.n_states(mdp::PitchCalling) = length(mdp.S)
POMDPs.n_actions(mdp::PitchCalling) = length(mdp.A)
POMDPs.discount(mdp::PitchCalling) = mdp.discount_factor;

function POMDPs.state_index(mdp::PitchCalling, s::GameState)
    if s.done
        return 1153
    else
        return convert(Int64,s.batter_hand*576 + s.pitcher_hand*288 + s.outs*96 + s.balls*24 + s.strikes*8 + s.base_code + 1)
    end
end

function POMDPs.action_index(mdp::PitchCalling, a::Pitch)
    if a.switch_hand
        return 64
    else
        return convert(Int64,a.pitch_code*9 + a.x_loc*3 + a.z_loc + 1)
    end
end

POMDPs.isterminal(mdp::PitchCalling, s::GameState) = s.done

In [10]:
#=for s = 1:n_states(mdp)
    if any(isnan,mdp.T[:,s,:])
        println("state is: $s")
    end
end=#

In [11]:
# initialize the problem
mdp = PitchCalling()

# initialize the solver
# max_iterations: maximum number of iterations value iteration runs for (default is 100)
# belres: the value of Bellman residual used in the solver (defualt is 1e-3)
solver = ValueIterationSolver(max_iterations=1000, belres=1e-3)

# initialize the policy by passing in your problem
policy = ValueIterationPolicy(mdp) 

# solve for an optimal policy
# if verbose=false, the text output will be supressed (false by default)
solve(solver, mdp, policy, verbose=true);

[Iteration 1   ] residual:      0.938 | iteration runtime:    942.419 ms, (     0.942 s total)
[Iteration 2   ] residual:      0.901 | iteration runtime:    956.553 ms, (       1.9 s total)
[Iteration 3   ] residual:      0.832 | iteration runtime:   1005.960 ms, (       2.9 s total)
[Iteration 4   ] residual:      0.513 | iteration runtime:    969.527 ms, (      3.87 s total)
[Iteration 5   ] residual:      0.291 | iteration runtime:   1197.945 ms, (      5.07 s total)
[Iteration 6   ] residual:      0.284 | iteration runtime:   1138.879 ms, (      6.21 s total)
[Iteration 7   ] residual:        0.2 | iteration runtime:   1094.514 ms, (      7.31 s total)
[Iteration 8   ] residual:        0.2 | iteration runtime:    976.939 ms, (      8.28 s total)
[Iteration 9   ] residual:        0.2 | iteration runtime:    999.166 ms, (      9.28 s total)
[Iteration 10  ] residual:        0.2 | iteration runtime:    917.530 ms, (      10.2 s total)
[Iteration 11  ] residual:        0.2 | iteration 

[Iteration 88  ] residual:     0.0396 | iteration runtime:    914.984 ms, (      83.2 s total)
[Iteration 89  ] residual:     0.0387 | iteration runtime:    910.044 ms, (      84.1 s total)
[Iteration 90  ] residual:     0.0378 | iteration runtime:   1005.095 ms, (      85.1 s total)
[Iteration 91  ] residual:     0.0368 | iteration runtime:    963.838 ms, (      86.1 s total)
[Iteration 92  ] residual:     0.0359 | iteration runtime:    899.212 ms, (        87 s total)
[Iteration 93  ] residual:     0.0351 | iteration runtime:    912.684 ms, (      87.9 s total)
[Iteration 94  ] residual:     0.0343 | iteration runtime:    911.860 ms, (      88.8 s total)
[Iteration 95  ] residual:     0.0334 | iteration runtime:    912.278 ms, (      89.7 s total)
[Iteration 96  ] residual:     0.0326 | iteration runtime:    967.389 ms, (      90.7 s total)
[Iteration 97  ] residual:     0.0318 | iteration runtime:    957.026 ms, (      91.6 s total)
[Iteration 98  ] residual:      0.031 | iteration 

[Iteration 175 ] residual:    0.00374 | iteration runtime:    918.943 ms, (       165 s total)
[Iteration 176 ] residual:    0.00363 | iteration runtime:    925.630 ms, (       166 s total)
[Iteration 177 ] residual:    0.00352 | iteration runtime:    924.372 ms, (       167 s total)
[Iteration 178 ] residual:    0.00342 | iteration runtime:    919.665 ms, (       167 s total)
[Iteration 179 ] residual:    0.00332 | iteration runtime:    953.394 ms, (       168 s total)
[Iteration 180 ] residual:    0.00322 | iteration runtime:    966.169 ms, (       169 s total)
[Iteration 181 ] residual:    0.00312 | iteration runtime:    953.615 ms, (       170 s total)
[Iteration 182 ] residual:    0.00303 | iteration runtime:    929.005 ms, (       171 s total)
[Iteration 183 ] residual:    0.00294 | iteration runtime:    956.193 ms, (       172 s total)
[Iteration 184 ] residual:    0.00286 | iteration runtime:    929.964 ms, (       173 s total)
[Iteration 185 ] residual:    0.00277 | iteration 

In [13]:
function write_Policy(policy::DiscreteValueIteration.ValueIterationPolicy, score::Float64)
    open(outputfilename, "a") do io
        @printf(io, "Policy Score: %f\n", score)
        @printf(io, "Baseline Score: %f\n", baseline)
        @printf(io, "Pitcher Switch Cost: %f\n", mdp.switch_penalty)
        @printf(io, "s_i bh ph outs balls strikes base_code a_i pitch_type x_loc z_loc\n")
        for s in mdp.S
            a = action(policy,s)
            #@printf(io, "State -> bh: %i\t ph: %i\touts: %i\tballs: %i\tstrikes: %i\tbase_code: %i\t\t",s.batter_hand, s.pitcher_hand, s.outs, s.balls, s.strikes, s.base_code)
            @printf(io, "%i\t%i\t%i\t%i\t%i\t%i\t%i\t",state_index(mdp,s),s.batter_hand, s.pitcher_hand, s.outs, s.balls, s.strikes, s.base_code)
            if a.switch_hand
                @printf(io, "%d\t%s\t0\t0\n", action_index(mdp,a), "Switch",)
            else
                @printf(io, "%d\t%s\t%i\t%i\n", action_index(mdp,a), pitch_codes[a.pitch_code+1], a.x_loc, a.z_loc)
            end
        end
    end
end

function computePolicyScore(policy::DiscreteValueIteration.ValueIterationPolicy)
    U = zeros(length(mdp.S),1)
    for i = 100
        for s in mdp.S
            if state_index(mdp,s) == 1153
                continue
            end
            s_i = state_index(mdp,s)
            a_i = action_index(mdp,action(policy,s))
            U[s_i] = R[s_i,a_i] + sum(mdp.T[:,s_i,a_i].*U[:])
        end
    end
    return sum(U)
end

function computePolicyScore(Pi::Array{Int64,1})
    U = zeros(kNumStates,1)
    for i = 1:100
        for s = 1:(kNumStates-1)
            U[s] = R[s,Pi[s]] + sum(mdp.T[:,s,Pi[s]].*U[:])
        end
    end
    return sum(U)
end

function computeRewardMatrix(data::Array{Int64,2})
    N = zeros(Int64,kNumStates, kNumActions, kNumStates)
    for i = 1:length(data[:,1])
        s = data[i,1]
        a = data[i,2]
        sp = data[i,4]
        N[s,a,sp] = N[s,a,sp] + 1
    end

    R = zeros(kNumStates, kNumActions)
    counts = zeros(kNumStates, kNumActions)
    for i = 1:length(data[:,1])
        s = data[i,1]
        a = data[i,2]
        r = data[i,3]
        counts[s,a] = counts[s,a] + 1
        rhat = r/sum(N[s,a,:])
        R[s,a] = R[s,a] + (rhat - R[s,a])/counts[s,a]
    end
    return R
end

function computeBaselineScore()
    aveScore = 0.0
    n = 100
    for i = 1:n
        Pi = rand(1:(kNumActions-1),kNumStates)
        score = computePolicyScore(Pi)
        @printf("Random Policy Score: %f\r\n", score)
        aveScore = aveScore + score
    end
    aveScore = aveScore/n
    @printf("Average of 100 Random Policy Scores: %f\r\n", aveScore)  
    return aveScore
end

mdp = PitchCalling()
kNumStates = n_states(mdp)
kNumActions = n_actions(mdp)
inputfilename = "../obs3.csv"
outputfilename = "./Results_Handedness/Value_Iteration/run4_half_data.policy"
data = readtable(inputfilename)
data = convert(Array{Int64}, data)
R = computeRewardMatrix(data)
baseline = -391.635155 #computeBaselineScore()
write_Policy(policy, computePolicyScore(policy));

In [17]:
function simulateInningOptimal()
    score = 0.0
    s = GameState(rand([0,1]),rand([0,1]),0,0,0,0,false)
    sp = GameState
    while(s.done == false)
        a = action(policy,s)
        x = transition(mdp, s, a)
        sp = sample(x.vals, ProbabilityWeights(x.probs))
        r = reward(mdp, s, a, sp)
        if r == -1.0
           score = score + r 
        end
        s = sp
    end
    return score
end

function simulateInningRandom()
    score = 0.0
    s = GameState(rand([0,1]),rand([0,1]),0,0,0,0,false)
    sp = GameState
    while(s.done == false)
        a = actions(mdp)[rand(1:n_actions(mdp))]
        x = transition(mdp, s, a)
        sp = sample(x.vals, ProbabilityWeights(x.probs))
        r = reward(mdp, s, a, sp)
        if r == -1.0
           score = score + r 
        end
        s = sp
    end
    return score
end

function simulateInningMccann(P::Array{Float64,2})
    score = 0.0
    s = GameState(rand([0,1]),rand([0,1]),0,0,0,0,false)
    sp = GameState
    while(s.done == false)
        s_i = state_index(mdp, s)
        if any(isnan,P[s_i,:])
            a = action(policy,s)
        else
            a_i = sample(collect(1:n_actions(mdp)), ProbabilityWeights(P[s_i,:]))
            a = actions(mdp)[a_i]
        end
        x = transition(mdp, s, a)
        sp = sample(x.vals, ProbabilityWeights(x.probs))
        r = reward(mdp, s, a, sp)
        if r == -1.0
           score = score + r 
        end
        s = sp
    end
    return score 
end

function getMccannCounts()
    fileName = "../obsmccann.csv"
    data = readtable(fileName)
    data = convert(Array{Int64}, data)
    N = zeros(n_states(mdp),n_actions(mdp))
    for i = 1:length(data[:,1])
        s = data[i,1]
        a = data[i,2]
        N[s,a] = N[s,a] + 1
    end
    for s = 1:n_states(mdp)
        N[s,:] = N[s,:]./sum(N[s,:])
    end
    return N
end

function compareNInnings(n::Int64)
    rand_ave_score = 0.0
    optimal_ave_score = 0.0
    mccann_ave_score = 0.0
    P = getMccannCounts()
    for i = 1:n
        rand_ave_score = rand_ave_score + simulateInningRandom()
        optimal_ave_score = optimal_ave_score + simulateInningOptimal()
        mccann_ave_score = mccann_ave_score + simulateInningMccann(P)
    end
    rand_ave_score = rand_ave_score/n
    optimal_ave_score = optimal_ave_score/n
    mccann_ave_score = mccann_ave_score/n
    println("Average random score after $n innings: \t$rand_ave_score")
    println("Average Mccann score after $n innings: \t$mccann_ave_score")
    println("Average optimal score after $n innings: \t$optimal_ave_score")
end

function simulateOneGame(catcher1::Int64, catcher2::Int64)
    P = getMccannCounts()
    if (catcher1 == 0) && (catcher2 == 1)
        optimal_score = 0
        mccmann_score = 0
        for inning = 1:9
            optimal_score = simulateInningOptimal() + optimal_score
            mccmann_score = simulateInningMccann(P) + mccmann_score
        end
        while(optimal_score == mccmann_score)
            optimal_score = simulateInningOptimal() + optimal_score
            mccmann_score = simulateInningMccann(P) + mccmann_score
        end
        return optimal_score, mccmann_score
    elseif (catcher1 == 0) && (catcher2 == 2)
        optimal_score = 0
        random_score = 0
        for inning = 1:9
            optimal_score = simulateInningOptimal() + optimal_score
            random_score = simulateInningRandom() + random_score
        end
        while(optimal_score == random_score)
            optimal_score = simulateInningOptimal() + optimal_score
            random_score = simulateInningRandom() + random_score
        end
        return optimal_score, random_score
    elseif (catcher1 == 1) && (catcher2 == 2)
        mccmann_score = 0
        random_score = 0
        for inning = 1:9
            mccmann_score = simulateInningMccann(P) + mccmann_score
            random_score = simulateInningRandom() + random_score
        end
        while(random_score == mccmann_score)
            mccmann_score = simulateInningMccann(P) + mccmann_score
            random_score = simulateInningRandom() + random_score
        end
        return mccmann_score, random_score
    else
        println("incorrect input numbers")
    end
end

#compareNInnings(1000000)
mutable struct record
    wins::Int64
    losses::Int64
end

mdp = PitchCalling()
optimal_record = record(0,0)
mccmann_record = record(0,0)
random_record = record(0,0)

for games = 1:1000
    p1_score, p2_score = simulateOneGame(0,2)
    if p1_score > p2_score
        optimal_record.wins = optimal_record.wins + 1
        random_record.losses = random_record.losses + 1
    else
        optimal_record.losses = optimal_record.losses + 1
        random_record.wins = random_record.wins + 1
    end
end
println("Optimal record: $optimal_record \t Random record: $random_record")

Mccann record: record(484, 516) 	 Random record: record(516, 484)


In [3]:
score = 0.0
s = GameState(0,0,0,0,0,0,false)
sp = GameState(0,0,0,0,0,0,false)
while(s.done == false)
    a = action(policy,s)
    x = transition(mdp, s, a)
    sp = sample(x.vals, ProbabilityWeights(x.probs))
    r = reward(mdp, s, a, sp)
    if r == -1.0
        score = score + r 
    end
    println("state is: \t$s")
    println("action is: \t$a")
    println("next state is: \t$sp")
    println("score is: \t$score\n\n")
    s = sp
end

LoadError: [91mUndefVarError: GameState not defined[39m