In [3]:
using DrWatson
@quickactivate "BNP2"
using Random, ProgressMeter, WeightsAndBiasLogger, MLDataUtils, BSON, Flux, Zygote
using MLToolkit.Neural, MLToolkit.DistributionsX
using Flux: Optimise
using Revise, BNP2

┌ Info: Precompiling MLToolkit [519e820e-097c-11e9-2274-1b004aeb0b9b]
└ @ Base loading.jl:1273
┌ Info: Precompiling BNP2 [91a68366-b384-5123-a8f1-3305599cf021]
└ @ Base loading.jl:1273


In [4]:
args = (
    dt          = 1f-1,
    n_trajs     = 10,
    is_noisyobs = false,
    σ_obs       = 1f-1,
)

;

In [5]:
function sim_traj(ms, qs, ps, dt, T)
    env = Space(Particle.(ms, qs, ps))
    traj = simulate(env, DiffEqSimulator(dt), T)
    return [env, traj...]
end

data_raw = 
let ms = fill(5e10, 3),
    qs = [[-1,  0],[ 1,  0], [ 0, √3]],
    ps = [[cos(π/3), -sin(π/3)], [cos(π/3),  sin(π/3)], [cos(π/1),  sin(π/1)]],
    T = 50, n_directions = 10, n_moving = 10, n_speed = 10, n_initials = 10, n_tests = div(3, 2)
    
    trajs, trajs_test = [], []
    # Trajectories with different initial conditions
    for i in 1:n_initials+n_tests
        traj = sim_traj(ms, add_gaussiannoise(qs, 1f-2), add_gaussiannoise(ps, 1f-2), args.dt, T)
        push!(i <= n_initials ? trajs : trajs_test, traj)
    end
    # Trajectories with different initial directions
    for i in 1:n_directions+n_tests
        traj = sim_traj(ms, qs, ps .* (1 + rand()), args.dt, T)
        push!(i <= n_directions ? trajs : trajs_test, traj)
    end
    # Trajectories with different speed
    for i in 1:n_speed+n_tests
        traj = sim_traj(ms, qs, rotate.(rand() * π, ps), args.dt, T)
        push!(i <= n_speed ? trajs : trajs_test, traj)
    end
    # Trajectories with different moving directions
    for i in 1:n_moving+n_tests
        traj = sim_traj(ms, qs, ps .+ [randn(2)], args.dt, T)
        push!(i <= n_moving ? trajs : trajs_test, traj)
    end
    (trajs=trajs, trajs_test=trajs_test)
end

for traj in data_raw.trajs_test[1:end]
    HTML(animof(traj).to_html5_video()) |> display
end

;

In [6]:
function preprocess(data; do_shuffle=true)
    # Convert trajectories to training pairs
    s_list, s′_list = [], []
    for traj in data.trajs
        states = stateof.(traj)  # Vec{Mat}
        states = vec.(states)    # Vec{Vec}
        states = hcat(states...) # Mat
        # Add Gaussian noise
        args.is_noisyobs && (states = states + args.σ_obs * randn(size(states)))
        # Create pairs of x_{t} -> x_{t+1}
        s, s′ = states[:,1:end-1], states[:,2:end]
        push!.((s_list, s′_list), (s, s′))
    end
    s, s′ = Matrix{Float32}(hcat(s_list...)), Matrix{Float32}(hcat(s′_list...))
    return (s=s, s′=s′)
end

data = preprocess(data_raw)

@info "Processed data" length(data_raw.trajs) length(data_raw.trajs_test) size(data.s) size(data.s′)

;

┌ Info: Processed data
│   length(data_raw.trajs) = 40
│   length(data_raw.trajs_test) = 4
│   size(data.s) = (12, 2000)
│   size(data.s′) = (12, 2000)
└ @ Main In[6]:20


In [47]:
gpu(data[:,1:5])

(s = Float32[-0.9947667 -0.9394256 … -0.7936965 -0.7044908; -0.013464166 -0.09409873 … -0.22922291 -0.28129536; … ; -1.0049078 -0.99859375 … -0.94751054 -0.9008759; -0.010137976 -0.15378906 … -0.44318953 -0.58958435], s′ = Float32[-0.9394256 -0.87220806 … -0.7044908 -0.6052438; -0.09409873 -0.16640484 … -0.28129536 -0.32122526; … ; -0.99859375 -0.9796907 … -0.9008759 -0.83802325; -0.15378906 -0.2980196 … -0.58958435 -0.7373483])

