In [2]:
#using COVIDIncidencePOMDPs
using POMCPOW
using POMDPModelTools
using Distributions
using POMDPSimulators
using POMDPs
using POMDPPolicies
using Random
using ARDESPOT
using ParticleFilters
using Statistics
using UUIDs
using MCTS
using MCVI
using Parameters
using POMCPOW
using Roots
using SpecialFunctions
using QuadGK
using LinearAlgebra

In [3]:
########################################################################
############################## PARAMETERS ##############################
########################################################################

@with_kw struct POMDPParams
    # Params from Ferretti et al
    MEAN_GEN_TIME::Float64 = 5.0
    STD_GEN_TIME::Float64 = 1.9
    EXP_GROWTH::Float64 = 0.14
    # Fixed modeling params
    POP_SIZE::Float64 = 8e6 
    CASE_THRESH::Float64 = POP_SIZE * 0.0043
    TOTAL_TIME::Int64 = 14 
    HOSP_TIME::Int64 = 8 
    INFECTION_SEED = 70
    INFECTION_MIN = 20
    R0_VARIANCE::Float64 = 0.03
    NUM_LEVELS::Int64 = 5
    # Dynamic modeling params
    TESTING_LAG::Tuple{Int64,Int64}
    PREVALENCE_THRESH::Float64
    OBS_DISP::Float64
    MC_STD::Float64
    # Solver paramaters
    OBS_ROUNDING::Int64 = -1
    TREE_DEPTH::Int64 = 21
    TREE_ITERS::Int64 = 100_000
    NUM_PARTICLES::Int64 = 1_000_000
    SOLVER_STEPS::Int64 = 100
    SOLVER::String = "pomcpow"
    DISCOUNT::Float64 = 0.99
    EXPLORATION_COEFF::Float64 = 1.0
    # Reward parameters
    SMOOTHNESS_MULT::Float64 = 0.0 #  * 1{prev = new}
    DROP_MULT::Float64 = 2.5 # * min(0, new - prev)
    PREVALENCE_MULT::Float64 = -1/(800. * POP_SIZE) # I(t) / N * 1/800
    OVERFLOW_REWARD::Float64 = -1e-2 
    # Resampling parameters
    NUM_UNIQUES_THRESH::Int64 = 100_000
    RESAMPLE_MULTIPLIER::Float64 = 5.0
    # MCVI-specific parameters
    OBS_BRANCH::Int64 = 8
    STATES_PER_BELIEF::Int64 = 500
    NUM_PRUNE_OBS::Int64 = 1000
    NUM_EVAL_BELIEF::Int64 = 5000
    NUM_OBS::Int64 = 50
end

const BAD_PARAMS = POMDPParams(TESTING_LAG=(7,10), PREVALENCE_THRESH=0.,
                        R0_VARIANCE=0.03, OBS_DISP=0.15, MC_STD=0.5)

const GOOD_PARAMS = POMDPParams(TESTING_LAG=(4, 5), PREVALENCE_THRESH=0.,
                        R0_VARIANCE=0.03, OBS_DISP=0.05, MC_STD=0.2)

const PARAM_MAPPING = Dict("good" => GOOD_PARAMS,
                           "bad" => BAD_PARAMS)

