In [None]:
using POMDPs
using Random # for AbstractRNG
using POMDPModelTools # for Deterministic
using QMDP, DiscreteValueIteration
using POMDPSolve
using FIB # For the solver
using BasicPOMCP # For the solver
using POMDPPolicies # For creating a random policy

In [None]:
#define struct for covid POMDP
struct CovidPOMDP <: POMDP{Tuple{Float64,Float64,Float64,Float64}, Tuple{Float64,Float64}, Tuple{Float64,Float64}}
    inf_r::Float64 #natural infection rate
    rec_r::Float64 #natural recovery rate
    sus_r::Float64 #natural susceptibility rate
    dea_r::Float64 #natural death rate
    fp_r::Float64 #false pos rate
    fn_r::Float64 #false neg rate
    discount::Float64 #discount factor
end

CovidPOMDP() = CovidPOMDP(0.01, 0.2, 0.2, 0.02, 0.05, 0.05, 0.9); #instance of struct


In [None]:
#generative function see here for explanation: 
# https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Generative-Interface.ipynb

function POMDPs.gen(m::CovidPOMDP, s, a, rng)
    #transition
    (su,i,rec,d) = s
    (sd_r,test_r) = a
    _su = su*(1- m.inf_r*(1-sd_r)) + rec*m.sus_r
    _i = i*(1-(m.dea_r*exp(i))-m.inf_r)+su*m.inf_r*(1-sd_r)
    _rec = rec*(1-m.dea_r)+i*m.rec_r
    _d = d + i*m.dea_r*exp(i)
    sp = (_su,_i,_rec,_d)
    
    #find reward
    r = -10*d-(su+i+rec)*((1/(1-sd_r)))+test_r-1)
    
    #observation update
    o = (0.8,0.8)
    
    return (sp=sp, o=o, r=r)
end

In [None]:
#additional methods needed for solvers see below for Tiger example:
#https://github.com/JuliaPOMDP/POMDPModels.jl/blob/master/src/TigerPOMDPs.jl
POMDPs.initialstate_distribution(m::CovidPOMDP) = Deterministic((1.0,0.0,0.0,0.0))
#POMDPs.actions(::CovidPOMDP) = ([0., 1.],[0.,1.])
#POMDPs.states(::CovidPOMDP) = ([0., 1.],[0.,1.],[0., 1.],[0.,1.])
#POMDPs.observations(::CovidPOMDP) = ([0., 1.],[0.,1.])
#POMDPs.discount(::CovidPOMDP) = 0.9

In [None]:
#simulate static policy
using POMDPSimulators
using POMDPPolicies

m = CovidPOMDP()

# policy that maps every input to a feed (true) action
policy = FunctionPolicy(o->(0.1,0.0))

for (s, a, r) in stepthrough(m, policy, "s,a,r", max_steps=50)
    @show s
    @show a
    @show r
    println()
end

In [None]:
#Offline Solvers
pomdp = CovidPOMDP()

solver = QMDPSolver()
policy = solve(solver, pomdp)

#solver = POMDPSolveSolver()
#solve(solver, pomdp) # returns an AlphaVectorPolicy
#solver = FIBSolver()
#policy = solve(solver, pomdp)