In [None]:
using POMDPs # for MDP type
using DiscreteValueIteration
using POMDPPolicies
using POMDPModelTools #for sparse cat 
using Parameters
using Random
using Plots; default(fontfamily="Computer Modern", framestyle=:box) # LaTex-style
using QuickPOMDPs
using Distributions 
#using PlotlyJS
using LinearAlgebra
using POMDPSimulators
using Measures

In [None]:
Random.seed!(0xC0FFEE)

**States**

In [None]:
struct State
    c::Int # chairs remaining 
    t::Int # time remaining 
    f::Int # family size 
    v::Int # visa status 
end 

**Environment Parameters**

In [None]:
@with_kw struct EvacuationParameters
    family_sizes::Vector{Int} = [1, 2, 3, 4, 5] # set with no repeats 
    family_prob = [.1, .2, .3, .2, .2]
    visa_status::Vector{Int} = [-2, -1, 0, 1, 2] #TODO: map to various status strings
    visa_prob = [.1, .1, .4, .2, .2]
    capacity::Int = 60 # keeping these both as integers of 20 for now. 
    time::Int = 60
    size::Tuple{Int, Int} = (length(visa_status), length(family_sizes)) # size of grid 
    p_transition::Real = 0.8 # don't we always transition into this since time moves forward? I'm confused... 
    null_state::State = State(-1, -1, -1 ,-1) # is there someway to do this???
    accept_prob = [.80, .20]
    reject_prob = [1.0]
end

In [None]:
params = EvacuationParameters(); 

In [None]:
@show number_states = params.capacity * params.time * size(params.family_sizes)[1] * size(params.visa_status)[1]

In [None]:
# The state space S for the evacuation problem is the set of all combinations 
𝒮 = []
for c in 0:params.capacity # capacity ends at 0 
    for t in 0:params.time # time ends at 0 
        for f in params.family_sizes # family size here we should have the ACTUAL family sizes 
            for v in params.visa_status # actual visa statuses 
                new = State(c, t, f, v) 
                𝒮 = [𝒮; new]
            end
        end        
    end
end
𝒮 = [𝒮; params.null_state]

**Actions**

In [None]:
# the possible actions are whether accept or reject a family at the gate 
@enum Action REJECT ACCEPT

In [None]:
𝒜 = [REJECT, ACCEPT]

In [None]:
# only inbounds if room for the family [assuming would not separate even though might]
# and if time is available to enter the airport 
validtime(s::State) = 0 < s.t 

In [None]:
validcapacity(s::State) = 0 ≤ s.c # maybe try back to 0 

**Transition Function** 

In [None]:
# #***** OLD ONE THAT WORKED ******
# function T(s::State, a::Action)
#     next_states = []
#     if validtime(s) 
        
#         f′ = rand(params.family_sizes) # pull according to those probabilities 
#         v′ = rand(params.visa_status) # TODO: possibly make this weighted in some way 
#         # keep pushing to next states and have a bigger associated probability mass with those possible next states. 
#         # need a for loop iterating over all family sizes and and visa statuses 
        
#         if a == ACCEPT 
#             next_state_accept = State(s.c - s.f, s.t - 1, f′, v′) # they get seats
#             next_state_reject = State(s.c, s.t - 1, f′, v′)
#             push!(next_states, next_state_accept)
#             push!(next_states, next_state_reject)
#             if !validcapacity(next_state_accept) 
#                 probabilities = [0, 1] #no room for full family :( so we make probability 0 to accept and 1 reject
#             else
#                 probabilities = [.80, .20]
#             end

#         elseif a == REJECT
#             probabilities = [1.0]
#             push!(next_states, State(s.c, s.t - 1, f′, v′))
#         end
#     else
#         push!(next_states,params.null_state)
#         probabilities = [1]
#     end
#     return SparseCat(next_states, probabilities)
# end

In [None]:
#***** ENUMERATING OVER ALL STATES ******

