In [None]:
1 + 1

In [None]:
using InteractiveUtils, DrWatson, Comonicon
if isdefined(Main, :IJulia) && Main.IJulia.inited
    using Revise
else
    ENV["GKSwstype"] = 100 # suppress warnings during gif saving
end
versioninfo()
@quickactivate

In [None]:
using Plots, ProgressMeter, Logging
theme(:bright; size=(300, 300))

In [None]:
using Random, Turing, BayesianSymbolic
using ExprOptimization.ExprRules
using PyCall
pd = pyimport("pandas")

includef(args...) = isdefined(Main, :Revise) ? includet(args...) : include(args...)
includef(srcdir("utility.jl"))
includef(srcdir("app_inf.jl"))
includef(srcdir("sym_reg.jl"))
includef(srcdir("network.jl"))
includef(srcdir("exp_max.jl"))
includef(srcdir("analyse.jl"))
includef(srcdir("dataset.jl"))
# Suppress warnings of using _varinfo
with_logger(SimpleLogger(stderr, Logging.Error)) do
    includef(srcdir("scenarios", "ullman.jl"))
end
includef(scriptsdir("ullman_hacks.jl"))

In [None]:
scenario, attribute = loadullman(datadir("ullman", "processed"), 1)

foreach(scenario.scenes) do scene
    @info scene.group
    @info groupinfo(scene.group)
end

@info attribute

@info compute_normrmse(
    UllmanScenario, [scenario], make_latents([attribute]), BayesianSymbolic.getforce, Likelihood(nahead=1)
)

In [None]:
scene = scenario.scenes[2]

@gif for t in 1:length(scene.traj)
    plot(scene.entity, scene.traj, t)
end

In [None]:
make_pos_plot(scene.traj)

In [None]:
make_vel_plot(scene.traj)

In [None]:
make_acc_plot(scene.traj)

## Study how to extract the assumed external force

In [None]:
d1 = Dict(i => i for i in 1:5)
d2 = Dict(i => i for i in 1:5)
d3 = Dict(0.01 => 1, 0.1 => 2, 1.0 => 3)
r3 = Dict(1 => 0.01, 2 => 0.1, 3 => 1.0)

results = []

Threads.@threads for (niterations, nahead, nlevel) in collect(Iterators.product(1:5, 1:5, [0.01, 0.1, 1.0]))

    force = let seed = 1,
        ScenarioModel = UllmanScenario,
        latentname = ["mass1", "mass2", "mass3", "charge1", "charge2", "charge3", "fric1", "fric2", "fric"],
        malg = HandCodedForce(niterations=niterations, mask=Bool[1,1,0,0,1]),
        mlike = Likelihood(nahead=nahead, nlevel=nlevel),
        scenarios = [scenario]
        attributes = [attribute]

        Random.seed!(seed)
        latents = make_latents(attributes) # orcale latents
        tused = @elapsed force = mstep(malg, ScenarioModel, scenarios, latents, mlike; verbose=true)

        force
    end

    push!(results, (force=force, niterations=niterations, nahead=nahead, nlevel))
    
end

In [None]:
C = zeros(5, 5, 5, 3)
for res in results
    C[:,d1[res.niterations],d2[res.nahead],d3[res.nlevel]] = res.force.constant
end

In [None]:
function vis_constants(C, inlevel)
    @info "nlevel=$(r3[inlevel])"

    constant_names = ["C", "G₀", "G", "Gm", "Gc"][[1,2,5]]
    clims = [(15, 40), (0, 30), (-1, 1)]
    ps = []

    xs = [string(i) for i in 1:5]
    ys = [string(i) for i in 1:5]

    for (i, (name, clim)) in enumerate(zip(constant_names, clims))
        z = C[i,:,:,inlevel]

        p = plot()
        heatmap!(p, xs, ys, z; aspect_ratio=1, clim=clim)
        xlabel!(p, "#iterations")
        ylabel!(p, "#ahead")
        title!(p, name)

        push!(ps, p)
    end

    plot(ps...; size=(300 * 3, 300), layout=Plots.GridLayout(1, length(constant_names)))
    
end

vis_constants(C, 1)

In [None]:
vis_constants(C, 2)

In [None]:
vis_constants(C, 3)

## Extract the external force with sensible parameters

