In [None]:
#include("BAPOMCP.jl")

In [1]:
using POMDPs, POMDPModels, POMDPToolbox, ParticleFilters, Distributions#, BAPOMCP




In [2]:

# pomdp = TigerPOMDP()

# solver = POMCPSolver()
# planner = solve(solver, pomdp)

# for (s, a, o) in stepthrough(pomdp, planner, "sao", max_steps=10)
#     println("State was $s,")
#     println("action $a was taken,")
#     println("and observation $o was received.\n")
# end

In [3]:
#rand(Categorical([0.5,0.5]))

In [4]:
pomdp = TigerPOMDP()

POMDPModels.TigerPOMDP(-1.0, -100.0, 10.0, 0.85, 0.95)

In [5]:

mutable struct BAPOMDPState{S}
    s::S                        #State
    oc::Array{Int,3}             #Multi dimensional array: Array{Int}(S,A,O)
    tc::Array{Int,3}             #Multi dimensional array: Array{Int}(S,A,S)
end


In [6]:
function obs_count_prob(p::POMDP,s::BAPOMDPState,a,o)
    s_idx = find(x -> x==s.s,states(p))
    a_idx = find(x -> x==a,actions(p))
    o_idx = find(x -> x==o,observations(p))
    return first(s.oc[s_idx,a_idx,o_idx]/sum(s.oc[s_idx,a_idx,:])) #convert to scalar
end

#Get the transition probability
function trans_count_prob(p::POMDP,s::BAPOMDPState,a,sp::BAPOMDPState)
    s_idx = find(x -> x==s.s,states(p))
    a_idx = find(x -> x==a,actions(p))
    sp_idx = find(x -> x==sp.s,states(p))
    return first(s.tc[s_idx,a_idx,sp_idx]/sum(s.tc[s_idx,a_idx,:])) #convert to scalar
end

#Get the distribution of transition probabilities
function trans_count_dist(p::POMDP,s::BAPOMDPState,a)
    s_idx = find(x -> x==s.s,states(p))
    a_idx = find(x -> x==a,actions(p))
    return reshape(s.tc[s_idx,a_idx,:]/sum(s.tc[s_idx,a_idx,:]),(n_states(p)))
end




trans_count_dist (generic function with 1 method)

In [7]:
#Initialize the BAPOMDP state
function initiate_state(p::POMDP,s)
    tc = ones(Int,POMDPs.n_states(p),POMDPs.n_actions(p),POMDPs.n_states(p))
    oc = ones(Int,POMDPs.n_states(p),POMDPs.n_actions(p),POMDPs.n_observations(p))
    return BAPOMDPState{state_type(p)}(s,oc,tc)
end


initiate_state (generic function with 1 method)

In [8]:
#Copy BAPOMDP state
function copy(s::BAPOMDPState)
    return BAPOMDPState(s.s,Base.copy(s.oc),Base.copy(s.tc))
end


copy (generic function with 1 method)

In [9]:
#Increment the counts for the new sp
function increment_trans_obs_counts(p::POMDP,s::BAPOMDPState,a,o,sp::BAPOMDPState)
    s_idx = find(x -> x==s.s,states(p))
    a_idx = find(x -> x==a,actions(p))
    sp_idx = find(x -> x==sp.s,states(p))
    o_idx = find(x -> x==o,observations(p))
    sp.tc[s_idx,a_idx,sp_idx] += 1 #only can do transition at this time, observation will be updated later
    sp.oc[s_idx,a_idx,o_idx] += 1 #only can do transition at this time, observation will be updated later
end


increment_trans_obs_counts (generic function with 1 method)

In [10]:

function ParticleFilters.generate_s(model::POMDP,s::BAPOMDPState,a,rng::AbstractRNG)
    sp = copy(s)
    s_index = rand(Categorical(trans_count_dist(model,s,a))) #Handle states that are not integers
    sp.s = states(model)[s_index]
    return sp
end

#OUtput of this is put into a WeightedParticleBelief, paired with the output of the generate_s() above
#only just the probability of sp + o, given a + s; P(s',o | a,s)
function ParticleFilters.obs_weight(model::POMDP,a,s::BAPOMDPState,sp::BAPOMDPState,o)
    increment_trans_obs_counts(model,s,a,o,sp)
    return obs_count_prob(model,s,a,o) * trans_count_prob(model,s,a,sp) #merge the observation and transition probabilities together
end


In [11]:
s1 = initiate_state(pomdp,1)

BAPOMDPState{Bool}(true, [1 1 1; 1 1 1]

[1 1 1; 1 1 1], [1 1 1; 1 1 1]

[1 1 1; 1 1 1])

In [12]:
#BAPOMDPState{Int}(1,zeros(Int,1,2,3),zeros(Int,1,2,3))


In [13]:
aa = copy(s1)


BAPOMDPState{Bool}(true, [1 1 1; 1 1 1]

[1 1 1; 1 1 1], [1 1 1; 1 1 1]

[1 1 1; 1 1 1])

In [None]:
increment_trans_obs_counts(pomdp,s1,0,true,aa)

In [None]:
observations(pomdp)

In [None]:
actions(pomdp)


In [None]:
aa

In [None]:
ParticleFilters.obs_weight(pomdp,0,s1,aa,true)

In [None]:
a1=obs_count_prob(pomdp,s1,0,true)

In [None]:
a2=trans_count_prob(pomdp,s1,0,aa)

In [None]:
first(a1)

In [14]:
ParticleFilters.generate_s(pomdp,s1,0,MersenneTwister(1234))

BAPOMDPState{Bool}(true, [1 1 1; 1 1 1]

[1 1 1; 1 1 1], [1 1 1; 1 1 1]

[1 1 1; 1 1 1])

In [None]:
d1 = trans_count_dist(pomdp,s1,0)


In [None]:
s_index = rand(Categorical(d1))

In [None]:
states(pomdp)[1]