In [1]:
using LinearAlgebra
using Revise
using PyPlot, AxPlot, AxUtil, InferGMM
using Distributions, Random
using Formatting, Dates
using Flux
using Flux: Tracker, params
using Flux.Tracker: @grad

using Parameters, ArgCheck
using StatsFuns, ProgressMeter

include("../mtds-julia/mtdsutil.jl")

# Previous MTDS script

### -- Params

In [2]:
const N = 20
const TS_LEN = 80

### -- Utils

In [3]:
# utils
unsqueeze(xs, dim) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...));
issquare(A) = let sz = size(A); length(sz) == 2 && sz[1] == sz[2]; end
eye(d) = Float64.(Array(I ,d, d)) 
trans_matrix_2d(θ, ρ) = ρ .* [cos(θ) sin(θ); -sin(θ) cos(θ)];

function evaluate_state_deterministic(A, x0, timesteps)
    @assert issquare(A)
    d = size(A, 1)
    X = zeros(d, timesteps+1)
    X[:,1] = x0
    for tt = 1:timesteps
        X[:,tt+1] = A * X[:,tt]
    end
    return X
end;

Random.seed!(80)
const Ctrue = [1. 0.]
p_ρ_log = Uniform(4, 80)
p_θ = Uniform((6.75/360) * 2π, (36/360) * 2π)
x0 = [0 1]

d = 2   # reqd later?
seq_deterministic = []
seq_noise = []

rsob = AxUtil.Random.uniform_rand_sobol(N, [p_θ.a, p_θ.b], [p_ρ_log.a, p_ρ_log.b])
for nn = 1:N
    _A = trans_matrix_2d(rsob[nn,1], exp(log(0.5)/rsob[nn,2]))
    Xcur = evaluate_state_deterministic(_A, x0, TS_LEN-1)'*Ctrue'
    push!(seq_deterministic, Xcur[:])
    push!(seq_noise,  Xcur[:] .+ randn(TS_LEN)*0.05)
end

In [4]:
# Parameterise system matrix to have spectral radius <= 1
function cayley_suborthog(x, d)
    n_skew = Int(d*(d-1)/2)
    x_U, x_Vt, x_S = x[1:n_skew], x[n_skew+1:2*n_skew], x[2*n_skew+1:end]
    U = AxUtil.Math.cayley_orthog(x_U, d)
    Vt = AxUtil.Math.cayley_orthog(x_Vt, d)
    S = AxUtil.Flux.diag0(σ.(x_S))
    return U * S * Vt
end

function cayley_suborthog_easy(x, d)
    n_skew = Int(d*(d-1)/2)
    x_U, x_S = x[1:n_skew], x[1*n_skew+1:end]
    U = cayley_orthog(x_U, d)
    Vt = eye(d)
    S = diag0(σ.(x_S))
    return U * S * Vt
end

feat_extract(z) = reduce(vcat, [z ,sin.(z), cos.(z), sqrt.(sum(x->x^2, z, dims=1))])
feat_extract(z::TrackedArray) = Tracker.track(feat_extract, z)


function feat_extract_deriv(Δ, f, d)
    out = Δ[1:d,:]
    for dd in 1:d
        out[dd,:] .+= (f[d*2+dd,:] .* Δ[d+dd,:])   # cos z_dd * ∇_{d+dd}
        out[dd,:] .+= -(f[d*1+dd,:] .* Δ[2*d+dd,:])  # - sin z_dd * ∇_{2d+dd}
        out[dd,:] .+= (f[dd,:] .* Δ[3*d+1,:]) ./ (f[3*d+1,:] .+ 1e-16) # (z_dd /||z||) * ∇_{3d+1}
    end
    return out
end

@grad function feat_extract(z::AbstractVector)
    zd = Tracker.data(z)
    d = size(z, 1)
    f = feat_extract(zd)    
    return f, Δ->let g=feat_extract_deriv(Δ, f, d); (vec(g), ); end
end

@grad function feat_extract(z::AbstractMatrix)
    zd = Tracker.data(z)
    d = size(z, 1)
    f = feat_extract(zd)    
    return f, Δ->let g=feat_extract_deriv(Δ, f, d); (g, ); end
end

### --Model

In [5]:
# Only optimising nn and logσ for DHO1X problems. B is fixed.
struct DhoModel{T <: AbstractFloat, F <: Chain}
    nn::F
    logσ::Tracker.TrackedVector{T}
end

struct DhoModelUntracked{T <: AbstractFloat, F <: Chain}
    nn::F
    logσ::Vector{T}
end
pars(x::DhoModel) = params([x.nn, x.logσ]...)
make_untracked(x::DhoModel) = DhoModelUntracked(mapleaves(Tracker.data, x.nn), Tracker.data(x.logσ))

In [6]:
function get_pars(x::DhoModel)
    weights = Tracker.data.(params(x.nn))
    weights = vcat(map(vec, weights)...)
    return vcat(weights, x.logσ.data[:])
end

function set_pars!(x::DhoModel{T,F}, p::Vector{T}) where {T <: AbstractFloat, F <: Chain}
    weights = Tracker.data.(params(x.nn))
    sz = map(size, weights)
    csz = cumsum([0; map(prod, sz)])
    new_weights = [reshape(p[(csz[nn]+1):csz[nn+1]], sz[nn]...) for nn in 1:length(sz)]
    Flux.loadparams!(x.nn, new_weights)
    x.logσ.data .= p[end];
end

function get_grad(x::DhoModel)
    weights = Tracker.grad.(params(x.nn))
    weights = vcat(map(vec, weights)...)
    return vcat(weights, x.logσ.data[:])
end

function zero_grad!(x::DhoModel)
    ps = params(x.nn)
    for p in ps
        p.tracker.grad .= 0
    end
    x.logσ.grad .= 0
end;

In [285]:
using Flux.Tracker: TrackedVector

mutable struct lds_internal_notrk{T <: AbstractFloat}
    x::Vector{T}
    x_prev::Vector{T}
