In [1]:
using Gen
using GenArrow
using Serialization
using BenchmarkTools

In [2]:
@gen function model()
    x ~ mvnormal([0, 0], [1 0; 0 1])
    if (b ~ bernoulli(0.5))
        y ~ categorical([0.25, 0.25, 0.25, 0.25])
        {:a => 1} ~ bernoulli(0.5)
    else
        z ~ exponential(2)
        {:c => 1} ~ bernoulli(0.5)
    end
    return 1
end
@gen function foo(x1::Float64, x2::Float64)
    y = @trace(normal(x1 + x2, 1.0), :z)
    return y
end
@gen function zoobar(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
    y = @trace(bernoulli(y_prev ? z1 : z2), :y)
    return y
end

bar = Map(foo)
zoo = Unfold(zoobar)


Unfold{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], var"##zoobar#314", Bool[0, 0, 0, 0], false))

In [3]:
tr_old, w_old = generate(model, ())
(trace, _) = generate(bar, ([0.0, 0.5], [0.5, 1.0]))
(brace, _) = generate(zoo, (5, false, 0.05, 0.95))
display(get_choices(trace))

│
├── 1
│   │
│   └── :z : 1.300320394913035
│
└── 2
    │
    └── :z : 1.8646730016085349


In [3]:
function write_to_file(io)
    bytes = take!(io)
    open("./data.trace", "w") do io
        write(io, bytes)
    end
end
function read_from_file()
    io = open("./data.trace")
end

read_from_file (generic function with 1 method)

In [9]:
io = GenArrow.serialize(trace);

1138

In [5]:
seekstart(io)
recovered_trace = GenArrow.deserialize(bar, io)

Gen.VectorTrace{Gen.MapType, Any, Gen.DynamicDSLTrace}(Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#313", Bool[0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#313", Bool[0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:z => Gen.ChoiceOrCallRecord{Float64}(1.300320394913035, -1.239194900461551, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -1.239194900461551, 0.0, (0.0, 0.5), 1.300320394913035), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#313", Bool[0, 0], false),

In [6]:
display(get_choices(recovered_trace))

│
├── 1
│   │
│   └── :z : 1.300320394913035
│
└── 2
    │
    └── :z : 1.8646730016085349


In [168]:
function bench()
    n=100000
    (trace, _) = generate(bar, ([0.5 * i for i=1:n], [0.5*i+0.5 for i=1:n]))
    io = GenArrow.serialize(trace);
    seekstart(io)
    # GenArrow.deserialize(bar, io)
end
@benchmark bench()

BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.002 s[22m[39m … [35m  1.193 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 7.98% … 22.63%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.142 s              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m18.61%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.112 s[22m[39m ± [32m73.457 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m17.07% ±  5.78%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m█[39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[

In [6]:
open("dump.b", "w") do io
    write(io, b"hi there")
end

8