In [1]:
#using COVIDIncidencePOMDPs
using POMCPOW
using POMDPModelTools
using Distributions
using POMDPSimulators
using POMDPs
using POMDPPolicies
using Random
using ARDESPOT
using ParticleFilters
using Statistics
using UUIDs
using MCTS
using MCVI
using Parameters
using POMCPOW
using Roots
using SpecialFunctions
using QuadGK
using LinearAlgebra

In [2]:
########################################################################
############################## PARAMETERS ##############################
########################################################################

@with_kw struct POMDPParams
    # Params from Ferretti et al
    MEAN_GEN_TIME::Float64 = 5.0
    STD_GEN_TIME::Float64 = 1.9
    MEAN_INC_TIME::Float64 = 1.644
    STD_INC_TIME::Float64 = 0.132
    EXP_GROWTH::Float64 = 0.14
    # Fixed modeling params
    DEATH_RATE::Float64 = 0.005
    POP_SIZE::Float64 = 8e6 
    CASE_THRESH::Float64 = POP_SIZE * 0.0043
    TOTAL_TIME::Int64 = 14 
    HOSP_TIME::Int64 = 8 
    INFECTION_SEED = 70
    INFECTION_MIN = 20
    R0_VARIANCE::Float64 = 0.03
    NUM_LEVELS::Int64 = 125
    # Dynamic modeling params
    TESTING_LAG::Tuple{Int64,Int64}
    PREVALENCE_THRESH::Float64
    OBS_DISP::Float64
    MC_STD::Float64
    FALSE_POS::Float64
    FALSE_NEG::Float64
    # Solver paramaters
    OBS_ROUNDING::Int64 = -1
    TREE_DEPTH::Int64 = 21
    TREE_ITERS::Int64 = 100_000
    NUM_PARTICLES::Int64 = 1000_000
    SOLVER_STEPS::Int64 = 100
    SOLVER::String = "pomcpow"
    DISCOUNT::Float64 = 0.99
    EXPLORATION_COEFF::Float64 = 1.0
    # Reward parameters
    SMOOTHNESS_MULT::Float64 = 0.0 #  * 1{prev = new}
    DROP_MULT::Float64 = 2.5 # * min(0, new - prev)
    PREVALENCE_MULT::Float64 = -1/(POP_SIZE) # I(t) / N * 1/800
    OVERFLOW_REWARD::Float64 = -1e-2 
    DEATH_MULT::Float64 = -1e-2
    # Resampling parameters
    NUM_UNIQUES_THRESH::Int64 = 100_000
    RESAMPLE_MULTIPLIER::Float64 = 5.0
    # MCVI-specific parameters
    OBS_BRANCH::Int64 = 8
    STATES_PER_BELIEF::Int64 = 500
    NUM_PRUNE_OBS::Int64 = 1000
    NUM_EVAL_BELIEF::Int64 = 5000
    NUM_OBS::Int64 = 50
end

const BAD_PARAMS = POMDPParams(TESTING_LAG=(7,10), PREVALENCE_THRESH=0.,
                        R0_VARIANCE=0.03, OBS_DISP=0.15, MC_STD=0.5,FALSE_POS = 0.05, FALSE_NEG = 0.05)

const GOOD_PARAMS = POMDPParams(TESTING_LAG=(4, 5), PREVALENCE_THRESH=0.,
                        R0_VARIANCE=0.03, OBS_DISP=0.05, MC_STD=0.2, FALSE_POS = 0.01, FALSE_NEG = 0.01)

const PARAM_MAPPING = Dict("good" => GOOD_PARAMS,
                           "bad" => BAD_PARAMS)