function T(s::State, a::Action)
    #c::Int # chairs remaining 
    #t::Int # time remaining 
    #f::Int # family size 
    #v::Int # visa status 
    next_states = State[]
    probabilities = Float64[] 
    
    if !validtime(s) 
        push!(next_states,params.null_state)
        push!(probabilities, 1) # double check 
    else
        if a == ACCEPT 
            next_state_accept = State(s.c - s.f, s.t - 1, 1, 1) # check if valid capacity 
            if !validcapacity(next_state_accept) 
                prob = [0,1] #no room for full family :( so we make probability 0 to accept and 1 reject
            else
                prob = params.accept_prob
            end
            for f in 1:length(params.family_sizes)
                for v in 1:length(params.visa_status)
                     #if get on plan
                    push!(next_states, State(s.c - s.f, s.t - 1, params.family_sizes[f], params.visa_status[v]))
                    push!(probabilities, prob[1] * params.visa_prob[v] * params.family_prob[f])
                     #if not
                    push!(next_states, State(s.c, s.t - 1, params.family_sizes[f], params.visa_status[v]))
                    push!(probabilities, prob[2] * params.visa_prob[v] * params.family_prob[f])                    
                end
            end
        else   #if reject     
            for f in 1:length(params.family_sizes)
                for v in 1:length(params.visa_status)
                    push!(next_states, State(s.c, s.t - 1, params.family_sizes[f], params.visa_status[v]))
                    push!(probabilities, params.reject_prob[1] * params.visa_prob[v] * params.family_prob[f])    
                end
            end  
        end
    end                
    normalize!(probabilities, 1)
    @assert sum(probabilities) ≈ 1
    return SparseCat(next_states, probabilities)
end      

In [None]:
#check = T(State(2,0,2,2), ACCEPT)
#check = T(State(0,10,2,2), ACCEPT) # im confused with what should be happening here....

**Reward Function**

In [None]:
function R(s::State, a::Action)
    # reward is just the visa status times family size i think! 
    if a == ACCEPT
        return s.v*s.f
    end
    return 0
end 

In [None]:
γ = 0.95

In [None]:
termination(s::State)= s == params.null_state # change to 1 or the other 

**MDP Formulation**

In [None]:
# We define the Airport abstract MDP type so we can reference it in other methods.
abstract type Evacuation <: MDP{State, Action} end

In [None]:
c_initial = params.capacity
t_initial = params.time
f_initial = rand(params.family_sizes, 1)[1]
v_initial = rand(params.visa_status, 1)[1]

initial_state = State(c_initial, t_initial, f_initial, v_initial)

In [None]:
statetype = typeof(initial_state)
initialstate_array = [initial_state]

In [None]:
mdp = QuickMDP(Evacuation,
    states       = 𝒮,
    actions      = 𝒜,
    transition   = T,
    reward       = R,
    discount     = γ,
    initialstate = initialstate_array, 
    isterminal   = termination,
    render       = render,
    statetype    = statetype 
    );

In [None]:
solver = ValueIterationSolver(max_iterations=30, belres=1e-6, verbose=true);

**Policy**

In [None]:
mdp_policy = solve(solver, mdp) 

**Basesline Policies**

In [None]:
"""AcceptAll"""
struct AcceptAll <: Policy end

# accept everyone until capacity is 0
function POMDPs.action(::AcceptAll, s::State)    # action(policy, state)
    return ACCEPT
end;

AcceptAll_policy = AcceptAll()

In [None]:
"""AMCITS"""
struct AMCITS <: Policy end

function POMDPs.action(::AMCITS, s::State)
    return (s.v == 2)  ? ACCEPT : REJECT
end;

AMCITS_policy = AMCITS()

In [None]:
"""SIV_AMCITS """
struct SIV_AMCITS <: Policy end

function POMDPs.action(::SIV_AMCITS, s::State)
    return (s.v == 2 || s.v == 1) ? ACCEPT : REJECT