In [50]:
include(srcdir("Models.jl"))
import .Models

nf = Models.NeuralForce((50, 50), args.dt, args.σ_obs) |> gpu

Zygote.refresh()

let batch = gpu(data[:,1:5])
    Models.lossof(nf, batch...) |> display
#     @code_warntype Models.lossof(nf, batch...)
end

;

(loss = -12.492491f0, lp = 12.514783f0, l2 = 0.02229226f0)



In [51]:
logger = WBLogger(project=projectname(), notes="")
config!(logger, args)

Loading chipmunk for Linux (64bit) [/afs/inf.ed.ac.uk/user/s16/s1672897/miniconda2/envs/ml/lib/python3.6/site-packages/pymunk/libchipmunk.so]


wandb: Tracking run with wandb version 0.8.29
wandb: Run data is saved locally in wandb/run-20200307_173243-umckzx8r
wandb: Syncing run fallen-fire-68
wandb: ⭐️ View project at https://app.wandb.ai/xukai92/BNP2
wandb: 🚀 View run at https://app.wandb.ai/xukai92/BNP2/runs/umckzx8r
wandb: Run `wandb off` to turn off syncing.



In [55]:
let n_epochs = 1_000, opt = Optimise.ADAM(2f-4), ps = params(nf), batch_size=100
    with(logger) do
        @showprogress for epoch in 1:n_epochs
            data_shuffled = shuffle(data)
            for (iter, batch) in enumerate(eachbatch(values(data_shuffled); size=batch_size))
                batch = gpu(batch)
                local res
                gs = gradient(ps) do
                    res = Models.lossof(nf, batch...)
                    res.loss
                end
                @info "train" res...
                Optimise.update!(opt, ps, gs)
            end
        end
    end
end

;

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:05:53[39mm


In [57]:
let nf = nf |> cpu, traj_true = data_raw.trajs_test[4], state = vec(stateof(first(traj_true))), T = 50
    function envfrom(state)
        dim = div(size(state, 1), 2)
        return Space(Particle.(fill(5e10, 3), state[1:dim], state[dim+1:end]))
    end
    traj = [envfrom(state)]
    for t in 1:T+20
        state = Models.leapfrog(nf, state)
        push!(traj, envfrom(state))
    end
    HTML(animof(traj).to_html5_video()) |> display
    traj_combined = [Space([objectsof(traj[t])..., objectsof(traj_true[t])...]) for t in 1:T+1]
    HTML(animof(traj_combined; refs=[4, 5, 6]).to_html5_video()) |> display
end

;

## MCMC

In [None]:
@model three_body(ms, states, dt) = begin
    q ~ MvNormal(zeros(6), 10)
    p ~ MvNormal(zeros(6), 10)
    space = Space(Particle.(ms, q, p))
    for i in 1:length(states)
        space = transition(space, dt)
        states[i] ~ MvNormal(vec(stateof(space)), 1e-1)
    end
end

@time chn = sample(three_body(massof.(data.objs0), data.states, args.dt), args.alg, args.n_samples)

splot(chn; colordim=:parameter) |> display

bson(datadir("three_body-noise=$(args.σ).bson"), chn=chn)

;

let n_mc = 50, T = 100, do_mean = true, res = get(chn[end-n_mc+1:end], [:q, :p])
    Q = Matrix{Float64}(hcat(res.q...)')
    P = Matrix{Float64}(hcat(res.p...)')
    if do_mean
        Q = mean(Q; dims=2)
        P = mean(P; dims=2)
    end
    Q̂ = zeros(6, T)
    for i in 1:size(Q, 2)
        q, p = Q[:,i], P[:,i]
        space = Space(Particle.(data.ms, q, p))
        traj = simulate(space, args.dt, T)
        Q̂ = Q̂ + hcat(positionof.(traj)...)
    end
    Q̂ = Q̂ / size(Q, 2)
    HTML(animof(Q̂).to_html5_video()) |> display
end

;