end


mutable struct lds_internal{T <: AbstractFloat}
    x::TrackedVector{T}
    x_prev::TrackedVector{T}
end

function A(ψ, d)
    n_skew = Int(d*(d-1)/2)
    x_S, x_V = ψ[1:d], ψ[d+1:2+n_skew]
    V = AxUtil.Math.cayley_orthog(x_V/10, d)
    S = AxUtil.Flux.diag0(tanh.(x_S))    # <= NOTICE SWITCHED FOR TANH
    return S * V
end

function B(ψ, d)
    n_skew = Int(d*(d-1)/2)
    ψ_b = ψ[(d+n_skew+1):(d+n_skew+d)]
    return identity.(ψ_b)
end

function C(ψ, d)
    n_skew = Int(d*(d-1)/2)
    ψ_b = ψ[(2*d+n_skew+1):(2*d+n_skew+d)]
    return reshape(ψ_b, 1, d)
end

function _harm_osc2d_state(λ::Vector{T}; tt=TS_LEN)::Matrix{T} where T <: AbstractFloat
    d = 2
    h0 = B(λ, d) * T(1)
    _A = A(λ, d)
    hidden = lds_internal_notrk{T}(h0, h0)
    function iterate(state)::Vector{T}
        state.x_prev = state.x
        state.x = _A*state.x; 
        state.x_prev
    end
    
    lds_state = [iterate(hidden) for t in 1:tt]
    return hcat(lds_state...)
end

function _harm_osc2d_state(λ::TrackedVector{T}; tt=TS_LEN)::TrackedMatrix{T} where T <: AbstractFloat
    d = 2
    h0 = B(λ, d) * T(1)
    _A = A(λ, d)
    lds_cell = AxUtil.Flux.LDSCell_simple(_A, h0)
    lds = Flux.Recur(lds_cell, h0, h0)  # Flux constructor can't find hidden method def outside Flux
    lds_state = [i == 1 ? lds.state : lds(1) for i in 1:tt]
    return hcat(lds_state...)
end

In [8]:
# function _model_forward(model::Union{DhoModel, DhoModelUntracked}, Z::AbstractMatrix{T}; tt=TS_LEN) where T <: AbstractFloat
#     d, M = size(Z, 1), size(Z, 2)   # explicit to ensure always gives 2nd dim (even if none := 1)
#     Λ = model.nn(feat_extract(Z))
#     Y = map(1:M) do m
#         x = _harm_osc2d_state(Λ[:,m]; tt=tt)
#         ŷ = C(Λ[:,m], d) * x
#     end
#     return reduce(vcat, Y)
# end

# NO FEATURE EXTRACTOR!
function _model_forward(model::Union{DhoModel, DhoModelUntracked}, Z::AbstractMatrix{T}; tt=TS_LEN) where T <: AbstractFloat
    d, M = size(Z, 1), size(Z, 2)   # explicit to ensure always gives 2nd dim (even if none := 1)
    Λ = model.nn(Z)
    Y = map(1:M) do m
        x = _harm_osc2d_state(Λ[:,m]; tt=tt)
        ŷ = C(Λ[:,m], d) * x
    end
    return reduce(vcat, Y)
end

# Union type dispatch is unavailable: use separate methods pointing to same function
(model::DhoModel)(Z::AbstractMatrix; tt::Int=TS_LEN) = Tracker.collect(_model_forward(model, Z; tt=tt))
(model::DhoModelUntracked)(Z::Matrix; tt::Int=TS_LEN) = _model_forward(model, Z; tt=tt)
(model::DhoModelUntracked)(Z::TrackedMatrix; tt::Int=TS_LEN) = Tracker.collect(_model_forward(model, Z; tt=tt))

# Vector input --> Matrix
(model::DhoModel)(Z::AbstractVector; tt::Int=TS_LEN) = Tracker.collect(_model_forward(model, unsqueeze(Z,2); tt=tt))
(model::DhoModelUntracked)(Z::Vector; tt::Int=TS_LEN) = _model_forward(model, unsqueeze(Z,2); tt=tt)
(model::DhoModelUntracked)(Z::TrackedVector; tt::Int=TS_LEN) = Tracker.collect(_model_forward(model, unsqueeze(Z,2); tt=tt))


### -- Optimisation

In [9]:
function sq_diff_matrix(X, Y)
    #=
    Constructs $(x_i - y_j)^T (x_i - y_j)$ matrix where
    X = Array(n_x, d)
    Y = Array(n_y, d)
    return: Array(n_x, n_y)
    =#
    normsq_x = sum(a -> a^2, X, dims=2)
    normsq_y = sum(a -> a^2, Y, dims=2)
    @assert size(normsq_x, 2) == 1 && size(normsq_y, 2) == 1
    out = normsq_x .+ normsq_y'    # outer
    @assert size(out) == (size(normsq_x, 1), size(normsq_y, 1))
    out .-= 2*X * Y'
    return out
end