end;
SIV_AMCITS_policy = SIV_AMCITS()

In [None]:
"""AfterThresholdAMCITS"""
# if want to change this need to make it a mutable struct 
@with_kw struct AfterThresholdAMCITS <: Policy
    threshold = 20 # could define this in parameters 
end

function POMDPs.action(policy::AfterThresholdAMCITS, s::State)
    if s.t <= 20 #policy.threshold
        return s.v == 5 ? ACCEPT : REJECT
    else
        return action(mdp_policy, s)
    end
end

SIV_AfterThresholdAMCITS_policy = AfterThresholdAMCITS()

In [None]:
"""BeforeThresholdAMCITS"""
@with_kw struct BeforeThresholdAMCITS <: Policy
    threshold = 20
end

function POMDPs.action(policy::BeforeThresholdAMCITS, s::State)
    if s.t >= 20 #policy.threshold
        return s.v == 5 ? ACCEPT : REJECT
    else
        return action(mdp_policy, s)
    end
end

        
BeforeThresholdAMCITS_policy = BeforeThresholdAMCITS()
#simulations(BeforeThresholdAMCITS_policy, mdp, 10)
# could play with changing this threshold


**Simulation**

In [None]:
# # for reference, this is what is happening in sim
# # b = initialize_belief(up, b0)

# r_total = 0.0
# d = 1.0
# while !isterminal(pomdp, s)
#     a = action(policy, b)
#     s, o, r = @gen(:sp,:o,:r)(pomdp, s, a) # gen is 
#     r_total += d*r
#     d *= discount(pomdp)
#     b = update(up, b, a, o)
# end

# gen is     
#     sp = rand(transition(pomdp, s, a))
#     o = rand(observation(pomdp, s, a, sp))
#     r = reward(pomdp, s, a, sp, o)
#     s = sp
# function simulation(policy, mdp)
#     sim = RolloutSimulator()
#     r = simulate(sim, mdp, policy) #accumulated discounted reward 
#     # could we also return the number of ppl on the plane? 
#     return r
# end
# sim w/ out history 
# function simulation(policy, mdp)
#     sim = RolloutSimulator()
#     r = simulate(sim, mdp, policy) #accumulated discounted reward 
#     # could we also return the number of ppl on the plane? 
#     return r
# end
# INTEGRATE HISTORY HERE https://juliapomdp.github.io/POMDPSimulators.jl/latest/histories/#Examples and use in stats
function simulation(policy, mdp)
    hr = HistoryRecorder()
    #sim = RolloutSimulator()
    history = simulate(hr, mdp, policy)
    #r = simulate(sim, mdp, policy) #accumulated discounted reward 
    # could we also return the number of ppl on the plane? 
    return history
end


In [None]:
mdp_history = simulation(mdp_policy, mdp)

In [None]:
function get_metrics(history)
    # THIS IS SO COOL AH 
    total_accepted_people = 0
    total_accepted_families = 0 
    total_rejected_people = 0
    total_rejected_families = 0
    total_reward = 0.0
    # Initialize visa_statuses dictionary
    visa_statuses = params.visa_status
    visa_dict_accepts = Dict()
    for v in visa_statuses
        visa_dict_accepts[v] = 0
    end
    visa_dict_rejects = Dict()
    for v in visa_statuses
        visa_dict_rejects[v] = 0
    end

    # State(c, t, f, v)
    for (s, a, r, sp) in eachstep(history, "(s, a, r, sp)") 
        # only counting the s not sp so as not to double count 
        if a==ACCEPT
            total_accepted_people += s.f
            total_accepted_families += 1
            visa_dict_accepts[s.v] += 1
        else # action is reject 
            total_rejected_people += 1
            total_rejected_families +=1
            visa_dict_rejects[s.v] += 1
        end

        total_reward += r
        return total_accepted_people, total_accepted_families, total_reward, visa_dict_accepts 
        println("reward $r received when state $sp was reached after action $a was taken in state $s")
    end