Dict{String,POMDPParams} with 2 entries:
  "bad"  => POMDPParams(5.0, 1.9, 0.14, 8.0e6, 34400.0, 14, 8, 70, 20, 0.03, 5,…
  "good" => POMDPParams(5.0, 1.9, 0.14, 8.0e6, 34400.0, 14, 8, 70, 20, 0.03, 5,…

In [4]:
########################################################################
########################## UTILITY FUNCTIONS ###########################
########################################################################

function get_weibull_params(mu, sigma)
    """
    Exploiting the identity:
      1/2 log(gamma(1+2/k)) - log(gamma(1+1/k)) = 1/2 log(mu^2 + sig^2) - log(mu)
    """
    f_targ = 0.5 * log(sigma^2 + mu^2) - log(mu)
    f = x -> 0.5 * loggamma(1+2/x) - loggamma(1+1/x) - f_targ
    # Find zero of f(), starting guess for k = 1e-4
    k = Roots.find_zero(f, 1e-4)
    lam = mu / gamma(1 + 1/k)
    return (k, lam)
end

function calc_R0(k, lam, r, T)
    integrand = t -> Distributions.pdf(Weibull(k, lam), t) * exp(-r * t)
    integral, err = QuadGK.quadgk(integrand, 0, T)  
    return 1 / integral
end

function growth_rate_for_R0(k, lam, R0, T)
    r_cont = Roots.find_zero(r -> calc_R0(k, lam, r, T) - R0, 1e-4)
    return exp(r_cont) - 1
end

function discrete_weibull(k, lam, total_time)
    integrand = t -> Distributions.pdf(Weibull(k, lam), t) 
    return map(t -> QuadGK.quadgk(integrand, t, t+1)[1], (0:total_time-1))
end

discrete_weibull (generic function with 1 method)

In [5]:
########################################################################
############################# POMDP MODEL ##############################
########################################################################

struct CovidState
   mobility_diffs::Array{Float64}
   incidences::Array{Float64}
   confounds::Array{Float64}
   prev_a::Int64
end

Base.:+(a::CovidState, b::CovidState) = CovidState(a.mobility_diffs + b.mobility_diffs,
                                                    a.incidences + b.incidences,
                                                    a.confounds + b.confounds,
                                                    a.prev_a + b.prev_a)
Base.:/(a::CovidState, b::Int64) = CovidState(a.mobility_diffs / b,
                                              a.incidences / b,
                                              a.confounds / b,
                                              a.prev_a / b)

struct CovidPOMDP <: POMDP{CovidState, Int64, Int64}
    base_R0::Float64
    gen_func::Array{Float64}
    initial_state::Array{Float64}
    params::POMDPParams
end

struct StateDistribution
    sum_integrand::Float64
    base_R0::Float64
    act::Int64
    prev_state::CovidState
    r0_var::Float64
end

######################################
# ------ TRANSITION FUNCTION --------#
######################################
function POMDPModelTools.rand(rng::AbstractRNG, sd::StateDistribution)
    R0_mult = Distributions.Gamma(1/sd.r0_var, sd.r0_var)
    mobility = cumsum(sd.prev_state.mobility_diffs)[sd.act]
    # mobility = sd.prev_state.mobility_diffs[sd.act]
    new_R0 = rand(rng, R0_mult) * sd.base_R0 * mobility
    incidence = rand(rng, Poisson(new_R0 * sd.sum_integrand))
    new_state = sd.prev_state.incidences[2:end]
    append!(new_state, incidence)
    return CovidState(sd.prev_state.mobility_diffs, new_state, 
                    sd.prev_state.confounds, sd.act)
end

function tf(m::CovidPOMDP, s::CovidState, a::Int64)
    int = dot(m.gen_func, reverse(s.incidences .+ m.params.INFECTION_MIN))
    return StateDistribution(int, m.base_R0, a, s, m.params.R0_VARIANCE)
end
POMDPs.transition(m::CovidPOMDP, s::CovidState, a::Int64)  = tf(m, s, a) 

######################################
# ------ OBSERVATION FUNCTION -------#
######################################
function obs(m::CovidPOMDP, a::Int64, sp::CovidState)
    mu = mean(sp.incidences[end-m.params.TESTING_LAG[2]:end-m.params.TESTING_LAG[1]]) 
    mu = mu * sp.confounds[a]
    #p = 1 - mu * m.params.OBS_DISP / (1 + mu * m.params.OBS_DISP) 
    #r = 1 / m.params.OBS_DISP
    p = 1 / (1 + m.params.OBS_DISP)
    r = mu / m.params.OBS_DISP
    return Distributions.NegativeBinomial(r+1e-4, min(p + 1e-4, 1.0-1e-4))
end
POMDPs.observation(m::CovidPOMDP, s::CovidState, a::Int64, sp::CovidState) = obs(m, a, sp)

######################################
# ------ REWARD FUNCTION -------#
######################################
function r(m::CovidPOMDP, s::CovidState, a::Int64, sp::CovidState)
    """
    Reward function:
        Reward = {new} + 1{new == old} + min(0, new - old) 
        - Encourages high new
        - Encourages consistency
        - Highly penalizes big drops
    """
    smoothness_term = m.params.SMOOTHNESS_MULT * abs(a - s.prev_a)
    drop_term = m.params.DROP_MULT * min(0, a - s.prev_a) 
    prevalence_term = m.params.PREVALENCE_MULT * sp.incidences[end]
    rew = a + smoothness_term + drop_term + prevalence_term
    if sum(sp.incidences[end-m.params.HOSP_TIME:end]) > m.params.CASE_THRESH
        rew = m.params.OVERFLOW_REWARD * sp.incidences[end]
    end
    return rew
end
POMDPs.reward(m::CovidPOMDP, s::CovidState, a::Int64, sp::CovidState) = r(m, s, a, sp) 

######################################
# ------ MISCELLANEOUS FUNCTIONS -----#
######################################
POMDPs.initialstate_distribution(m::CovidPOMDP) =
    InitialStateDistribution(m.initial_state, m.params.MC_STD, m.params.NUM_LEVELS)
POMDPs.actionindex(::CovidPOMDP, a::Int64) = a + 1
POMDPs.discount(m::CovidPOMDP) = m.params.DISCOUNT
POMDPs.actions(m::CovidPOMDP) = (1:m.params.NUM_LEVELS)

#######################################
# -- FIXED SOLVER (val estimation) -- #
#######################################
mutable struct FixedPolicy{RNG<:AbstractRNG, P<:Union{POMDP,MDP}, U<:Updater} <: Policy
    rng::RNG
    problem::P
    updater::U # set this to use a custom updater, by default it will be a void updater
end
# The constructor below should be used to create the policy so that the action space is initialized correctly
FixedPolicy(problem::Union{POMDP,MDP}; rng=FixedPolicy.GLOBAL_RNG,
             updater=NothingUpdater()) = FixedPolicy(rng, problem, updater)

## policy execution ##
POMDPPolicies.action(policy::FixedPolicy, s) = actions(policy.problem, s)[1] 
POMDPPolicies.action(policy::FixedPolicy, b::Nothing) = actions(policy.problem)[1]
POMDPPolicies.updater(policy::FixedPolicy) = policy.updater

mutable struct FixedSolver <: Solver
    rng::AbstractRNG
end
FixedSolver(;rng=Random.GLOBAL_RNG) = FixedSolver(rng)
POMDPPolicies.solve(solver::FixedSolver, problem::Union{POMDP,MDP}) = FixedPolicy(solver.rng, problem, NothingUpdater())

In [6]:
########################################################################
######################### CUSTOM INIT BELIEF ###########################
########################################################################


struct InitialBeliefDistribution
    incidences::Array{Float64}
    confound_std::Float64
    num_levels::Int64
end

function POMDPModelTools.rand(rng::AbstractRNG, ibd::InitialBeliefDistribution)
    mobility_diffs = rand(rng, Distributions.Dirichlet(ibd.num_levels, 1))
    incidences = ibd.incidences
    confounds = map(x->rand(rng, LogNormal(0, ibd.confound_std)), (1:ibd.num_levels))
    #confounds = ones(5)
    prev_a = 1
    return CovidState(mobility_diffs, incidences, confounds, prev_a)
end

### Custom Initial State Distribution
struct InitialStateDistribution
    incidences::Array{Float64}
    confound_std::Float64
    num_levels::Int64
end

function POMDPModelTools.rand(rng::AbstractRNG, isd::InitialStateDistribution)
    """
    gaps = rand(rng, Distributions.Dirichlet(4, 0.7)) * 0.7
    prepend!(gaps, 0.3)
    mobility_diffs = cumsum(gaps)
    """
    mobility_diffs = [0.3, 0.15, 0.15, 0.1, 0.3]
    incidences = isd.incidences
    confounds = map(x->rand(rng, LogNormal(0, isd.confound_std)), (1:isd.num_levels))
    prev_a = isd.num_levels
    return CovidState(mobility_diffs, incidences, confounds, prev_a)
end

function ParticleFilters.resample(r, bp::WeightedParticleBelief, pm::CovidPOMDP,
                                  rm::CovidPOMDP, b, a, o, rng)
    # Run the normal resampling procedure
    ps = Array{CovidState}(undef, r.n)
    ws = Array{Float64}(undef, r.n)
    mobility_diffs::Array{Array{Float64}} = []
    confounders::Array{Array{Float64}} = []
    for bpp in particles(bp)
        push!(mobility_diffs, bpp.mobility_diffs)
        push!(confounders, bpp.confounds)
    end
    unique_mobility_diffs = unique(mobility_diffs)
    unique_confounders = unique(confounders)

    # Logging
    resample_mobs = length(unique_mobility_diffs) < pm.params.NUM_UNIQUES_THRESH
    resample_cfs = length(unique_confounders) < pm.params.NUM_UNIQUES_THRESH / 10

    if !(resample_mobs | resample_cfs)
        return ParticleFilters.resample(LowVarianceResampler(r.n), bp, rng)
    end

    mobility_stdevs = ones(length(unique_mobility_diffs[1])) * 0.1
    confounder_stdevs = ones(length(unique_confounders[1])) * 0.1
    if length(unique_mobility_diffs) > 1
        mobility_stdevs = Statistics.std(unique_mobility_diffs)
    end
    if length(unique_confounders) > 1
        confounder_stdevs = Statistics.std(unique_confounders)
    end

    if resample_mobs
        println("Noising mobilities | std: ", round.(mobility_stdevs, digits=3))
    end
    if resample_cfs
        println("Noising confounds | std: ", round.(confounder_stdevs, digits=3))
    end

    for i in (1:r.n)
        curr_p = particle(bp, i)
        w = weight(bp, i)
        new_mob = curr_p.mobility_diffs
        new_confounds = curr_p.confounds
        if resample_mobs
            mus = clamp.(curr_p.mobility_diffs, 1e-2, 1 - 1e-2) 
            mus = mus / sum(mus) * length(mus) 
            mus = mus * pm.params.RESAMPLE_MULTIPLIER / mean(mobility_stdevs) 
            new_mob = rand(Dirichlet(mus)) 
        end
        if resample_cfs
            mus = curr_p.confounds
            new_confounds = clamp.(rand(MvNormal(mus, max.(confounder_stdevs.^2, 0.01))), 0.5, 2.0)
        end
        ps[i] = CovidState(new_mob, curr_p.incidences, new_confounds, curr_p.prev_a) 
        ws[i] = weight(bp, i)
    end

    new_bp = WeightedParticleBelief(ps, ws)
    return ParticleFilters.resample(LowVarianceResampler(r.n), new_bp, rng)
end

In [20]:
########################################################################
########################### SOLVER TESTING #############################
########################################################################

p = PARAM_MAPPING["good"]

k, lam = get_weibull_params(p.MEAN_GEN_TIME, p.STD_GEN_TIME)
gen_func = discrete_weibull(k, lam, p.TOTAL_TIME)
base_R0 = calc_R0(k, lam, p.EXP_GROWTH, p.TOTAL_TIME)
initial_incidences = ones(Float64, p.TOTAL_TIME) * p.INFECTION_SEED
#initial_belief = InitialBeliefDistribution(initial_incidences, p.NUM_LEVELS)
initial_belief = InitialBeliefDistribution(initial_incidences, 10.0, p.NUM_LEVELS)

covid_pomdp = CovidPOMDP(base_R0, gen_func, initial_incidences, p)

solver = POMCPOWSolver(max_depth=20, tree_queries=1,
    k_observation=6.0, alpha_observation=1/20., enable_action_pw=false)#,
    #estimate_value=RolloutEstimator(FixedSolver()))

planner = solve(solver, covid_pomdp)
updater = SIRParticleFilter(covid_pomdp, 100)

#stepper = stepthrough(covid_pomdp, planner, updator, initial_belief, "b,s,sp,a,o,r", max_steps=200)

#sim = RolloutSimulator()
sim = HistoryRecorder(max_steps=300)
stepper = simulate(sim, covid_pomdp, planner, updater, initial_belief)

for (s, b, a, r, sp, o) in stepper
    @show s, b, a, r, sp, o
    #@show r
end


Noising mobilities | std: [0.174, 0.179, 0.149, 0.18, 0.194]
Noising confounds | std: [5.6246783589527806e14, 8.2176021197702e10, 1.4428147820011e10, 1.23238631695029e11, 1.88925340484803e11]
Noising mobilities | std: [0.213, 0.113, 0.051, 0.032, 0.279]
Noising confounds | std: [0.866, 0.0, 0.866, 0.0, 0.866]
Noising mobilities | std: [0.129, 0.101, 0.051, 0.025, 0.203]
Noising confounds | std: [0.429, 0.005, 0.441, 0.006, 0.615]
Noising mobilities | std: [0.129, 0.102, 0.058, 0.028, 0.202]
Noising confounds | std: [0.425, 0.006, 0.403, 0.009, 0.566]
Noising mobilities | std: [0.131, 0.106, 0.063, 0.03, 0.202]
Noising confounds | std: [0.391, 0.008, 0.365, 0.011, 0.556]
Noising mobilities | std: [0.089, 0.074, 0.064, 0.036, 0.145]
Noising confounds | std: [0.345, 0.01, 0.306, 0.011, 0.476]
Noising mobilities | std: [0.062, 0.031, 0.036, 0.007, 0.022]
Noising confounds | std: [0.15, 0.006, 0.168, 0.015, 0.17]
Noising mobilities | std: [0.059, 0.027, 0.029, 0.012, 0.024]
Noising confound

Noising confounds | std: [0.1, 0.1, 0.1, 0.1, 0.1]
Noising mobilities | std: [0.1, 0.1, 0.1, 0.1, 0.1]
Noising confounds | std: [0.1, 0.1, 0.1, 0.1, 0.1]


100-element SimHistory{NamedTuple{(:s, :a, :sp, :o, :r, :t, :action_info, :b, :bp, :update_info),Tuple{CovidState,Int64,CovidState,Int64,Float64,Int64,Dict{Symbol,Any},ParticleCollection{CovidState},ParticleCollection{CovidState},Nothing}},Float64}:
 (s = CovidState([0.3, 0.15, 0.15, 0.1, 0.3], [70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0], [1.0960341125813755, 0.9359674632933408, 0.8451108828158969, 1.2545914072306936, 1.0995891915236868], 5), a = 2, sp = CovidState([0.3, 0.15, 0.15, 0.1, 0.3], [70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 70.0, 93.0], [1.0960341125813755, 0.9359674632933408, 0.8451108828158969, 1.2545914072306936, 1.0995891915236868], 2), o = 71, r = -5.50000001453125, t = 1, action_info = Dict(:tree_queries => 0,:search_time_us => 0x0000000000000072), b = ParticleCollection{CovidState}(CovidState[CovidState([0.3680004980347918, 0.1988598813951622, 0.07181537422289604, 0.07072831760361745, 0.2905959287