Dict{String,POMDPParams} with 2 entries:
  "bad"  => POMDPParams(5.0, 1.9, 1.644, 0.132, 0.14, 0.005, 8.0e6, 34400.0, 14…
  "good" => POMDPParams(5.0, 1.9, 1.644, 0.132, 0.14, 0.005, 8.0e6, 34400.0, 14…

In [3]:
########################################################################
########################## UTILITY FUNCTIONS ###########################
########################################################################

function get_weibull_params(mu, sigma)
    """
    Exploiting the identity:
      1/2 log(gamma(1+2/k)) - log(gamma(1+1/k)) = 1/2 log(mu^2 + sig^2) - log(mu)
    """
    f_targ = 0.5 * log(sigma^2 + mu^2) - log(mu)
    f = x -> 0.5 * loggamma(1+2/x) - loggamma(1+1/x) - f_targ
    # Find zero of f(), starting guess for k = 1e-4
    k = Roots.find_zero(f, 1e-4)
    lam = mu / gamma(1 + 1/k)
    return (k, lam)
end

function calc_R0(k, lam, r, T)
    integrand = t -> Distributions.pdf(Weibull(k, lam), t) * exp(-r * t)
    integral, err = QuadGK.quadgk(integrand, 0, T)  
    return 1 / integral
end

function growth_rate_for_R0(k, lam, R0, T)
    r_cont = Roots.find_zero(r -> calc_R0(k, lam, r, T) - R0, 1e-4)
    return exp(r_cont) - 1
end

function discrete_weibull(k, lam, total_time)
    integrand = t -> Distributions.pdf(Weibull(k, lam), t) 
    return map(t -> QuadGK.quadgk(integrand, t, t+1)[1], (0:total_time-1))
end

function discrete_lognormal(μ,σ,total_time)
    integrand = t -> Distributions.pdf(LogNormal(μ,σ),t)
    return map(t -> QuadGK.quadgk(integrand, t, t+t)[1], (0:total_time-1))
end


discrete_lognormal (generic function with 1 method)

In [4]:
########################################################################
############################# POMDP MODEL ##############################
########################################################################

struct CovidState
   mobility_diffs::Array{Float64}
   incidences::Array{Float64}
   deaths::Float64
   confounds::Array{Float64}
   quarantine::Array{Float64}
   prev_a::Int64
end

Base.:+(a::CovidState, b::CovidState) = CovidState(a.mobility_diffs + b.mobility_diffs,
                                                    a.incidences + b.incidences,
                                                    a.deaths + b.deaths,
                                                    a.confounds + b.confounds,
                                                    a.quarantine + b.quarantine,
                                                    a.prev_a + b.prev_a)
Base.:/(a::CovidState, b::Int64) = CovidState(a.mobility_diffs / b,
                                              a.incidences / b,
                                              a.deaths / b,
                                              a.confounds / b,
                                              a.quarantine / b,
                                              a.prev_a / b)

struct CovidPOMDP <: POMDP{CovidState, Int64, Int64}
    base_R0::Float64
    gen_func::Array{Float64}
    inc_func::Array{Float64}
    initial_state::Array{Float64}
    initial_deaths::Float64
    params::POMDPParams
end

struct StateDistribution
    sum_integrand::Float64
    mort_rate::Float64
    base_R0::Float64
    act::Int64
    prev_state::CovidState
    r0_var::Float64
end

######################################
# ------ TRANSITION FUNCTION --------#
######################################
function POMDPModelTools.rand(rng::AbstractRNG, sd::StateDistribution)
    R0_mult = Distributions.Gamma(1/sd.r0_var, sd.r0_var)
    a_m = Int(mod(mod(sd.act-1,25),5)+1)
    mobility = sd.prev_state.mobility_diffs[a_m]
    # mobility = sd.prev_state.mobility_diffs[sd.act]
    new_R0 = rand(rng, R0_mult) * sd.base_R0 * mobility
    incidence = rand(rng, Poisson(new_R0 * sd.sum_integrand))
    new_state = sd.prev_state.incidences[2:end]
    
    #Add deaths to the model
    djk = map(x-> rand(rng, Poisson(new_state[x]*sd.mort_rate)),(1:length(new_state)))
    new_deaths = sum(d for d in djk)
    new_deaths += sd.prev_state.deaths
    new_state = new_state .- djk
    append!(new_state, incidence)
    return CovidState(sd.prev_state.mobility_diffs, new_state, new_deaths,
                    sd.prev_state.confounds, sd.prev_state.quarantine, sd.act)
end

function tf(m::CovidPOMDP, s::CovidState, a::Int64)
    aq = Int((mod(a-1,25) - mod(mod(a-1,25),5))/5 + 1)
    ϵ_q = s.quarantine[aq];
    int = dot(m.gen_func.*(1 .- ϵ_q .+ ϵ_q.*m.inc_func), reverse(s.incidences .+ m.params.INFECTION_MIN))
    return StateDistribution(int,m.params.DEATH_RATE, m.base_R0, a, s, m.params.R0_VARIANCE)
end
POMDPs.transition(m::CovidPOMDP, s::CovidState, a::Int64)  = tf(m, s, a) 

######################################
# ------ OBSERVATION FUNCTION -------#
######################################
function obs(m::CovidPOMDP, a::Int64, sp::CovidState)
    at = Int((a - 1 - mod(a-1,25))/25 + 1)
    ϵ_t = sp.confounds[at]
    mu_sym = (1-m.params.FALSE_NEG)*mean(reverse(sp.incidences[1:end-m.params.TESTING_LAG[1]]).*(1.0 .- m.inc_func[m.params.TESTING_LAG[1]+1:end]))
    #@show sum(sp.incidences)
    mu_asym = (1-m.params.FALSE_NEG)*ϵ_t*dot(reverse(sp.incidences[end-m.params.TESTING_LAG[2]:end-m.params.TESTING_LAG[1]]),m.inc_func[m.params.TESTING_LAG[1]+1:m.params.TESTING_LAG[2]+1])
    mu_healthy = m.params.FALSE_POS*ϵ_t*(m.params.POP_SIZE - sp.deaths - sum(sp.incidences))
    mu = mu_sym + mu_asym + mu_healthy
    #mu = mean(sp.incidences[end-m.params.TESTING_LAG[2]:end-m.params.TESTING_LAG[1]]) 
    #mu = mu * sp.confounds[a]
    p = 1-(mu * m.params.OBS_DISP) / (1 + mu * m.params.OBS_DISP) 
    r = 1 / m.params.OBS_DISP
    #p = 1 / (1 + m.params.OBS_DISP)
    #r = mu / m.params.OBS_DISP
#    @show mean(Distributions.NegativeBinomial(r+1e-4, min(p + 1e-4, 1.0-1e-4)))
    #return Distributions.NegativeBinomial(r+1e-4, min(p + 1e-4, 1.0-1e-4))
    return Distributions.NegativeBinomial(r, p)
end
POMDPs.observation(m::CovidPOMDP, s::CovidState, a::Int64, sp::CovidState) = obs(m, a, sp)

######################################
# ------ REWARD FUNCTION -------#
######################################
function r(m::CovidPOMDP, s::CovidState, a::Int64, sp::CovidState)
    """
    Reward function:
        Reward = {new} + 1{new == old} + min(0, new - old) 
        - Encourages high new
        - Encourages consistency
        - Highly penalizes big drops
    """
    at = Int((a - 1 - mod(a-1,25))/25 + 1)
    am = Int(mod(mod(a-1,25),5)+1)
    aq = Int((mod(a-1,25) - mod(mod(a-1,25),5))/5 + 1)
    at_prev = Int((s.prev_a - 1 - mod(s.prev_a-1,25))/25 + 1)
    am_prev = Int(mod(mod(s.prev_a-1,25),5)+1)
    aq_prev = Int((mod(s.prev_a-1,25) - mod(mod(s.prev_a-1,25),5))/5 + 1)
    Δ_deaths = sp.deaths-s.deaths
    smoothness_term = m.params.SMOOTHNESS_MULT *(abs(at - at_prev) + abs(am - am_prev) + abs(aq - aq_prev))
    drop_term = m.params.DROP_MULT * (max(at - at_prev,0) + max(am - am_prev,0) + max(aq - aq_prev,0)) 
    prevalence_term = m.params.PREVALENCE_MULT * sp.incidences[end]
    death_term = m.params.DEATH_MULT * Δ_deaths
    rew = (am - at - aq) + smoothness_term + drop_term + prevalence_term + death_term
    if sum(sp.incidences[end-m.params.HOSP_TIME:end]) > m.params.CASE_THRESH
        rew = m.params.OVERFLOW_REWARD * sp.incidences[end]
    end
#    @show rew
    return rew
end
POMDPs.reward(m::CovidPOMDP, s::CovidState, a::Int64, sp::CovidState) = r(m, s, a, sp) 

######################################
# ------ MISCELLANEOUS FUNCTIONS -----#
######################################
POMDPs.initialstate_distribution(m::CovidPOMDP) =
    InitialStateDistribution(m.initial_state, m.initial_deaths, m.params.MC_STD, m.params.NUM_LEVELS)
#POMDPs.actionindex(::CovidPOMDP, a::Array{Int64}) = sum((a .+ 2).*[1 2 4])+1
POMDPs.actionindex(::CovidPOMDP, a::Array{Int64}) = a + 1
POMDPs.discount(m::CovidPOMDP) = m.params.DISCOUNT
A = []
for a in -1:1,b in -1:1, c in -1:1
    push!(A,[a b c])
end
#POMDPs.actions(m::CovidPOMDP) = A
POMDPs.actions(m::CovidPOMDP) = (1:m.params.NUM_LEVELS)

#######################################
# -- FIXED SOLVER (val estimation) -- #
#######################################
mutable struct FixedPolicy{RNG<:AbstractRNG, P<:Union{POMDP,MDP}, U<:Updater} <: Policy
    rng::RNG
    problem::P
    updater::U # set this to use a custom updater, by default it will be a void updater
end
# The constructor below should be used to create the policy so that the action space is initialized correctly
FixedPolicy(problem::Union{POMDP,MDP}; rng=FixedPolicy.GLOBAL_RNG,
             updater=NothingUpdater()) = FixedPolicy(rng, problem, updater)

## policy execution ##
POMDPPolicies.action(policy::FixedPolicy, s) = actions(policy.problem, s)[1] 
POMDPPolicies.action(policy::FixedPolicy, b::Nothing) = actions(policy.problem)[1]
POMDPPolicies.updater(policy::FixedPolicy) = policy.updater

mutable struct FixedSolver <: Solver
    rng::AbstractRNG
end
FixedSolver(;rng=Random.GLOBAL_RNG) = FixedSolver(rng)
POMDPPolicies.solve(solver::FixedSolver, problem::Union{POMDP,MDP}) = FixedPolicy(solver.rng, problem, NothingUpdater())

In [5]:
########################################################################
######################### CUSTOM INIT BELIEF ###########################
########################################################################


struct InitialBeliefDistribution
    incidences::Array{Float64}
    deaths::Float64
    confound_std::Float64
    num_levels::Int64
end

function POMDPModelTools.rand(rng::AbstractRNG, ibd::InitialBeliefDistribution)
    mobility_diffs = rand(rng, Distributions.Dirichlet(5, .5))
    incidences = ibd.incidences
    deaths = ibd.deaths
    confounds = map(x->rand(rng, LogNormal(-5, ibd.confound_std)), (1:ibd.num_levels))
    quarantine = rand(rng, Distributions.Dirichlet(5, .5))
    #confounds = ones(5)
    prev_a = 1
    return CovidState(mobility_diffs, incidences, deaths, confounds, quarantine, prev_a)
end

### Custom Initial State Distribution
struct InitialStateDistribution
    incidences::Array{Float64}
    deaths::Float64
    confound_std::Float64
    num_levels::Int64
end

function POMDPModelTools.rand(rng::AbstractRNG, isd::InitialStateDistribution)
    """
    gaps = rand(rng, Distributions.Dirichlet(4, 0.7)) * 0.7
    prepend!(gaps, 0.3)
    mobility_diffs = cumsum(gaps)
    """
    #mobility_diffs = [0.3, 0.15, 0.15, 0.1, 0.3]
    mobility_diffs = LinRange(0,0.8,5)
    incidences = isd.incidences
    deaths = isd.deaths
    confounds = LinRange(0.0,0.01,5)
    quarantine = LinRange(0,0.8,5)
    #confounds = map(x->rand(rng, LogNormal(0, isd.confound_std)), (1:isd.num_levels))
    prev_a = isd.num_levels
    return CovidState(mobility_diffs, incidences, deaths, confounds, quarantine, prev_a)
end
#=
function ParticleFilters.resample(r, bp::WeightedParticleBelief, pm::CovidPOMDP,
                                  rm::CovidPOMDP, b, a, o, rng)
    # Run the normal resampling procedure
    ps = Array{CovidState}(undef, r.n)
    ws = Array{Float64}(undef, r.n)
    mobility_diffs::Array{Array{Float64}} = []
    confounders::Array{Array{Float64}} = []
    for bpp in particles(bp)
        push!(mobility_diffs, bpp.mobility_diffs)
        push!(confounders, bpp.confounds)
    end
    unique_mobility_diffs = unique(mobility_diffs)
    unique_confounders = unique(confounders)

    # Logging
    resample_mobs = length(unique_mobility_diffs) < pm.params.NUM_UNIQUES_THRESH
    resample_cfs = length(unique_confounders) < pm.params.NUM_UNIQUES_THRESH / 10

    if !(resample_mobs | resample_cfs)
        return ParticleFilters.resample(LowVarianceResampler(r.n), bp, rng)
    end

    mobility_stdevs = ones(length(unique_mobility_diffs[1])) * 0.1
    confounder_stdevs = ones(length(unique_confounders[1])) * 0.1
    if length(unique_mobility_diffs) > 1
        mobility_stdevs = Statistics.std(unique_mobility_diffs)
    end
    if length(unique_confounders) > 1
        confounder_stdevs = Statistics.std(unique_confounders)
    end

    if resample_mobs
        println("Noising mobilities | std: ", round.(mobility_stdevs, digits=3))
    end
    if resample_cfs
        println("Noising confounds | std: ", round.(confounder_stdevs, digits=3))
    end

    for i in (1:r.n)
        curr_p = particle(bp, i)
        w = weight(bp, i)
        new_mob = curr_p.mobility_diffs
        new_confounds = curr_p.confounds
        if resample_mobs
            mus = clamp.(curr_p.mobility_diffs, 1e-2, 1 - 1e-2) 
            mus = mus / sum(mus) * length(mus) 
            mus = mus * pm.params.RESAMPLE_MULTIPLIER / mean(mobility_stdevs) 
            new_mob = rand(Dirichlet(mus)) 
        end
        if resample_cfs
            mus = curr_p.confounds
            new_confounds = clamp.(rand(MvNormal(mus, max.(confounder_stdevs.^2, 0.01))), 0.5, 2.0)
        end
        ps[i] = CovidState(new_mob, curr_p.incidences, curr_p.deaths, new_confounds, curr_p.prev_a) 
        ws[i] = weight(bp, i)
    end

    new_bp = WeightedParticleBelief(ps, ws)
    return ParticleFilters.resample(LowVarianceResampler(r.n), new_bp, rng)
end
=#

In [23]:
########################################################################
########################### SOLVER TESTING #############################
########################################################################

p = PARAM_MAPPING["good"]

k, lam = get_weibull_params(p.MEAN_GEN_TIME, p.STD_GEN_TIME)
gen_func = discrete_weibull(k, lam, p.TOTAL_TIME)
inc_func = discrete_lognormal(p.MEAN_INC_TIME,p.STD_INC_TIME,p.TOTAL_TIME)
base_R0 = calc_R0(k, lam, p.EXP_GROWTH, p.TOTAL_TIME)
initial_incidences = ones(Float64, p.TOTAL_TIME) * p.INFECTION_SEED
initial_deaths = 0.0
#initial_belief = InitialBeliefDistribution(initial_incidences, p.NUM_LEVELS)
initial_belief = InitialBeliefDistribution(initial_incidences, initial_deaths, 0.10, p.NUM_LEVELS)

covid_pomdp = CovidPOMDP(base_R0, gen_func, inc_func, initial_incidences, initial_deaths, p)

solver = POMCPOWSolver(max_depth=20, tree_queries=1,
    k_observation=6.0, alpha_observation=1/20., enable_action_pw=false)#,
    #estimate_value=RolloutEstimator(FixedSolver()))

planner = solve(solver, covid_pomdp)
updater = SIRParticleFilter(covid_pomdp, 100)

#stepper = stepthrough(covid_pomdp, planner, updator, initial_belief, "b,s,sp,a,o,r", max_steps=200)

#sim = RolloutSimulator()
sim = HistoryRecorder(max_steps=500)
stepper = simulate(sim, covid_pomdp, planner, updater, initial_belief)

for (s, a, sp, o, r, t) in stepper
#    @show s, b, r, sp, o
#    @show a
#    @show a
#    @show sum(s.incidences)
#    @show o
#    @show s.confounds[s.prev_a]
#    @show s.deaths
#    @show r
end


In [24]:
act_vec = vec(collect((a, b, c) for a in 1:1:5, b in 1:1:5, c in 1:1:5))
open("results.txt", "w") do io
write(io, "deaths sum_inc inc_end mob_a con_a q_a rew\n")
    for (s, a, sp, o, r, t) in stepper
        deaths = s.deaths
        mob_a = act_vec[a][1]
        con_a = act_vec[a][2]
        q_a = act_vec[a][3]
        write(io, "$(deaths) $(sum(s.incidences)) $(s.incidences[end]) $(mob_a) $(con_a) $(q_a) $(r)\n")
    end
end;