In [1]:
using Gen
using Serialization
using Logging
using BenchmarkTools

In [2]:
using GenArrow

In [3]:
function write_to_file(fname, input)
    seekstart(input)
    data = read(input, String)
    open(fname, "w") do io
        write(io, data)
    end
end
function read_from_file(fname)
    io = IOBuffer()
    open(fname, "r") do file
        data = read(file, String)
        write(io, data)
    end
    seekstart(io)
    io
end

read_from_file (generic function with 1 method)

### Leaf Nodes

In [4]:
@gen function submodel(w)
    a ~ bernoulli(0.5)
    b ~ bernoulli(0.5)
end

@gen function model(n)
    z ~ bernoulli(0.5)
    # for k=1:n
    #     @trace(bernoulli(0.5), k)
    # end
    q ~ submodel("what the")
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##model#313", Bool[0], false)

In [5]:
debugIO = open("write.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
tr_old, w_old = generate(model, (10,))
display(get_choices(tr_old))
display(get_score(tr_old))
io = IOBuffer()
GenArrow.serialize(io, tr_old)
close(debugIO)
write_to_file("data.trace", io)

│
├── :z : false
│
└── :q
    │
    ├── :a : false
    │
    └── :b : false


-2.0794415416798357

538

In [10]:
debugIO = open("read.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)

io = read_from_file("data.trace")
recovered_tr = try
    recovered_tr = GenArrow._deserialize_lazy(io)
    close(debugIO)
    recovered_tr
catch
    close(debugIO)
    rethrow()
end
display(get_score(recovered_tr))
display(get_choices(recovered_tr))
display(get_args(recovered_tr))

-5.545177444479562

│
├── :z : false
│
└── :q
    │
    ├── :a : false
    │
    └── :b : false


(10,)

### Internal Nodes

In [11]:
@gen function model()
    {:x=>1} ~ bernoulli(0.5)
    {:x=>2} ~ categorical([0.25, 0.25, 0.25, 0.25])
    return 1
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##model#314", Bool[], false)

In [12]:
debugIO = open("write.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
tr_old, w_old = generate(model, ())
display(get_choices(tr_old))
display(get_score(tr_old))
io = IOBuffer()
GenArrow.serialize(io, tr_old)
close(debugIO)
write_to_file("data.trace", io)

│
└── :x
    │
    ├── 2 : 3
    │
    └── 1 : true


-2.0794415416798357

332

In [13]:
debugIO = open("read.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
seekstart(io)
recovered_tr = try
    recovered_tr = GenArrow._deserialize_lazy(io)
    close(debugIO)
    recovered_tr
catch
    close(debugIO)
    rethrow()
end
display(get_choices(recovered_tr))
display(get_score(recovered_tr))

│
└── :x
    │
    ├── 2 : 3
    │
    └── 1 : true


-4.1588830833596715

### Conversion

In [17]:
@gen function mixed_submodel(w)
    a ~ bernoulli(0.5)
    b ~ bernoulli(0.5)
end

@gen function mixed(n)
    z ~ bernoulli(0.5)
    {:x=>1} ~ bernoulli(0.5)
    {:x=>2} ~ categorical([0.25, 0.25, 0.25, 0.25])
    q ~ mixed_submodel("hol up")
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##mixed#318", Bool[0], false)

In [18]:
debugIO = open("write.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
tr_old, w_old = generate(mixed, (10,))
display(get_choices(tr_old))
display(get_score(tr_old))
io = IOBuffer()
GenArrow.serialize(io, tr_old)
close(debugIO)
write_to_file("data.trace", io)

│
├── :z : false
│
├── :q
│   │
│   ├── :a : true
│   │
│   └── :b : true
│
└── :x
    │
    ├── 2 : 4
    │
    └── 1 : true


-4.1588830833596715

696

In [19]:
debugIO = open("read.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)

io = read_from_file("data.trace")
recovered_tr = try
    recovered_tr = GenArrow._deserialize_lazy(io, gen_fn = mixed)
    close(debugIO)
    recovered_tr
catch
    close(debugIO)
    rethrow()
end
# display(get_score(recovered_tr))
# display(get_choices(recovered_tr))
# display(get_args(recovered_tr))

Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##mixed#318", Bool[0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:z => Gen.ChoiceOrCallRecord{Bool}(false, -0.6931471805599453, NaN, true), :q => Gen.ChoiceOrCallRecord{LazyTrace}(LazyTrace(Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN, true), :b => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -2.772588722239781, 0.0, ("hol up",), true), -2.772588722239781, 0.0, false)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}(:x => Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(2 => Gen.ChoiceOrCallRecord{Int64}(4, -1.3862943611198906, NaN, true), 1 => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN,

### Update

In [42]:
@gen function mixed_submodel(w)
    a ~ bernoulli(0.5)
    b ~ bernoulli(0.5)
end

@gen function mixed(n)
    z ~ bernoulli(0.5)
    # {:x=>1} ~ bernoulli(0.5)
    # {:x=>2} ~ categorical([0.25, 0.25, 0.25, 0.25])
    # q ~ mixed_submodel("hol up")
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##mixed#322", Bool[0], false)

In [92]:
global_logger(NullLogger())
tr_old, w_old = generate(mixed, (10,))
display("Original")
display(get_choices(tr_old))
display(get_score(tr_old))
io = IOBuffer()
GenArrow.serialize(io, tr_old)
seekstart(io)
recovered_tr = GenArrow._deserialize_lazy(io, gen_fn = mixed)
display("Recovered")
display(get_choices(recovered_tr))
display(get_score(recovered_tr))

"Original"

│
└── :z : false


-0.6931471805599453

"Recovered"

│
└── :z : false


-1.3862943611198906

In [104]:
chm = choicemap((:z, true))
tr_update, _, _, _ = update(recovered_tr, (10,), (NoChange(),), chm)
recovered_tr_update, _, _, _ = update(tr_old, (10,), (NoChange(),), chm)

(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##mixed#322", Bool[0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:z => Gen.ChoiceOrCallRecord{Bool}(true, -0.6931471805599453, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.6931471805599453, 0.0, (10,), true), 0.0, UnknownChange(), DynamicChoiceMap(Dict{Any, Any}(:z => false), Dict{Any, Any}()))

In [105]:
display(get_choices(tr_update))
display(get_choices(recovered_tr_update))

│
└── :z : true


│
└── :z : true


In [106]:
@gen function away(n)
    for i=1:n
        {:k=>i} ~ bernoulli(0.4)
    end
    return n
end
@gen function throw(n)
    a ~ bernoulli(0.4)
    q ~ away(n)
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##throw#325", Bool[0], false)

In [135]:
tr, _ = generate(throw, (4,))
tr2 = deepcopy(tr)

Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##throw#325", Bool[0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), :q => Gen.ChoiceOrCallRecord{Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}}(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##away#324", Bool[0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}(:k => Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(4 => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), 2 => Gen.ChoiceOrCallRecord{Bool}(true, -0.916290731874155, NaN, true), 3 => Gen.ChoiceOrCallRecord{Bool}(false, -0.510

In [152]:
chm = choicemap((:a, false))
# display(chm)
update(tr, (4,), (NoChange(),), chm)

(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##throw#325", Bool[0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:a => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), :q => Gen.ChoiceOrCallRecord{Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}}(Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##away#324", Bool[0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}(:k => Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(4 => Gen.ChoiceOrCallRecord{Bool}(false, -0.5108256237659907, NaN, true), 2 => Gen.ChoiceOrCallRecord{Bool}(true, -0.916290731874155, NaN, true), 3 => Gen.ChoiceOrCallRecord{Bool}(false, -0.51

In [112]:
tr==tr2

false