In [17]:
using POMDPs
using Random # for AbstractRNG
using POMDPModelTools # for Deterministic
using POMDPSolve
using BasicPOMCP # For the solver
using POMDPPolicies # For creating a random policy
using POMCPOW
import POMDPs.observation
using Distributions
using BeliefUpdaters
using ParticleFilters

In [18]:
#define struct for covid POMDP
struct CovidPOMDP <: POMDP{Tuple{Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64,Float64}, Tuple{Float64,Float64}, Tuple{Int64,Int64}}
    inf_r::Float64 #natural infection rate
    rec_r::Float64 #natural recovery rate
    sus_r::Float64 #natural susceptibility rate
    dea_r::Float64 #natural death rate
    fp_r::Float64 #false pos rate
    fn_r::Float64 #false neg rate
    discount::Float64 #discount factor
end

CovidPOMDP() = CovidPOMDP(0.01, 0.2, 0.2, 0.02, 0.05, 0.05, 0.9); #instance of struct


In [19]:
#generative function see here for explanation: 
# https://github.com/JuliaPOMDP/POMDPExamples.jl/blob/master/notebooks/Defining-a-POMDP-with-the-Generative-Interface.ipynb

function POMDPs.gen(m::CovidPOMDP, s, a, rng)
    #transition
    #(su,i,rec,d) = s
    #(sd_r,test_r) = a
    #_su = su*(1- m.inf_r*(1-sd_r)) + rec*m.sus_r
    #_i = i*(1-(m.dea_r*exp(i))-m.inf_r)+su*m.inf_r*(1-sd_r)
    #_rec = rec*(1-m.dea_r)+i*m.rec_r
    #_d = d + i*m.dea_r*exp(i)
    #sp = (_su,_i,_rec,_d)
    
    #(S,I0,D,Es,Eq,Et,Esd,Eqd,Etd) = s
    
    #es = (2/pi)
    
    #Es_ = max(Es+Esd,0)
    #Eq_ = max(Eq+Eqd,0)
    #Ed_ = max(Ed+Edd,0)
    
    sp = (0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0)
    
    #find reward
    #r = -10*d-(su+i+rec)*((1/(1-sd_r)))+test_r-1
    #observation update
    #o = (0.0,0.0)
    r = 0
    
    return (sp=sp, r=r)
end

In [20]:
POMDPs.discount(m::CovidPOMDP) = m.discount
POMDPs.reward(m::CovidPOMDP) = 0

In [25]:
POMDPs.actions(m::CovidPOMDP) = ([-1,0,1],[-1,0,1])
@show actions(m)
POMDPs.initialstate_distribution(m::CovidPOMDP) = Deterministic((0.8,0.2,0.0,0.0,0.0,0.0,0.0,0.0,0.0))

actions(m) = ([-1, 0, 1], [-1, 0, 1])


In [26]:
d = Normal()
function POMDPs.observation(::CovidPOMDP, d, ::Any, ::Any)
    return d
end

In [27]:
#simulate static policy
using POMDPSimulators
using POMDPPolicies

m = CovidPOMDP()

# policy that maps every input to a feed (true) action
policy = FunctionPolicy(o->(-1,1))

for (s, a, r, o) in stepthrough(m, policy, "s,a,r,o", max_steps=5)
    @show s
    @show a
    @show r
    @show o
    println()
end

s = (0.8, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
a = (-1, 1)
r = 0
o = 0.0

s = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
a = (-1, 1)
r = 0
o = 0.0

s = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
a = (-1, 1)
r = 0
o = 0.0

s = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
a = (-1, 1)
r = 0
o = 0.0

s = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
a = (-1, 1)
r = 0
o = 0.0



In [28]:
#solvers
d = Deterministic((0.8, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
solver = POMCPOWSolver(criterion=MaxUCB(20.0),)
pomdp = CovidPOMDP()
planner = solve(solver, pomdp)
rng = MersenneTwister();
up = SIRParticleFilter(m, 100, rng=rng)

b = initialize_belief(up, d)

action(planner,b)

#rng = MersenneTwister();

#sim = RolloutSimulator()

#importance sampling particle filter

#hist = simulate(sim, pomdp, planner, up, d, (0.8, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0))
#updater(policy::Policy) = Deterministic(1)
#initialize_belief(updater, d) = (0.8, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

#hr = HistoryRecorder(max_steps=20)
#hist = simulate(sim, pomdp, planner)
#for (s, b, a, r, sp, o) in hist
#    @show s, a, r, sp
#end

MethodError: MethodError: no method matching push_anode!(::POMCPOWTree{POWNodeBelief{NTuple{9,Float64},Tuple{Float64,Float64},Tuple{Int64,Int64},CovidPOMDP},Tuple{Float64,Float64},Tuple{Int64,Int64},ParticleCollection{NTuple{9,Float64}}}, ::Int64, ::Array{Int64,1}, ::Int64, ::Float64, ::Bool)
Closest candidates are:
  push_anode!(::POMCPOWTree{B,A,O,RB} where RB, ::Int64, !Matched::A, ::Int64, ::Float64, ::Any) where {B, A, O} at /Users/wyattraich/.julia/packages/POMCPOW/7FCnV/src/tree.jl:42
  push_anode!(::POMCPOWTree{B,A,O,RB} where RB, ::Int64, !Matched::A, ::Int64, ::Float64) where {B, A, O} at /Users/wyattraich/.julia/packages/POMCPOW/7FCnV/src/tree.jl:42
  push_anode!(::POMCPOWTree{B,A,O,RB} where RB, ::Int64, !Matched::A, ::Int64) where {B, A, O} at /Users/wyattraich/.julia/packages/POMCPOW/7FCnV/src/tree.jl:42
  ...

In [None]:
policy = solve(POMCPOWSolver(), CovidPOMDP())
@show_requirements action(policy)