end

In [None]:
total_accepted_people, total_accepted_families, total_reward, visa_dict_accepts = get_metrics(mdp_history)

In [None]:
# function reward_simulations(policy, mdp, n_sims) # n is number of times to run 
#     policy_rewards = []
#     for i in 1:n_sims
#         push!(policy_rewards, simulation(policy, mdp))
#     end
#     std_policy_reward = std(policy_rewards)
#     mean_policy_reward = mean(policy_rewards)
#     return mean_policy_reward, std_policy_reward
# end

In [None]:
function reward_simulations(policy, mdp, n_sims) # n is number of times to run 
    histories = []
    for i in 1:n_sims
        push!(histories, simulation(policy, mdp))
    end
    
    list_total_accepted_people = []
    list_total_accepted_families = []
    list_total_reward = []
    list_visa_dict_accepts = []
    
    for history in histories
        total_accepted_people, total_accepted_families, total_reward, visa_dict_accepts = get_metrics(history)
        push!(list_total_accepted_people, total_accepted_people)
        push!(list_total_accepted_families, total_accepted_families)
        push!(list_total_reward, total_reward)
        push!(list_visa_dict_accepts, visa_dict_accepts)
    end
    
    mean_total_accepted_people = mean(list_total_accepted_people)
    std_list_total_accepted_people = std(list_total_accepted_people)
    mean_list_total_accepted_families = mean(list_total_accepted_families)
    std_list_total_accepted_families = std(list_total_accepted_families)
    mean_list_total_reward = mean(list_total_reward)
    std_list_total_reward = std(list_total_reward)
    #mean_list_visa_dict_accepts = mean(list_visa_dict_accepts) 
    #std_list_visa_dict_accepts  = std(list_visa_dict_accepts) 
    
end

In [None]:
reward_simulations(mdp_policy, mdp, 2)

In [None]:
function experiments()
    # policies and n_sims can probably be put in our params function as a list. here for now. 
    n_sims = 10
    policies = [AcceptAll_policy, AMCITS_policy, 
        SIV_AMCITS_policy, SIV_AfterThresholdAMCITS_policy, 
        BeforeThresholdAMCITS_policy]
    m_std_rewards = []
    for curr_policy in policies
        push!(m_std_rewards, reward_simulations(curr_policy, mdp, n_sims))
    end

    return m_std_rewards

end

In [None]:
mean_std_rewards = experiments()

**Visualizations**

In [None]:
function vis_time_step(policy, history, c, t) # pass in policy and chairs and time remaing. Spit out graph of family size versus visa status.
    (v_size, f_size) = params.size  #visa, family 5, 5
    policyGraph = zeros(v_size, f_size) 
    
    visa_statuses=params.visa_status
    family_sizes=params.family_sizes
    
#     index = 1
#     visatoindx_dict = Dict()
#     for v in visa_statuses
#         visatoindx_dict[v] = index
#         index += 1
#     end
    

# #    State(12, 8, 3, 1) (c, t, v, f)
#     for (s, a, r, sp) in eachstep(mdp_history, "(s, a, r, sp)") 
#         if (s.c == c && s.t == t)
#             if a == ACCEPT
#                 #println(s.v, s.f)
#                 policyGraph[visatoindx_dict[s.v], s.f] = 100
#             end
#         end
#     end
        
    for f in 1:f_size
        for v in 1:v_size
            act = action(policy, State(c, t, family_sizes[f], visa_statuses[v])) 
            if act == ACCEPT
                policyGraph[v,f] = 100
            else
                policyGraph[v,f] = 0
            end
        end
    end
    
    x=params.visa_status
    y=params.family_sizes
    z=policyGraph'
    title_time_cap="t = $t c = $c $policy" # MAKE title smaller 
    return heatmap(x, y, z, 
        aspect_ratio = :equal, 
        legend = :none, 
        xlims = (params.visa_status[1], 
        params.visa_status[length(params.visa_status)]), 
        xlabel = "Visa Status", 
        ylabel = "Family Size", 
        ylims = (params.family_sizes[1], 
        params.family_sizes[length(params.family_sizes)]), 
        title=title_time_cap, 
        xtickfont = font(5, "Courier"), 
        ytickfont = font(5, "Courier"),
        thickness_scaling = .5,
        palette = cgrad([:red, :green], [0, 1]),
        )
    
