In [1]:
using Gen
using Serialization
using Logging
using BenchmarkTools

In [11]:
using GenArrow

In [3]:
function write_to_file(fname, input)
    seekstart(input)
    data = read(input, String)
    open(fname, "w") do io
        write(io, data)
    end
end
function read_from_file(fname)
    io = IOBuffer()
    open(fname, "r") do file
        data = read(file, String)
        write(io, data)
    end
    seekstart(io)
    io
end

read_from_file (generic function with 1 method)

### Leaf Nodes

In [4]:
@gen function submodel(w)
    a ~ bernoulli(0.5)
    b ~ bernoulli(0.5)
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##submodel#312", Bool[0], false)

In [5]:
@gen function model(n)
    z ~ bernoulli(0.5)
    # for k=1:n
    #     @trace(bernoulli(0.5), k)
    # end
    q ~ submodel("what the")
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Any], false, Union{Nothing, Some{Any}}[nothing], var"##model#313", Bool[0], false)

In [6]:
debugIO = open("write.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
tr_old, w_old = generate(model, (10,))
display(get_choices(tr_old))
display(get_score(tr_old))
io = IOBuffer()
GenArrow.serialize(io, tr_old)
close(debugIO)
write_to_file("data.trace", io)

│
├── :z : false
│
└── :q
    │
    ├── :a : false
    │
    └── :b : false


-2.0794415416798357

538

In [12]:
debugIO = open("read.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)

io = read_from_file("data.trace")
recovered_tr = try
    recovered_tr = GenArrow._deserialize(model, io)
    close(debugIO)
    recovered_tr
catch
    close(debugIO)
    rethrow()
end
display(get_score(recovered_tr))
display(get_choices(recovered_tr))
display(get_args(recovered_tr))


MethodError: MethodError: no method matching get_retval(::Gen.DynamicDSLTrace{DynamicDSLFunction{Any}})
Closest candidates are:
  get_retval(!Matched::LazyTrace) at ~/Documents/probcomp/GenArrow.jl/src/lazy/LazyTrace.jl:73

In [13]:
global_logger(NullLogger())
function bench()
    n = 1000
    io = IOBuffer()
    (trace, _) = generate(model, (n,))
    GenArrow.serialize(io, trace);
    seekstart(io)
    GenArrow._deserialize(model, io)
end
@benchmark bench()

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m32.625 μs[22m[39m … [35m  5.231 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 98.05%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m34.708 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m38.665 μs[22m[39m ± [32m111.869 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m6.37% ±  2.19%

  [39m [39m [39m█[39m▆[34m▃[39m[39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m█[39m█[39m█[34

### Internal Nodes

In [14]:
@gen function model()
    {:x=>1} ~ bernoulli(0.5)
    {:x=>2} ~ categorical([0.25, 0.25, 0.25, 0.25])
    return 1
end

DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[], false, Union{Nothing, Some{Any}}[], var"##model#318", Bool[], false)

In [15]:
debugIO = open("write.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
tr_old, w_old = generate(model, ())
display(get_choices(tr_old))
display(get_score(tr_old))
io = IOBuffer()
GenArrow.serialize(io, tr_old)
close(debugIO)
write_to_file("data.trace", io)

│
└── :x
    │
    ├── 2 : 4
    │
    └── 1 : true


-2.0794415416798357

439

In [16]:
debugIO = open("read.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
seekstart(io)
recovered_tr = try
    recovered_tr = GenArrow._deserialize(model, io)
    close(debugIO)
    recovered_tr
catch
    close(debugIO)
    rethrow()
end
display(get_score(recovered_tr))

-2.0794415416798357

### Map

In [4]:
@gen function foo(x1::Float64, x2::Float64)
    y = @trace(normal(x1 + x2, 1.0), :z)
    return y
end

map_foo = Map(foo)

Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#312", Bool[0, 0], false))

In [5]:
debugIO = open("write.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)
(trace, _) = generate(map_foo, ([0.0, 0.5], [0.5, 1.0]))
display(get_choices(trace))
display(get_score(trace))
io = IOBuffer()
GenArrow.serialize(io, trace)
close(debugIO)
write_to_file("data.trace", io)

│
├── 1
│   │
│   └── :z : -1.4177475336087921
│
└── 2
    │
    └── :z : 0.8093017910364002


-3.9152868756734103

1210

In [27]:
debugIO = open("read.txt", "w+")
logger = ConsoleLogger(debugIO, Debug)
global_logger(logger)

io = read_from_file("data.trace")
recovered_tr = try
    recovered_tr = GenArrow._deserialize(map_foo, io)
    close(debugIO)
    recovered_tr
catch
    close(debugIO)
    rethrow()
end
display(get_score(recovered_tr))
display(get_choices(recovered_tr))
display(get_args(recovered_tr))

-3.9152868756734103

│
├── 1
│   │
│   └── :z : -1.4177475336087921
│
└── 2
    │
    └── :z : 0.8093017910364002


([0.0, 0.5], [0.5, 1.0])

### Unfold

In [43]:
@gen function bar(t::Int, y_prev::Bool, z1::Float64, z2::Float64)
    y = @trace(bernoulli(y_prev ? z1 : z2), :y)
    return y
end

unfold_bar = Unfold(bar)

Unfold{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], var"##bar#328", Bool[0, 0, 0, 0], false))

In [48]:
(brace, _) = generate(unfold_bar, (5, false, 0.05, 0.95))

(Gen.VectorTrace{Gen.UnfoldType, Any, Gen.DynamicDSLTrace}(Unfold{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], var"##bar#328", Bool[0, 0, 0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing, nothing, nothing], var"##bar#328", Bool[0, 0, 0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:y => Gen.ChoiceOrCallRecord{Bool}(true, -0.05129329438755058, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -0.05129329438755058, 0.0, (1, false, 0.05, 0.95), true), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Int64, Bool, Float64, Float64], false,

In [205]:
write_to_file(io)

602

In [153]:
io = GenArrow.serialize(trace);

In [154]:
seekstart(io)
recovered_trace = GenArrow.deserialize(bar, io)

leaf count: 1
Key: z
is trace: false
Deserialize Internal Nodes: 0
leaf count: 1
Key: z
is trace: false
Deserialize Internal Nodes: 0


Gen.VectorTrace{Gen.MapType, Any, Gen.DynamicDSLTrace}(Map{Any, Gen.DynamicDSLTrace}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#384", Bool[0, 0], false)), Gen.DynamicDSLTrace[Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#384", Bool[0, 0], false), Trie{Any, Gen.ChoiceOrCallRecord}(Dict{Any, Gen.ChoiceOrCallRecord}(:z => Gen.ChoiceOrCallRecord{Float64}(1.7223728665993139, -1.6660362457037847, NaN, true)), Dict{Any, Trie{Any, Gen.ChoiceOrCallRecord}}()), false, -1.6660362457037847, 0.0, (0.0, 0.5), 1.7223728665993139), Gen.DynamicDSLTrace{DynamicDSLFunction{Any}}(DynamicDSLFunction{Any}(Dict{Symbol, Any}(), Dict{Symbol, Any}(), Type[Float64, Float64], false, Union{Nothing, Some{Any}}[nothing, nothing], var"##foo#384", Bool[0, 0], fal

In [6]:
display(get_choices(recovered_trace))

│
├── 1
│   │
│   └── :z : 1.300320394913035
│
└── 2
    │
    └── :z : 1.8646730016085349


In [168]:
function bench()
    n=100000
    (trace, _) = generate(bar, ([0.5 * i for i=1:n], [0.5*i+0.5 for i=1:n]))
    io = GenArrow.serialize(trace);
    seekstart(io)
    # GenArrow.deserialize(bar, io)
end
@benchmark bench()

BenchmarkTools.Trial: 5 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.002 s[22m[39m … [35m  1.193 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m 7.98% … 22.63%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.142 s              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m18.61%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.112 s[22m[39m ± [32m73.457 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m17.07% ±  5.78%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m█[39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[

In [23]:
io = IOBuffer()
Serialization.serialize(io, "mama")
println(take!(io))

UInt8[0x37, 0x4a, 0x4c, 0x11, 0x04, 0x00, 0x00, 0x00, 0x21, 0x04, 0x6d, 0x61, 0x6d, 0x61]
