## Implementation of JuliaBP interpreter using Gen

standard gen imports

In [1]:
using ResumableFunctions
using Gen
using PyPlot

In [2]:
# add bthread to the BP
function registerBThread(bThread_dict, name, bThread, bt_id)
    bThread_dict[bt_id] = Dict("Name" => name, "bThread" => bThread)
    end;

In [3]:
function event_to_int(e)
    val = 48
    if e[1] == 'H'
        return Int(e[2])-val
    elseif e[1] == 'O'
        return Int(e[2])-val+3
    else
        return Int(e[2])-val+6
    end
end;  

In [5]:
@resumable function get_bt_id()
    i = 1
    while true 
        @yield "bt_$i"
        i+=1
    end    
end;

In [8]:
function btDict(bThread_list)
    dict = Dict()
    k = keys(bThread_list) 
    for key in k
        dict[key] = nothing
    end
    return  dict
end;

In [9]:
@gen function selectEvent(i,request_list, block_list)
    request_and_not_block = setdiff(request_list,block_list)
    m = size(request_and_not_block)[1]
    n = size(events)[1]
    prob = [events[j] in Set(request_and_not_block)  ? 1/m : 0.0 for j=1:n]
    log_flag ? println("probability vector: $prob") : -1
    #idx2 = @trace(categorical(prob),(:event, i))
    return prob
end;

In [10]:
function is_wait_for_requested(event, wait_for_list,requested_list)
    events =[wait_for_list ;requested_list]
    if in(event, events)
        return true
    else 
        return false
    end
end;

In [11]:
@gen function bProgram(bThread_list)
    i = 0
    event = "init"
    bt_wait_for = btDict(bThread_list)
    bt_requeste = btDict(bThread_list)

    while true
        to_delete = []
        block = []
        requests = []
        ids = keys(bThread_list)
        for bt_id in ids
            if  is_wait_for_requested(event,bt_wait_for[bt_id],bt_requeste[bt_id]) || event == "init"
                tmp = bThread_list[bt_id]["bThread"](event)
                if tmp != false
                    log_flag ? println(bt_id,"::",bThread_list[bt_id]["Name"],"::\t ",tmp) : -1
                    if in("request",keys(tmp))
                        requests = [requests ; tmp["request"]]
                        bt_requeste[bt_id] = tmp["request"]
                    end
                    if in("block",keys(tmp))
                        block = [block ; tmp["block"]]
                    end
                    
                    if in("wait_for",keys(tmp))
                        bt_wait_for[bt_id] = tmp["wait_for"]
                    end    
                else
                    push!(to_delete,bt_id)
                end  
            end        
        end
        for bt_id in to_delete
            delete!(bThread_list, bt_id)
        end
        if size(requests)[1] > 0
            i+=1
            prob = selectEvent(i, requests, block)
            idx2 = @trace(categorical(prob),(:event, i))
            event = events[idx2]
            log_flag ? println("$i. Selected event: $event") : -1
        else
            break
            end    
        log_flag ? println("*************") : -1
        
    end
end    

DynamicDSLFunction{Any}(Dict{Symbol,Any}(), Dict{Symbol,Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], ##bProgram#258, Bool[0], false)

# Monty Hall

Main scenario:

	Request H1, H2, H3

    Request G1, G2, G3

    Request O1, O2, O3



Second scenario:

	h = Wait H1, H2, H3

    g = Wait G1, G2, G3

    Block O(index(h)), O(index(g))


Third scenario:

	h = Wait H1, H2, H3

    g = Wait G1, G2, G3

    assert ( index(h) == index(g) )

In [30]:
@resumable function main_scenario() 
    @yield Dict("request" => ["H1", "H2", "H3"])
    @yield Dict("request" => ["G1", "G2", "G3"])
    @yield Dict("request" => ["O1", "O2", "O3"])
end;

main_scenario (generic function with 1 method)

In [31]:
@resumable function second_scenario() 
    h = @yield Dict("wait_for" => ["H1", "H2", "H3"])
    g = @yield Dict("wait_for" => ["G1", "G2", "G3"])
    @yield Dict("block" => ["O$(h[2])","O$(g[2])"])
end;

second_scenario (generic function with 1 method)

In [40]:
events = ["H1","H2","H3","G1","G2","G3","O1","O2","O3"];
         # 1.    2.   3.  4.   5.   6.   7.   8.   9.  

In [41]:
log_flag = true
@gen function monty_hall()
    bThread_list_monty = Dict()
    id_gen = get_bt_id()
    registerBThread(bThread_list_monty,"playerBT" ,main_scenario(),id_gen())
    registerBThread(bThread_list_monty,"montyBT  " ,second_scenario(),id_gen())
    @trace(bProgram(bThread_list_monty))
end
trace = Gen.simulate(monty_hall, ());
Gen.get_choices(trace)

bt_2::montyBT  ::	 Dict("wait_for" => ["H1", "H2", "H3"])
bt_1::playerBT::	 Dict("request" => ["H1", "H2", "H3"])
probability vector: [0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
1. Selected event: H3
*************
bt_2::montyBT  ::	 Dict("wait_for" => ["G1", "G2", "G3"])
bt_1::playerBT::	 Dict("request" => ["G1", "G2", "G3"])
probability vector: [0.0, 0.0, 0.0, 0.3333333333333333, 0.3333333333333333, 0.3333333333333333, 0.0, 0.0, 0.0]
2. Selected event: G3
*************
bt_2::montyBT  ::	 Dict("block" => ["O3", "O3"])
bt_1::playerBT::	 Dict("request" => ["O1", "O2", "O3"])
probability vector: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.5, 0.0]
3. Selected event: O1
*************


│
├── (:event, 3) : 7
│
├── (:event, 1) : 3
│
└── (:event, 2) : 6


In [42]:
function do_inference(model, g,o, amount_of_computation)

    observations = Gen.choicemap()
    observations[(:event, 2)] = g
    observations[(:event, 3)] = o

 
    (traces, log_norm_weights, lml_est) = Gen.importance_sampling(model,(), observations, amount_of_computation);

    #TODO: add inference methods (MCMC)
    return traces, log_norm_weights, lml_est
end;

In [43]:

log_flag = false
traces,log_norm_weights, lml_est= do_inference(monty_hall, 4, 9, 10000)
#println(traces)
sum_1 = 0
sum_2 = 0 
sum_3 = 0

for (i,t) in enumerate(traces)
    if t[(:event, 1)] == 1
        sum_1 += exp(log_norm_weights[i])
    elseif t[(:event, 1)] == 2
        sum_2 += exp(log_norm_weights[i])    
    else 
        sum_3 += exp(log_norm_weights[i])             

    end
end

total = sum_1 + sum_2 + sum_3
println("p[H1]=$(sum_1/total), p[H2]=$(sum_2/total), p[H3]=$sum_3")



p[H1]=0.3303357314148686, p[H2]=0.6696642685851314, p[H3]=0.0
