In [1]:
using DrWatson
@quickactivate "BNP2"
using Random, ProgressMeter, WeightsAndBiasLogger, MLDataUtils, BSON, Flux, Zygote
using MLToolkit.Neural, MLToolkit.DistributionsX
using Flux: Optimise
using Revise, BNP2

┌ Info: CUDAdrv.jl failed to initialize, GPU functionality unavailable (set JULIA_CUDA_SILENT or JULIA_CUDA_VERBOSE to silence or expand this message)
└ @ CUDAdrv /Users/kai/.julia/packages/CUDAdrv/b1mvw/src/CUDAdrv.jl:67
┌ Info: Precompiling BNP2 [11504357-4fe5-5405-981c-8cd43ea31635]
└ @ Base loading.jl:1273


In [2]:
args = (
    dt          = 1f-1,
    n_trajs     = 10,
    is_noisyobs = false,
    σ_obs       = 1f-1,
)

;

In [3]:
function sim_traj(ms, qs, ps, dt, T)
    env = Space(Particle.(ms, qs, ps))
    traj = simulate(env, DiffEqSimulator(dt), T)
    return [env, traj...]
end

data_raw = 
let ms = fill(5e10, 3),
    qs = [[-1,  0],[ 1,  0], [ 0, √3]],
    ps = [[cos(π/3), -sin(π/3)], [cos(π/3),  sin(π/3)], [cos(π/1),  sin(π/1)]],
    T = 50, n_directions = 10, n_moving = 10, n_speed = 10, n_initials = 10, n_tests = div(3, 2)
    
    trajs, trajs_test = [], []
    # Trajectories with different initial conditions
    for i in 1:n_initials+n_tests
        traj = sim_traj(ms, add_gaussiannoise(qs, 1f-2), add_gaussiannoise(ps, 1f-2), args.dt, T)
        push!(i <= n_initials ? trajs : trajs_test, traj)
    end
    # Trajectories with different initial directions
    for i in 1:n_directions+n_tests
        traj = sim_traj(ms, qs, ps .* (1 + rand()), args.dt, T)
        push!(i <= n_directions ? trajs : trajs_test, traj)
    end
    # Trajectories with different speed
    for i in 1:n_speed+n_tests
        traj = sim_traj(ms, qs, rotate.(rand() * π, ps), args.dt, T)
        push!(i <= n_speed ? trajs : trajs_test, traj)
    end
    # Trajectories with different moving directions
    for i in 1:n_moving+n_tests
        traj = sim_traj(ms, qs, ps .+ [randn(2)], args.dt, T)
        push!(i <= n_moving ? trajs : trajs_test, traj)
    end
    (trajs=trajs, trajs_test=trajs_test)
end

for traj in data_raw.trajs_test[1:end]
    HTML(animof(traj).to_html5_video()) |> display
end

;

Assume we have $n$ bodies in the $d$-dimensional space. 
Let $D=nd$ be the total dimensionality of position variables or velocity variables.

- `states` or `s` has of shape $2D \times T$
- `states[i]` consist `[pos1, pos2, pos3, vel1, vel2, vel3]` where `pos` or `vel` has a shape of $d$
- `q` and `p` are postions and velocities for all bodies respectively
  - `q` and `p` both has a shape of $D \times T$
  
Some definitions

- `attritbue[t] = [pos, vel, mass, others...] -> not first order if attribute[t] != attribute[t-1]`
- `state[t] = [pos, vel] -> first order Markovian`
- `visible_state[t] = [pos_sub, pos_vel] `

In [4]:
function Random.shuffle(data::NamedTuple{(:s, :s′), <:Any})
    idcs = shuffle(1:size(data.s, 2))
    return (s=data.s[:,idcs], s′=data.s′[:,idcs])
end