end

In [None]:
x = vis_time_step(mdp_policy, mdp_history, 10, 30)

In [None]:
function vis_all(policy)
    total_time = params.time 
    total_capacity = params.capacity
    graph_per_n = 4
    heat_maps = []
    time_points = (total_time/graph_per_n) + 1 # to include 0 
    capacity_points = (total_capacity/graph_per_n) + 1 
    num_graphs = trunc(Int, time_points*capacity_points)
    
    for t in 0:total_time
        if t % graph_per_n == 0 
            for c in 0:total_capacity
                if c % graph_per_n == 0
                push!(heat_maps, vis_time_step(policy, c, t))
                end
            end
        end
    end 
    plot((heat_maps...), layout = num_graphs, margin = 5mm)

end

In [None]:
vis_all(policy) # THIS IS VISIA;L

In [None]:
vis_all(AcceptAll_policy)

In [None]:
vis_all(AMCITS_policy) 

       

In [None]:
vis_all(SIV_AMCITS_policy)

In [None]:
vis_all(SIV_AfterThresholdAMCITS_policy)


In [None]:
vis_all(BeforeThresholdAMCITS_policy)

**Aggregate Metrics**

In [None]:
# function getmetrics(policy, mdp)
#     # I think this somehow needs to simulated in our rollout 
#     total_people = 0
#     total_accept = 0
#     total_reject = 0
#     total_accepted_people = 0 
    
#     # Initialize visa_statuses dictionary
#     visa_statuses = params.visa_status
#     visa_dict = Dict()
#     for v in visa_statuses
#         visa_dict[v] = 0
#     end
    
#     for c in 0:params.capacity # capacity ends at 0 
#         for t in 0:params.time # time ends at 0 
#             for f in params.family_sizes # family size here we should have the ACTUAL family sizes 
#                 for v in params.visa_status # actual visa statuses  
#                     state = State(c, t, f, v)
#                     total_people += f 
#                     if action(policy, state) == ACCEPT
#                         total_accept += 1
#                         total_accepted_people += f
#                         visa_dict[v] += 1
#                     else 
#                         total_reject += 1
#                     end
#                 end        
#             end
#         end
#     end
    
#     print("Total people: ", total_people, )
#     print("Total accept: ", total_accept)
#     print("Total reject: ", total_reject)
#     visa_dict 
       
# end 

In [None]:
# getmetrics(policy, mdp)

In [None]:
# function getmetrics(policy, mdp)
#     # I think this somehow needs to simulated in our rollout 
#     total_people = 0
#     total_accept = 0
#     total_reject = 0
#     total_accepted_people = 0 
    
#     # Initialize visa_statuses dictionary
#     visa_statuses = params.visa_status
#     visa_dict = Dict()
#     for v in visa_statuses
#         visa_dict[v] = 0
#     end
    
#     for c in 0:params.capacity # capacity ends at 0 
#         for t in 0:params.time # time ends at 0 
#             for f in params.family_sizes # family size here we should have the ACTUAL family sizes 
#                 for v in params.visa_status # actual visa statuses  
#                     state = State(c, t, f, v)
#                     total_people += f 
#                     if action(policy, state) == ACCEPT
#                         total_accept += 1
#                         total_accepted_people += f
#                         visa_dict[v] += 1
#                     else 
#                         total_reject += 1
#                     end
#                 end        
#             end
#         end
#     end
    
#     print("Total people: ", total_people, )
#     print("Total accept: ", total_accept)
#     print("Total reject: ", total_reject)
#     visa_dict 
       
# end 