function llh_weight_matrix(X, Xmu, Dlogdiag; incl_const=True)
    #=
    constructs log likelihood matrix of all combinations
    of all rows of X and Xmu (or 'Xhat') with diagonal cov
    for which the log diagonal vector `Dlogdiag' is specified.
    =#
    @assert ndims(Dlogdiag) == 1
    Dsqrtdiag = exp.(Dlogdiag*0.5)
    d = size(Dsqrtdiag, 1)
    X = X ./ Dsqrtdiag'      # consider in-place operation? e.g. https://github.com/simonbyrne/InplaceOps.jl
    Xmu = Xmu ./ Dsqrtdiag'  # consider in-place operation?
    out = -0.5 * sq_diff_matrix(X, Xmu)
    if incl_const
        out .+= -0.5*d*log(2π) - 0.5*sum(Dlogdiag)
    end
    return out
end

function snis_weight_matrix(X, Xhat, Dlogdiag)
    #=
    Self Normalised Importance Weights for full cross comparison btwn `X`
    and `Xhat`. The matrix is returned as numrows(X) x numrows(Xhat).
    
    While we're at it, it makes sense to calculate an estimate of the log
    probability, so the first return value is the SNIS matrix, the second
    is the approx log probability of each row.
    =#
    W = Matrix(llh_weight_matrix(X, Xhat, Dlogdiag, incl_const=true)')
    W, lse = AxUtil.Math.softmax_lse!(W)   # softmax the rows, calc logp
    W = W'  # softmax does cols, we want rows!
    return W, lse .- log(size(Xhat, 1))  # W, log(1/M sum_j exp(log_j))
end

function get_posterior_smp_matrix(model::DhoModelUntracked, n::Int, seq_obs; tt=TS_LEN)
    Zproposal = AxUtil.Random.sobol_gaussian(n, 2)'
    Xhat = model(Zproposal; tt=tt);
    Xtrue = reduce(vcat, map(x->x', seq_obs));
    w, lp = snis_weight_matrix(Xtrue, Xhat, ones(tt)*model.logσ[1]*2);
    return Zproposal, w, lp
end

get_log_px(model::DhoModelUntracked, n::Int, seq_obs; tt=TS_LEN) = get_posterior_smp_matrix(
    model, n, seq_obs; tt=tt)[3]

### -- LLH for inference

In [10]:
const prior_var = 1.0

function p_log_llh(model::Union{DhoModel, DhoModelUntracked}, Y::Vector, Z::AbstractVecOrMat; tt=TS_LEN)
    Ŷ = model(Z, tt=tt)
    @argcheck size(Y, 1) == size(Ŷ,2)
    Δ = (Ŷ .- Y') ./ exp(model.logσ[1])
    return -0.5*sum(Δ.^2, dims=2)[:]
end

function p_log_prior(Z)
    d = size(Z,1)
    @argcheck d == 2
    ZZ = Z .* 1/sqrt(prior_var)
    exponent = -0.5*sum(ZZ.^2, dims=1)
    return exponent[:]
end

function p_log_posterior_unnorm(model::DhoModelUntracked, Y::Vector, Z::AbstractVecOrMat; tt=TS_LEN)
    f_llh = p_log_llh(model, Y, Z; tt=tt)
    f_prior = p_log_prior(Z)
    return f_llh + f_prior
end

function p_log_posterior_unnorm_beta(model::DhoModelUntracked, Y::Vector, Z::AbstractVecOrMat, beta::AbstractFloat)
    f_llh = p_log_llh(model, Y, Z)
    f_prior = p_log_prior(Z)
    return f_prior + beta*f_llh, f_llh + f_prior
end

### -- Create Model

In [11]:
d = 2
d_out = (d+1) + d + d

In [12]:
d_out = 11

In [13]:
d_out = (d+1) + d + d

In [206]:
seq_deterministic = []
seq_noise = []

# Random.seed!(5921207)
Random.seed!(2220175)  # explored
# Random.seed!(3197275)  # this one has the low θ example (and not an esp low one!). Clearly highly curved. Suggest look at PCs.
# Random.seed!(20175)
N = 4
rsob = AxUtil.Random.uniform_rand_sobol(N, [p_θ.a, p_θ.b], [p_ρ_log.a, p_ρ_log.b])

for nn = 1:N
    _A = trans_matrix_2d(rsob[nn,1], exp(log(0.5)/rsob[nn,2]))
    Xcur = evaluate_state_deterministic(_A, x0, TS_LEN-1)'*Ctrue'
    push!(seq_deterministic, Xcur[:])
    push!(seq_noise,  Xcur[:] .+ randn(TS_LEN)*0.05)
end

d_nn = 300   # hidden layer of neural net link fn
d_nn_in = 2 # 7
# linkfn = Chain(Dense(d_nn_in, d_nn, σ), Dense(d_nn, d_out, identity));
linkfn = Chain(Dense(d_nn_in, d_nn, σ), Dense(d_nn, d_out, identity), Flux.Diagonal(d_out));
log_emission_std = [-1.5]

dho_model = DhoModel(linkfn, Flux.param(Float32.(log_emission_std)))
dho_model_nog = make_untracked(dho_model)

nsmps = 500
epochs = 500
n_est = 5
cpars = pars(dho_model)
prior_lsigma = Normal(-1.3, 0.04)

opt = ADAM(8e-4, (0.9,0.99));


## -- optimise

In [56]:
history6, opt, pps = evidence_optimisation_async(dho_model, dho_model_nog, seq_noise, cpars; epochs=100, 
                        n_est=5, nsmps=1000, prior_lsigma=prior_lsigma, norm_clip_amt=1000., opt=opt);

In [99]:
fig, axs = subplots(1,2,figsize=(9,4))
for i in 1:2
    axs[i].plot(history[:,1])
    axs[i].plot(history2[:,1])
    (i == 1) ? axs[i].plot(history3[:,1]) : axs[i].plot(0,0)
    axs[i].plot(history4[:,1])
    axs[i].plot(history5[:,1])
    axs[i].plot(history6[:,1])
end
axs[1].legend([L"$\eta(\cdot), C(\sigma)$", L"$(\cdot), C(\sigma)$", 
            L"$(\cdot), C(\sigma)$, async", L"$(\cdot), C(\cdot)$", L"$(\cdot), C(\cdot), elu$",
            L"$(\cdot), C(\cdot)$, no diaglayer"])
axs[2].set_ylim([20,34])

In [56]:
# history4, opt, pps = evidence_optimisation_async(dho_model, dho_model_nog, seq_noise, cpars; epochs=20, 
#                         n_est=5, nsmps=1000, prior_lsigma=prior_lsigma, norm_clip_amt=1000., opt=opt);

In [256]:
function evidence_optimisation_async(dho_model::DhoModel, dho_model_nog::DhoModelUntracked, seq_obs, pars; 
                               epochs=100, nsmps=1000, n_est=1, prior_lsigma=Normal(0, 100), norm_clip_amt=1000.,
                                opt=ADAM(1e-3), acc_step=false)
    # return opt, history, ...
    verbose = false
    N = length(seq_obs)
    history = zeros(epochs, 25)
    allpar = Any[]
    
    AxUtil.Flux.zero_grad!(pars)   # just in case have previously been held and have hidden gradient.
    cparval = [get_pars(dho_model), get_pars(dho_model)]
    
    g_ν = zeros(300)
    for i in 1:epochs
        
        # --- Backprop through posterior-integrated likelihood ---
        recon_loss = 0.
        
        for nn in 1:N
            # --- Generate approximate posterior samples ---
            Zproposal, w, logp_s = get_posterior_smp_matrix(dho_model_nog, nsmps, seq_obs[nn:nn]; tt=TS_LEN)
            logp = sum(logp_s)

            smps = AxUtil.Random.multinomial_indices_linear(n_est, w[:])
            smps = reduce(vcat, smps)
            verbose && println(smps')
            zsmps_nn = Zproposal[:, smps]
            # --- /end ---

            # neg. log likelihood / reconstruction
            loss = -sum(p_log_llh(dho_model, seq_obs[nn], zsmps_nn; tt=TS_LEN))
            loss /= n_est   # normalise for number of samples
            recon_loss += loss.data
            verbose && println(loss.data)
            
            # Likelihood normalising constant  (# not n*T*logσ as we have normalised by n already)
            loss += sum(TS_LEN*dho_model.logσ)

            # Log Normal prior on log sigma
            loss += 0.5*sum(x->x^2, dho_model.logσ .- prior_lsigma.μ)/prior_lsigma.σ^2
            verbose && println(loss.data)
            Tracker.back!(loss)
            
            # Regularise MLP weights
            loss += 1e-3*sum(abs, dho_model.nn.layers[1].W) / N
            loss += 1e-1*sum(abs, dho_model.nn.layers[2].W) / N

            for p in pars
                Tracker.update!(opt, p, -Tracker.grad(p))
            end
        end
        # --- /end ---
        
        __zp, __w, logp_s = get_posterior_smp_matrix(dho_model_nog, nsmps, seq_obs; tt=TS_LEN)
        logp = sum(logp_s)
        
        logrecon = -recon_loss - N * (TS_LEN*(dho_model.logσ.data[1] + log(2π)/2))
        
        history[i,1:2] = [logp, logrecon]/N
        if (i % 5 == 1) 
            print("Batch iter ", i, ", recon loss: ", sprintf1("%.3f", logrecon/N))
            println(", prev logp(x): ", sprintf1("%.2f", logp/N))
#             flush(stdout);
        end
    end
    
    return history, opt, allpar
end

In [100]:
function evidence_optimisation_sync(dho_model::DhoModel, dho_model_nog::DhoModelUntracked, seq_obs, pars; 
                               epochs=100, nsmps=1000, n_est=1, prior_lsigma=Normal(0, 100), norm_clip_amt=1000.,
                                opt=ADAM(1e-3), acc_step=false)
    # return opt, history, ...
    N = length(seq_obs)
    history = zeros(epochs, 25)
    allpar = Any[]
    
    AxUtil.Flux.zero_grad!(pars)   # just in case have previously been held and have hidden gradient.
    cparval = [get_pars(dho_model), get_pars(dho_model)]
    
    g_ν = zeros(300)
    nclips = 0
    for i in 1:epochs

        # --- Generate approximate posterior samples ---
        Zproposal, w, logp_s = get_posterior_smp_matrix(dho_model_nog, nsmps, seq_obs; tt=TS_LEN)
        logp = sum(logp_s)

        smps = [AxUtil.Random.multinomial_indices_linear(n_est, view(w, i, :)) for i ∈ 1:N]
        smps = reduce(vcat, smps)
        Zs_post = Zproposal[:, smps]
        # --- /end ---

        # --- Backprop through posterior-integrated likelihood ---
        epoch_loss = 0.
        recon_loss = 0.
        for nn in 1:N
            zsmps_nn = Zs_post[:, (nn-1)*n_est+1:nn*n_est]

            # neg. log likelihood / reconstruction
            loss = -sum(p_log_llh(dho_model, seq_obs[nn], zsmps_nn; tt=TS_LEN))
            loss /= n_est   # normalise for number of samples
            recon_loss += loss.data
            
            # Likelihood normalising constant  (# not n*T*logσ as we have normalised by n already)
            loss += sum(TS_LEN*dho_model.logσ)

            # Log Normal prior on log sigma
            loss += 0.5*sum(x->x^2, dho_model.logσ .- prior_lsigma.μ)/prior_lsigma.σ^2

            epoch_loss += loss.data
            Tracker.back!(loss)
        end
        # --- /end ---

#         nclips += AxUtil.Flux.normclip!(cpars, norm_clip_amt)
        nclips = 0

        # Regularise MLP weights
        loss += 1e-3*sum(abs, dho_model.nn.layers[1].W)
        loss += 1e-1*sum(abs, dho_model.nn.layers[2].W)
        
        norm_g = [norm(dho_model.nn.layers[2].W.grad[i,:]) for i in 1:7]
        ν_angle = (dho_model.nn.layers[2].W.grad[3,:]'*g_ν)/(norm(g_ν)*norm_g[3])
        g_ν = dho_model.nn.layers[2].W.grad[3,:]
        for p in pars
            Tracker.update!(opt, p, -Tracker.grad(p))
        end
#         newparval = get_pars(dho_model)
#         push!(allpar, newparval)
#         cos_optim = let g1=(cparval[2] - cparval[1]); g2=newparval-cparval[2]; sum(x->x^2, g1 .* g2)/(norm(g1)*norm(g2)); end
#         cparval = [cparval[2], newparval]
        
        logrecon = -recon_loss - N * (TS_LEN*(dho_model.logσ.data[1] + log(2π)/2))
        
#         mqnt(x) = [mean(x), quantile(x, 0.1), quantile(x, 0.9)]
#         _smps = dho_model_nog.nn(feat_extract(randn(2,1000)))
#         _ρs = vcat([mqnt(σ.(_smps[i,:])) for i in 1:2]...)
#         _θ = vcat([mqnt([2 .* atan.(t/10) .* 360 ./ 2π for t in _smps[i,:]]) for i in 3:3]...)
#         _B_s = vcat([mqnt(σ.(_smps[i,:])) for i in 4:5]...)
#         _C_s = vcat([mqnt(_smps[i,:]) for i in 4:5]...)
        
#         history[i, :] = vcat([logp, logrecon]/N,  _ρs, _θ, _B_s, _C_s, cos_optim, dho_model_nog.logσ[1])
        history[i,1] = logp/N
        history[i,1:25] = vcat([logp, logrecon]/N, norm_g, [ν_angle], g_ν[1:15])
        if (i % 5 == 1) 
            print("Batch iter ", i, ", recon loss: ", sprintf1("%.3f", logrecon/N))
            println(", prev logp(x): ", sprintf1("%.2f", logp/N), " (clip: ", "|"^nclips, ")")
            nclips = 0
            flush(stdout)
#             println(angle_stuff(dho_model_nog.nn)[:])
        end
    end
    
    return history, opt, allpar
end

# Can we do the same thing using the Mocap machinery?

In [23]:
DIR_MOCAP_MTDS = "." 

# Data loading and transformation utils
include(joinpath(DIR_MOCAP_MTDS, "io.jl"))

# MeshCat skeleton visualisation tools
include(joinpath(DIR_MOCAP_MTDS, "mocap_viz.jl"))

# Data scaling utils
include(joinpath(DIR_MOCAP_MTDS, "util.jl"))
import .mocaputil: MyStandardScaler, scale_transform, invert
import .mocaputil: OutputDifferencer, difference_transform, fit_transform
import .mocaputil: no_pos, no_poscp

# Models: LDS
include(joinpath(DIR_MOCAP_MTDS, "models.jl"))
import .model: Astable

# Table visualisation
include(joinpath(DIR_MOCAP_MTDS, "pretty.jl"))

In [395]:
# REMOVE B, b, AND D
drop_amt = 0.0

function model._make_lds_psi(s::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}},
        ψ::Union{AbstractVector{T}, TrackedVector{T}},
        η_h::Union{T, Vector{T}}=s.η_h) where {T <: Real, F <: Chain}
    d_state, d_out, d_in = size(s)
    ldsdims = model._partition_ldspars_dims(d_state, d_out, d_in, length(ψ))
    a, B, b, C, D, d = model.partition_ldspars(ψ, ldsdims, d_state, d_out, d_in)
    state = deepcopy(s.h)
    ldstype = model.has_grad(s) ? model.MyLDS_g{T} : model.MyLDS_ng{T}
    η₁ = model.arr2sc(s.η_h)
    m_B = 1 #rand() > drop_amt ? 1 : 0   # B dropout
    m_D = 0 #rand() > drop_amt ? 1 : 0   # D dropout
    return ldstype(η₁*a + s.a, η₁*B*m_B*T(0.1) + s.B*m_B, η₁*s.b*m_B, C + s.C, D*0, d*m_D, state)
end

In [24]:
############################################
##    CUSTOM WIDELY USED FUNCTIONS
function zero_grad!(P) 
    for x in P
        x.grad .= 0
    end
end

rmse(Δ::AbstractArray, scale=size(Δ, 1)) = sqrt(mean(x->x^2, Δ))*scale

function rmse(d::mocaputil.DataIterator, m::model.MyLDS_ng)
    obj = map(d) do (y, u, new_state)
        new_state && (m.h .= zeros(size(m, 1))) 
        rmse(m(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end


rmse(Ds::Vector{D}, m::model.MyLDS_ng) where {D <: Dict} = rmse(mocaputil.DataIterator(Ds, 1000000), m)
############################################

In [25]:
# Initialise LDS
clds_orig = model.MyLDS_ng{Float32}(f32([zeros(d); zeros(Int(d*(d-1)/2))]), f32(zeros(d, 1)), f32(zeros(d)),
            f32(zeros(1, d)), f32(zeros(1, 1)), f32(zeros(1)), f32([0,0]))

In [745]:
k = 2                 # dimension of manifold
d_nn = 300            # "complexity" of manifold

d_par = length(model.get_pars(clds_orig))
nn = Chain(Dense(k, d_nn, σ), 
           Dense(d_nn, d_par, identity, initW = ((dims...)->Flux.glorot_uniform(dims...)*1.0)),
           Flux.Diagonal(d_par))
clogσ = repeat([-1.5f0], size(clds_orig, 2))

cmtlds_g = model.mtldsg_from_lds(clds_orig, nn, clogσ, 1.0f0);
cmtlds = model.make_nograd(cmtlds_g);
# model.change_relative_lr!(cmtlds, 0.001f0)   # reduce sensitivity of chain params. 

In [746]:
clds = model.make_lds(cmtlds, randn(Float32, 2), 0.1f0);
seq_noise32 = [reshape(Float32.(y), 1, :) for y in seq_noise];
Us32 = [hcat(1f0, zeros(Float32, 1, TS_LEN-1)) for i in 1:length(seq_noise32)];

In [747]:
model.change_relative_lr!(cmtlds, 0.10f0)   # reduce sensitivity of chain params. 
cmtlds_g.η_h .= 0.10f0

In [748]:
# Random.seed!(15201)
# opt = ADAM(8e-4, (0.9,0.99));
nsmps = 300
epochs = 30
n_est = 3
prior_lsigma = Normal(-1.3, 0.04)

# opt = ADAM(4e-4, (0.9,0.99));
# opt.eta /=2
# opt.eta = 1e-10
opt.eta = 8e-4
history6, opt, pps = evidence_optimisation_new(cmtlds_g, seq_noise32, Us32[1]; epochs=epochs, 
                        n_est=n_est, nsmps=nsmps, prior_lsigma=prior_lsigma, opt=opt);

In [641]:
zero_grad!(cpars)

In [287]:
# Random.seed!(15201)
# opt.eta = 2e-3
# cpars = pars(dho_model)
opt.eta = 1e-4
epochs=100
nsmps = 400
history6, opt, pps = evidence_optimisation_async(dho_model, dho_model_nog, seq_noise, cpars; epochs=epochs, 
                        n_est=n_est, nsmps=nsmps, prior_lsigma=prior_lsigma, norm_clip_amt=1000., opt=opt);

In [145]:
Random.seed!(15201)
get_posterior_smp_matrix(dho_model_nog, nsmps, seq_noise[1:1]; tt=TS_LEN)[2]

In [146]:
Random.seed!(15201)
sample_posterior(cmtlds, seq_noise32[1], Us32[1], nsmps)[1]'

In [None]:
get_posterior_smp_matrix(dho_model_nog, nsmps, seq_noise[1:1]; tt=TS_LEN)

In [156]:
sample_posterior(cmtlds, seq_noise32[1], Us32[1], 3, f32(_zs[:, [193, 180, 141]]))

In [169]:
p_log_llh(cmtlds_g, seq_noise32[1], Us32[1], f32(_zs[:, [193, 180, 141]]))[1]

In [562]:
AxUtil.Random.multinomial_indices_linear(3, vec(w2))

In [555]:
rand(Categorical(vec(w2)), 3)

In [716]:
function evidence_optimisation_new(cmtlds_g::model.MTLDS_g, seq_obs, U, pars=model.pars(cmtlds_g); 
                               epochs=100, nsmps=1000, n_est=1, prior_lsigma=Normal(0, 100), norm_clip_amt=1000.,
                                opt=ADAM(1e-3), acc_step=false, dropamt=0)
    
    verbose = false
    cmtlds = model.make_nograd(cmtlds_g)
    T = eltype(cmtlds)
    N = length(seq_obs)
    history = zeros(epochs, 2)
    allpar = Any[]
    
    model.zero_grad!(cmtlds_g)
#     cparval = [get_pars(dho_model), get_pars(dho_model)]
    
    g_ν = zeros(300)
    nclips = 0
    
    for i in 1:epochs
    # --- Backprop through posterior-integrated likelihood ---
        recon_loss = zero(T)
        logp = zero(T)
        
        for nn in 1:N
            # --- Generate approximate posterior samples ---
            w, logp_s, Zproposal = sample_posterior(cmtlds, seq_obs[nn], U, nsmps)
            logp += sum(logp_s)
            
#             smps = AxUtil.Random.multinomial_indices_linear(n_est, w[:])
#             smps = reduce(vcat, smps)
#             verbose && println(smps')
#             zsmps_nn = Zproposal[:, smps]
#             Zs_wgt = ones(T, n_est)
            smps = AxUtil.Random.multinomial_indices_linear(n_est, w[:])
            # often duplicates (e.g. m_bprop of same sample), esp nr beginning. Aggregate to improve efficiency.
            smps, smp_wgt = countmap(smps) |> x-> (collect(keys(x)), collect(values(x)))
            zsmps_nn = Zproposal[:, smps]
#             verbose && println(zsmps_nn')
            Zs_wgt  = T.(smp_wgt)
            # --- /end ---

            # neg. log likelihood / reconstruction
            llh, _state = p_log_llh(cmtlds_g, seq_obs[nn], U, zsmps_nn, Zs_wgt)
            verbose && display(llh)
            loss   = - sum(llh) /n_est  # *decrease* *negative* llh.
#             println(loss.data)
            recon_loss += -loss.data - TS_LEN*(sum(cmtlds.logσ) + size(cmtlds,2)*log(2π)/2)

            # Likelihood normalising constant  (# not n*T*logσ as we have normalised by n already)
            loss += TS_LEN*sum(cmtlds.logσ)
            verbose && println(loss.data)
            
            # Log Normal prior on log sigma
            loss += 0.5*sum(x->x^2, cmtlds.logσ .- prior_lsigma.μ) ./ prior_lsigma.σ^2
            verbose && println(loss.data)
            
            Tracker.back!(loss)
            
            # Regularise MLP weights
            loss += 1e-3*sum(abs, cmtlds_g.nn.layers[1].W) / N   # careful of doing layer 3 (lr rescale)
#             loss += 1e-1*sum(abs, cmtlds_g.nn.layers[2].W) / N
            
#             display(Tracker.grad(cmtlds_g.nn.layers[end].α))
            for p in pars
                Tracker.update!(opt, p, -Tracker.grad(p))
            end
        end
        # --- /end ---

        history[i,1:2] = [logp, recon_loss]/N
        if (i % 1 == 0) 
            print("Batch iter ", i, ", recon loss: ", sprintf1("%.3f", recon_loss/N))
            println(", prev logp(x): ", sprintf1("%.2f", logp/N))
            flush(stdout)
        #             println(angle_stuff(dho_model_nog.nn)[:])
        end
        
    end
    # --- /end ---
    
    return history, opt, allpar
end

In [349]:
function sample_forward(mtlds::model.MTLDS_ng{T,F}, U::AbstractArray{T}, M::Int, ϵ::AbstractMatrix{T}) where {T, F}
    return [model.make_lds(mtlds, view(ϵ, :, i), mtlds.η_h)(U) for i in 1:M]
end

function sample_forward(mtlds::model.MTLDS_ng{T,F}, U::AbstractArray{T}, M::Int) where {T, F}
    k = size(mtlds.nn.layers[1].W, 2)
    ϵ = convert(Array{T}, AxUtil.Random.sobol_gaussian(M, k)')
    return (sample_forward(mtlds, U, M, ϵ)..., ϵ)
end


@inline _gauss_lognormconst(logσ::Vector, tt) = 0.5*tt*length(logσ)*log(2π) + 0.5*tt*sum(2*logσ)

#= slightly faster version for when a single Y (can do subtraction in-place).
   This is called much more frequently during training (i.e. before every couple of batches) =#
function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Y::AbstractArray{T}, 
            U::AbstractArray{T}, M::Int, ϵ::AbstractArray{T}) where {T, F}
    
    Ŷ = sample_forward(mtlds, U, M, ϵ)
    
    # Calculate density of each sample Ŷ
    Δnorm = Vector{T}(undef, M)
    precision = 1 ./ exp.(2*mtlds.logσ)
    for i in 1:M
        @views Ŷ[i] .-= Y
        Δnorm[i] = dot(sum(x->x^2, Ŷ[i], dims=2), precision)
    end

    # Elementwise logpdf ==> single logpdf and importance weights
    el_lpdf = -0.5 * Δnorm .- _gauss_lognormconst(mtlds.logσ, size(Y,2))
    W, lgpdf = AxUtil.Math.softmax_lse!(reshape(el_lpdf, :, 1))   # softmax AND logsumexp
    return W, lgpdf[1] .- log(M)  # ϵ, W, log(1/M sum_j exp(log_j))
end

function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Y::AbstractArray{T}, 
            U::AbstractArray{T}, M::Int) where {T, F}
    k = size(mtlds.nn.layers[1].W, 2)
    ϵ = convert(Array{T}, AxUtil.Random.sobol_gaussian(M, k)')
    return (sample_posterior(mtlds, Y, U, M, ϵ)..., ϵ)
end


# slower version for *series of matrices* for Y (can no longer do in-place) 
function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Ys::AbstractArray{MT}, 
            Us::AbstractArray{MT}, M::Int) where {T, F, MT <: Matrix}
    
    k = size(mtlds.nn.layers[1].W, 2)
    
    # common r.v.s
    ϵ = convert(Array{T}, AxUtil.Random.sobol_gaussian(M, k)')
    W, lgpdf = sample_posterior(mtlds, Ys, Us, M, ϵ)
    
    return Matrix(W'), lgpdf, ϵ  # W, log(1/M sum_j exp(log_j)), ϵ
end

function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Ys::AbstractArray{MT}, 
            Us::AbstractArray{MT}, M::Int, ϵ::AbstractArray{T}) where {T, F, MT <: Matrix}
    
    N = length(Ys)
    W = Matrix{T}(undef, N, M)
    lgpdf = Vector{T}(undef, N)
    
    for nn in 1:N
        _W, _lp = sample_posterior(mtlds, Ys[nn], Us[nn], M, ϵ)
        W[nn,:] = _W[:]
        lgpdf[nn] = _lp
    end
    
    return Matrix(W'), lgpdf  # W, log(1/M sum_j exp(log_j)), ϵ
end

function get_log_px(model::model.MTLDS_ng{T,F}, Y::AbstractArray, U::AbstractArray, m::Int) where {T,F}
    sample_posterior(model, Y, U, m)[2]
end

In [33]:
const prior_var = 1.0
"""
    p_log_llh(mtlds, Y, U, Z, wgt=ones(T, size(Z, 2)))
Calculate the log likelihood of the LDS models corresponding to the samples
`Z` (Vector or Column Matrix) within `mtlds`. `wgt` corr. to importance
weights of `Z`. (Under resampling, typically integral valued).
"""
function p_log_llh(mtlds::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}}, 
        Y::AbstractMatrix{T}, U::AbstractMatrix{T}, Z::AbstractVecOrMat{T}, 
        wgt::AbstractVector{T}=ones(T, size(Z, 2))) where {T,F}
    Ψ = mtlds.nn(Z)
    precision = 1 ./ exp.(2*mtlds.logσ)
    states = Matrix{T}(undef, size(mtlds,1), size(Z,2))
    
    llh = map(1:size(Z,2)) do i
        lds = model._make_lds_psi(mtlds, Ψ[:,i], mtlds.η_h)
        X   = model.state_rollout(lds, U)
        Ŷ   = lds.C * X + lds.D * U .+ lds.d;
        states[:,i] = Tracker.data(X)[:,end]
        - wgt[i] * dot(sum(x->x^2, Y - Ŷ, dims=2), precision)/2
    end
    return llh, states * (wgt/sum(wgt))
end

function p_log_prior(Z)
    ZZ = Z .* 1/sqrt(prior_var)
    exponent = -0.5*sum(ZZ.^2)
    return exponent
end

function p_log_posterior_unnorm(mtlds::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}}, 
        Y::AbstractMatrix{T}, U::AbstractMatrix{T}, Z::AbstractVecOrMat{T}) where {T,F}
    f_llh = p_log_llh(mtlds, Y, U, Z)
    f_prior = p_log_prior(Z)
    return f_llh + f_prior
end

function p_log_posterior_unnorm_beta(mtlds::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}}, 
        Y::AbstractMatrix{T}, U::AbstractMatrix{T}, Z::AbstractVecOrMat{T}, beta::T) where {T,F}
    f_llh = p_log_llh(mtlds, Y, U, Z)
    f_prior = p_log_prior(Z)
    return f_prior + beta*f_llh, f_llh + f_prior
end

In [508]:
cmtlds.a

In [498]:
model.Astable(cmtlds.a, 2)

In [535]:
DhoModel(cmtlds.nn, Flux.param(cmtlds.logσ))

## Comparison to dhoModel code

In [909]:
model.change_relative_lr!(cmtlds, 0.10f0)   # reduce sensitivity of chain params. 
cmtlds_g.η_h .= 0.10f0

In [35]:
function convert_mtlds_to_dhomodel(cmtlds::model.MTLDS_ng)
    cmtlds_cp = copy(cmtlds)
    model.change_relative_lr!(cmtlds_cp, 1.0f0)
    nn = cmtlds_cp.nn
    @argcheck length(nn(randn(eltype(cmtlds), 2))) == 11
    rm_Bb = setdiff(1:11, [6,7,10,11])
    new_ixs = vcat(rm_Bb, 6,7,10,11)
    @argcheck length(nn.layers) == 3
    @argcheck nn.layers[3] isa Flux.Diagonal
    L2 = nn.layers[end-1]
    L2 = Dense(L2.W[new_ixs,:], L2.b[new_ixs,:], L2.σ)
    L3 = nn.layers[end]
    L3 = Flux.Diagonal(L3.α[new_ixs], L3.β[new_ixs])
    nn = Chain(nn.layers[1], L2, L3)
#     nn.layers[end-1].b = nn.layers[end-1].b[rm_Bb,:]
#     nn.layers[end].α = nn.layers[end].α[rm_Bb,:]
#     nn.layers[end].β = nn.layers[end].β[rm_Bb,:]
    dho_model = DhoModel(nn, Flux.param(cmtlds_cp.logσ))
    dho_model_nog = make_untracked(dho_model)
    return dho_model_nog
end
#     

In [37]:
_eps = randn(2);

In [265]:
clds = model.make_lds(cmtlds, f32(_eps), cmtlds.η_h);
dho_model_nog = convert_mtlds_to_dhomodel(cmtlds)
dho_model = DhoModel(mapleaves(Flux.param, dho_model_nog.nn), Flux.param(dho_model_nog.logσ))
dho_model_nog = make_untracked(dho_model)

In [109]:
plot(clds(Us32[1])' - dho_model_nog(_eps)'); gcf().set_size_inches(8,2.5)

In [293]:
Zproposals, w, logp = get_posterior_smp_matrix(dho_model_nog, 200, seq_noise[1:4]; tt=80)

In [231]:
mean(get_log_px(cmtlds, seq_noise32[1:4], Us32[1:4], 100))

In [473]:
w2, logp2 = sample_posterior(cmtlds, seq_noise32[1:4], Us32[1:4], 200, f32(Zproposals));

In [234]:
# DHO version vs new version. O(1e-6) error expected due to f32.
logp - logp2

In [235]:
w2' - w

In [1214]:
-0.5*sum(x->x^2,(vec(model.make_lds(cmtlds, f32(Zproposals[:,1]), 1.0f0)(Us32[1])) - vec(seq_noise32[1])) / 
    exp(cmtlds.logσ[1]))

In [None]:
let Δ=(dho_model_nog(Zproposals[:,1]) - seq_noise[1]') ./  exp(Tracker.data(dho_model.logσ[1]));
    -0.5*sum(Δ.^2, dims=2)[:]
end

In [89]:
p_log_llh(cmtlds, seq_noise32[1], Us32[1], f32(Zproposals)[:,1:3])

In [90]:
p_log_llh(dho_model_nog, seq_noise[1], Zproposals[:,1:3]; tt=80)

In [None]:
function p_log_llh(model::Union{DhoModel, DhoModelUntracked}, Y::Vector, Z::AbstractVecOrMat; tt=TS_LEN)
    Ŷ = model(Z, tt=tt)
    Δ = (Ŷ .- Y') ./ exp(Tracker.data(model.logσ[1]))
    return -0.5*sum(Δ.^2, dims=2)[:]
end

In [292]:
w

In [295]:
fig, axs = subplots(2,2,figsize=(5,5))
for i in 1:2, j in 1:2
    AxPlot.scatter_alpha(Zproposals[1,:], Zproposals[2,:], w[(i-1)*2+j,:], ax=axs[i,j])
end

In [237]:
fig, axs = subplots(2,2,figsize=(5,5))
for i in 1:2, j in 1:2
    AxPlot.scatter_alpha(Zproposals[1,:], Zproposals[2,:], w2'[1,:], ax=axs[i,j])
end

In [None]:
# # ================== REDEFINE w/o /10 ============
# function B(ψ, d)
#     n_skew = Int(d*(d-1)/2)
#     ψ_b = ψ[(d+n_skew+1):(d+n_skew+d)]
#     return identity.(ψ_b)
# end

# # ================== ADDL ============
# function b(ψ, d)
#     n_skew = Int(d*(d-1)/2)
#     ψ_b = ψ[(3*d+n_skew+1):(3*d+n_skew+d)]
#     return reshape(ψ_b, d)
# end

# function d_offset(ψ, d)
#     n_skew = Int(d*(d-1)/2)
#     ψ_d = ψ[(4*d+n_skew+2):(4*d+n_skew+2)]
#     return ψ_d
# end

# function D(ψ, d)
#     n_skew = Int(d*(d-1)/2)
#     ψ_d = ψ[(4*d+n_skew+1):(4*d+n_skew+1)]
#     return ψ_d
# end
# # ======================================

# function _harm_osc2d_state(λ::Vector{T}; tt=TS_LEN)::Matrix{T} where T <: AbstractFloat
#     d = 2
#     h0 = B(λ, d) * T(1)
#     _A = A(λ, d)
#     _b = b(λ, d)   # addl
#     hidden = lds_internal_notrk{T}(h0 + _b, h0 + _b)   # addl
#     function iterate(state)::Vector{T}
#         state.x_prev = state.x
#         state.x = _A*state.x .+ _b ;    # addl
#         state.x_prev
#     end
    
#     lds_state = [iterate(hidden) for t in 1:tt]
#     return hcat(lds_state...)
# end

# function _model_forward(model::Union{DhoModel, DhoModelUntracked}, Z::AbstractMatrix{T}; tt=TS_LEN) where T <: AbstractFloat
#     d, M = size(Z, 1), size(Z, 2)   # explicit to ensure always gives 2nd dim (even if none := 1)
#     Λ = model.nn(Z)
#     Y = map(1:M) do m
#         x = _harm_osc2d_state(Λ[:,m]; tt=tt)
#         offset = d_offset(Λ[:,m], d)
#         ŷ = C(Λ[:,m], d) * x + D(Λ[:,m], d)*hcat(1f0, zeros(Float32, 1, 79)) .+ offset
#     end
#     return reduce(vcat, Y)
# end