function preprocess(data; do_shuffle=true)
    # Convert trajectories to training pairs
    s_list, s′_list = [], []
    for traj in data.trajs
        states = stateof.(traj)  # Vec{Mat}
        states = vec.(states)    # Vec{Vec}
        states = hcat(states...) # Mat
        # Add Gaussian noise
        args.is_noisyobs && (states = states + args.σ_obs * randn(size(states)))
        # Create pairs of x_{t} -> x_{t+1}
        s, s′ = states[:,1:end-1], states[:,2:end]
        push!.((s_list, s′_list), (s, s′))
    end
    s, s′ = Matrix{Float32}(hcat(s_list...)), Matrix{Float32}(hcat(s′_list...))
    return (s=s, s′=s′)
end

data = preprocess(data_raw)

@info "Processed data" length(data_raw.trajs) length(data_raw.trajs_test) size(data.s) size(data.s′)

;

┌ Info: Processed data
│   length(data_raw.trajs) = 40
│   length(data_raw.trajs_test) = 4
│   size(data.s) = (12, 2000)
│   size(data.s′) = (12, 2000)
└ @ Main In[4]:25


In [5]:
function apply_kernels(X)
    return vcat(X, 1 ./ X, sin.(X), cos.(X))
end

function euclidsq(X::T) where {T<:AbstractMatrix}
    XiXj = transpose(X) * X
    x² = sum(X.^2; dims=1)
    return transpose(x²) .+ x² - 2XiXj
end

function pairwise_compute(X)
    dim = div(size(X, 1), 3)
    X = cat([X[(i-1)*dim+1:i*dim,:] for i in 1:3]...; dims=3)
    hs = map(1:size(X, 2)) do t
        Xt = X[:,t,:]
        Dt = euclidsq(Xt)
        ht = sum(Dt; dims=2)
    end
    return hcat(hs...)
end

;

12×5 Array{Float64,2}:
 2560.0          2560.0          …  2560.0          2560.0        
 1024.0          1024.0             1024.0          1024.0        
 2560.0          2560.0             2560.0          2560.0        
    0.000390625     0.000390625        0.000390625     0.000390625
    0.000976563     0.000976563        0.000976563     0.000976563
    0.000390625     0.000390625  …     0.000390625     0.000390625
    0.387587        0.387587           0.387587        0.387587   
   -0.158533       -0.158533          -0.158533       -0.158533   
    0.387587        0.387587           0.387587        0.387587   
   -0.921833       -0.921833          -0.921833       -0.921833   
    0.987354        0.987354     …     0.987354        0.987354   
   -0.921833       -0.921833          -0.921833       -0.921833   

In [32]:
l1of(x) = sum(abs.(x))
l2of(x) = sum(x.^2)

l1regof(m) = sum(l1of, params(m)) / nparams(m)
l2regof(m) = sum(l2of, params(m)) / nparams(m)

struct NeuralForce{T}
    f::T
end

Flux.@functor NeuralForce

function NeuralForce(Dh::Int; n_bodys=3, d=2)
    dim = n_bodys * d
    f = Chain(Dense(2dim, Dh, relu), Dense(Dh, Dh, relu), Dense(Dh, Dh, relu), Dense(Dh, Dh, relu), Dense(Dh, dim))
    return NeuralForce(f)
end

(nf::NeuralForce)(state) = nf.f(state)

# TODO: make this variational to account uncertainty in integration
function leapfrog(state, nf, dt)
    dim = div(size(state, 1), 2)
    q, p = state[1:dim,:], state[dim+1:end,:]
    p = p + dt / 2 * nf(vcat(q, p))
    q = q + dt * p
    p = p + dt / 2 * nf(vcat(q, p))
    return vcat(q, p)
end

function get_lossof(nf::NeuralForce; dt=args.dt, σ=args.σ_obs)
    function lossof(s, s′)
        ŝ = leapfrog(s, nf, dt)
        lps = logpdf(Normal(ŝ, fill(σ, size(ŝ))), s′) # elementwise log probabilities
        lp = mean(sum(lps; dims=1)) # sum over state (dim 1) and avgerage over time (dim 2)
        l2 = l2regof(nf)
        return (loss=-lp + l2, lp=lp, l2=l2)
    end
    return lossof
end

nf = NeuralForce(50)

lossof = get_lossof(nf)
Zygote.refresh()

let s = data.s[:,1:5], s′ = data.s′[:,1:5]    
    lossof(s, s′) |> display