In [None]:
force_ext = let ScenarioModel = UllmanScenario, slient = false
    # Load World 1 for fitting external
    scenario_ext, attribute_ext = loadullman(datadir("ullman", "processed"), 1)

    malg_ext = HandCodedForce(niterations=3, mask=Bool[1,1,0,0,1])
    mlike_ext = Likelihood(nahead=3, nlevel=0.1)
    latents_ext = make_latents([attribute_ext])
    mstep(malg_ext, UllmanScenario, [scenario_ext], latents_ext, mlike_ext; verbose=!slient)
end

In [None]:
res["malg"].opt

In [None]:
function make_est(latent)
    _, i = findmax(map(l -> l.logweight, latent))
    return latent[i].value
end

results = []
forces = []
@showprogress for wid in 1:10, sid in 1:6
    let niters = 1, seed = 0, slient = false
        hps = @ntuple(wid, sid, niters, seed)
        respath = projectdir("results-submission", "ullman", savename(hps; connector="-"), "em.jld2")
        try
            scenario, attribute = loadullman(
                datadir("ullman", "processed"), wid; idcs=[sid]
            )
            scenarios = [scenario]
            
            res = wload(respath)

            @unpack ScenarioModel, latentname, ealg, malg, elike, mlike, trace = res

            force = trace[end].force
            force = @set(force.external = make_getforce(force_ext))
            
            push!(forces, (wid=wid, expr=BayesianSymbolic.get_executable(force.tree, force.grammar)))

            latents = estep(ImportanceSampling(nsamples=500), ScenarioModel, scenarios, force, elike; verbose=!slient)

#             est = expect.(x -> x, latents)[1]
            est = make_est.(latents)[1]
            !slient && @info "" est attribute
            
            push!(results, (wid=wid, sid=sid, est=est, att=attribute))
        catch e
            println("Failed to load $respath")
            if isa(e, InterruptException)
                throw(e)
            end
        end
    end
end

In [None]:
function findclosest(s, x)
    l1 = Inf
    local retval
    for y in s
        _l1 = abs(y - x)
        if _l1 < l1
            l1 = _l1
            retval = y
        end
    end
    return retval
end

function answer_ques(est, att, opt; verbose=false)
    @assert length(est) == length(att)
    verbose && @info "" est att opt
    retval = []
    for i in 1:length(est)
        c = findclosest(opt, est[i])
        push!(retval, c == att[i])
    end
    return retval
end

cmass, cfric = [], []
for (idx, (wid, sid, est, att)) in enumerate(results)
    any(isnan.(est)) && (println("$idx has NaN"); continue)
    append!(cmass, answer_ques(est[1:3], att[1:3], [1, 3, 9]))
    append!(cfric, answer_ques(est[7:9], att[7:9], [0, 5, 20]))
end

cmass |> mean, cfric |> mean

In [None]:
function make_mat(est, att, opt)
    @assert length(est) == length(att)
    mat = zeros(3, 3)
    for i in 1:length(est)
        attidx = findfirst(sort(att) .== att[i])
        c = findclosest(opt, est[i])
        cidx = findfirst(sort(att) .== c)
        mat[attidx,cidx] += 1
    end
    return mat
end

mmass = []
mfric = []
for (idx, (wid, sid, est, att)) in enumerate(results)
    any(isnan.(est)) && (println("$idx has NaN"); continue)
    push!(mmass, make_mat(est[1:3], att[1:3], [1, 3, 9]))
    push!(mfric, make_mat(est[7:9], att[7:9], [0, 5, 20]))
end

mmass |> mean

In [None]:
mfric |> mean

In [None]:
filter(f -> f.wid == 10, forces)

In [None]:
cgforce = [
    true, false, true, true, false, true,
    true, true, false, true, false, true,
    false, true, true, false, false,
    true, false,
    true, true,
    true, true, true, true,
    
    true, false, true, true, false,
    
    false, false, false, false, false, false,
    
    true, true, true, true,
    
    false, false, false, false,
] 

cgforce |> mean

In [None]:
cpforce = [
    true, true, true, true, true, true,
    false, false, false, true, false, true,
    true, false, true, true, false,
    true, true,
    false, true,
    false, true, false, false,
    false, false, true, true, false,
    true, true, true, false, true, false,
    false, false, false, false,
    false, false, true, false
] 

cpforce |> mean