#     @code_warntype lossof(s, s′)
end

;

(loss = -13.601502f0, lp = 13.622411f0, l2 = 0.020908466f0)

In [33]:
logger = WBLogger(project=projectname(), notes="")
config!(logger, args)


wandb: Waiting for W&B process to finish, PID 84473
wandb: Program ended successfully.
wandb: Run summary:
wandb:     train/lp 16.530418395996094
wandb:     _runtime 36750.425022125244
wandb:        _step 59999
wandb:   _timestamp 1583586713.00773
wandb:   train/loss -16.530418395996094
wandb:     train/l2 0.07922817021608353
wandb: Syncing files in wandb/run-20200307_031223-mgd3xaon:
wandb:   upstream_diff_0d40ca2888a766d1596ba1bd00716363712faa70.patch
wandb: plus 8 W&B file(s) and 0 media file(s)
wandb:                                                                                
wandb: Synced twilight-wildflower-64: https://app.wandb.ai/xukai92/BNP2/runs/mgd3xaon
wandb: Tracking run with wandb version 0.8.27
wandb: Wandb version 0.8.29 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Run data is saved locally in wandb/run-20200307_131351-45x8lyev
wandb: Syncing run trim-moon-65
wandb: ⭐️ View project at https://app.wandb.ai/xukai92/BNP2
wandb: 🚀

In [39]:
let n_epochs = 1_000, opt = Optimise.ADAM(2f-4), ps = params(nf), batch_size=100
    with(logger) do
        @showprogress for epoch in 1:n_epochs
            s_shuffled, s′_shuffled = shuffle(data)
            for (iter, (s, s′)) in enumerate( eachbatch((s_shuffled, s′_shuffled); size=batch_size))
                local res
                gs = gradient(ps) do
                    res = lossof(s, s′)
                    res.loss
                end
                @info "train" res...
                Optimise.update!(opt, ps, gs)
            end
        end
    end
end

;

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:01:32[39m


In [44]:
let traj_true = data_raw.trajs_test[3], state = vec(stateof(first(traj_true))), T = 50
    function envfrom(state)
        dim = div(size(state, 1), 2)
        return Space(Particle.(fill(5e10, 3), state[1:dim], state[dim+1:end]))
    end
    traj = [envfrom(state)]
    for t in 1:T+20
        state = leapfrog(state, nf, args.dt)
        push!(traj, envfrom(state))
    end
    HTML(animof(traj).to_html5_video()) |> display
    traj_combined = [Space([objectsof(traj[t])..., objectsof(traj_true[t])...]) for t in 1:T+1]
    HTML(animof(traj_combined; refs=[4, 5, 6]).to_html5_video()) |> display
end

;

## MCMC

In [None]:
@model three_body(ms, states, dt) = begin
    q ~ MvNormal(zeros(6), 10)
    p ~ MvNormal(zeros(6), 10)
    space = Space(Particle.(ms, q, p))
    for i in 1:length(states)
        space = transition(space, dt)
        states[i] ~ MvNormal(vec(stateof(space)), 1e-1)
    end
end

@time chn = sample(three_body(massof.(data.objs0), data.states, args.dt), args.alg, args.n_samples)

splot(chn; colordim=:parameter) |> display

bson(datadir("three_body-noise=$(args.σ).bson"), chn=chn)

;

let n_mc = 50, T = 100, do_mean = true, res = get(chn[end-n_mc+1:end], [:q, :p])
    Q = Matrix{Float64}(hcat(res.q...)')
    P = Matrix{Float64}(hcat(res.p...)')
    if do_mean
        Q = mean(Q; dims=2)
        P = mean(P; dims=2)
    end
    Q̂ = zeros(6, T)
    for i in 1:size(Q, 2)
        q, p = Q[:,i], P[:,i]
        space = Space(Particle.(data.ms, q, p))
        traj = simulate(space, args.dt, T)
        Q̂ = Q̂ + hcat(positionof.(traj)...)
    end
    Q̂ = Q̂ / size(Q, 2)
    HTML(animof(Q̂).to_html5_video()) |> display
end

;