# Training Reference Notebook

In [None]:
# using Revise
using LinearAlgebra, Random
using StatsBase, Statistics
using Distributions, MultivariateStats   # Categorical, P(P)CA
using Quaternions    # For manipulating 3D Geometry
using MeshCat        # For web visualisation / animation
using PyPlot         # Plotting
using AxUtil         # Cayley, skew matrices
using Flux, CuArrays # Optimisation
using DSP            # convolution / low-pass (MA) filter

# small utils libraries
using ProgressMeter, Formatting, ArgCheck, Dates
using BSON

In [None]:
DIR_MOCAP_MTDS = "." 

# Data loading and transformation utils
include(joinpath(DIR_MOCAP_MTDS, "io.jl"))

# MeshCat skeleton visualisation tools
include(joinpath(DIR_MOCAP_MTDS, "mocap_viz.jl"))

# Data scaling utils
include(joinpath(DIR_MOCAP_MTDS, "util.jl"))

# Models: LDS
include(joinpath(DIR_MOCAP_MTDS, "models.jl"))

# Table visualisation
include(joinpath(DIR_MOCAP_MTDS, "pretty.jl"))

In [None]:
############################################
##    CUSTOM WIDELY USED FUNCTIONS
function zero_grad!(P) 
    for x in P
        x.grad .= 0
    end
end

const NoGradModels = Union{model.MyLDS_ng, model.ORNN_ng}
const _var_cache = IdDict()

mse(Δ::AbstractArray, scale=size(Δ, 1)) = mean(x->x^2, Δ)*scale

function mse(d::mocaputil.DataIterator, m::NoGradModels)
    obj = map(d) do (y, u, new_state)
        new_state && (m.h .= zeros(size(m, 1))) 
        mse(m(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end


mse(Ds::Vector{D}, m::NoGradModels) where {D <: Dict} = mse(mocaputil.DataIterator(Ds, 1000000), m)
mse(D::Dict, m::NoGradModels) = mse(m(D[:U]) - D[:Y])
mse(V::Tuple, m::NoGradModels) = mse(m(V[2]) - V[1])

# Calculate variance
function _calc_var!(cache::IdDict, d::mocaputil.DataIterator)
    Y = reduce(hcat, [y for (y, u, h) in d])
    _var_cache[d] = var(Y, dims=2)
end

function _calc_var!(cache::IdDict, d::Vector{D}) where {D <: Dict}
    Y = reduce(hcat, [dd[:Y] for dd in d])
    _var_cache[d] = var(Y, dims=2)
end

function Statistics.var(d::Union{mocaputil.DataIterator, Vector{D}}) where {D <: Dict}
    !haskey(_var_cache, d) && _calc_var!(_var_cache, d)
    return _var_cache[d]
end
Statistics.var(d::Dict) = var(d[:Y], dims=2)

# Standardised MSE
smse(Δ::AbstractArray, scale=size(Δ, 1)) = mse(Δ, scale) / sum(var(Δ, dims=2))

smse(d::mocaputil.DataIterator, m::NoGradModels) = mse(d, m) / sum(var(d))
smse(D::Dict, m::NoGradModels) = mse(m(D[:U]) - D[:Y]) / sum(var(D))
smse(Ds::Vector{D}, m::NoGradModels) where {D <: Dict} = mse(mocaputil.DataIterator(Ds, 1000000), m) / sum(var(Ds))
smse(D::Tuple, m::NoGradModels) = mse(D, m) / sum(var(D[1], dims=2))

rsmse(args...) = sqrt(smse(args...))

In [None]:
function mse(d::mocaputil.DataIterator, m::model.MTLDS_ng, z::AbstractArray)
    @argcheck size(z, 2) == length(d)
    obj = map(enumerate(d)) do (ii, (y, u, new_state))
        new_state && (m.h .= zeros(size(m, 1))) 
        cmodel = model.make_lds(m, z[:,ii], m.η_h)
        mse(cmodel(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end


function mse(d::mocaputil.DataIterator, m::model.ORNN_ng, z::AbstractArray, nn::Chain)
    @argcheck size(z, 2) == length(d)
    obj = map(enumerate(d)) do (ii, (y, u, new_state))
        new_state && (m.h .= zeros(size(m, 1))) 
        cmodel = model.make_rnn_psi(m, Tracker.data(nn(z[:,ii])), 1f0)
        mse(cmodel(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end

smse(d::mocaputil.DataIterator, m::model.MTLDS_ng, z::AbstractArray) = mse(d, m, z) / sum(var(d))
smse(d::mocaputil.DataIterator, m::model.ORNN_ng, z::AbstractArray, nn::Chain) = mse(d, m, z, nn) / sum(var(d))

### Load in Data
See `2_Preprocess.ipynb`

**Note that in the current harddisk state**,
* `edin_Ys_30fps.bson` was created with `include_ftcontact=false, fps=30`,
* `edin_Xs_30fps.bson` was created with `include_ftcontact=true, include_ftmid=true, joint_pos=false, fps=fps, speed=false`.

In [None]:
# task descriptors
styles_lkp = BSON.load("styles_lkp")[:styles_lkp];

In [None]:
# Load in data
Usraw = BSON.load("edin_Xs_30fps.bson")[:Xs];
Ysraw = BSON.load("edin_Ys_30fps.bson")[:Ys];

In [None]:
# Standardise inputs and outputs
standardize_Y = fit(mocaputil.MyStandardScaler, reduce(vcat, Ysraw),  1)
standardize_U = fit(mocaputil.MyStandardScaler, reduce(vcat, Usraw),  1)

Ys = [mocaputil.scale_transform(standardize_Y, y[2:end, :] ) for y in Ysraw];  # (1-step ahead of u)
Us = [mocaputil.scale_transform(standardize_U, u[1:end-1,:]) for u in Usraw];  # (1-step behind y)

@assert (let c=cor(Usraw[1][1:end-1, :], Ysraw[1][2:end, :], dims=1); 
        !isapprox(maximum(abs.(c[.!isnan.(c)])), 1.0); end) "some input features perfectly correlated"

# to invert: `mocaputil.invert(standardize_Y, y)`

In [None]:
# SENSE CHECK
# check that no bugs in constructing U, Y (i.e. esp that t's align and can predict U --> Y)
let c=cor(reduce(vcat, Us) |>f64, reduce(vcat, Ys) |> f64, dims=1)
    imshow(c, aspect="auto")
    nonan_c = c[.!isnan.(c)]
    title(format("max (abs) corrcoeff: {:.8f}", maximum(abs.(nonan_c))))
    flush(stdout)
    display(findmax(reshape(nonan_c, size(c, 1) - 2, size(c,2))))
    printfmtln("10th best result {:.5f}", reverse(sort(nonan_c))[10]) 
end
colorbar()

In [None]:
expmtdata = mocapio.ExperimentData(Ysraw, [Matrix(y') for y in Ys], 
    [Matrix(u') for u in Us], styles_lkp);
# see ?mocapio.get_data

# MT-LDS (Hard-EM) experiment

#### Setup data

In [None]:
# Get training set for STL and pooled models.
style_ix = 1
d = 10;

In [None]:
trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, :pooled);

In [None]:
# construct batch iterator
batch_size = 64
min_size = 50
trainIter = mocaputil.DataIterator(trainPool, batch_size, min_size=min_size);

In [None]:
# style segment lookups
style_names = ["angry", "childlike", "depressed", "neutral", "old", "proud", "sexy", "strutting"];
segment_lkp = [length(mocaputil.DataIterator(mocapio.get_data(expmtdata, i, :train, :stl, split=[0.875,0.125]),
            batch_size, min_size=50)) for i in 1:7];
segment_lkp = [collect(i+1:j) for (i,j) in zip(vcat(0, cumsum(segment_lkp[1:end-1])), cumsum(segment_lkp))];

#### Base model

In [None]:
# base model
clds_orig = model.init_LDS_spectral(trainPool[1][:Y][:,1:200], trainPool[1][:U][:,1:200], d)  # whatever...
clds_g = model.make_grad(clds_orig)
clds   = model.make_nograd(clds_g);   # MUST DO THIS SECOND, (Flux.param takes copy)

#### Multi-task manifold

In [None]:
k = 2                 # dimension of manifold
d_nn = 200            # "complexity" of manifold
d_subspace = 20       # dim of subspace (⊆ parameter space) containg the manifold
d_out = size(clds,2);

In [None]:
d_par = [length(x) for x in model.get_pars(clds)] |> sum
nn = Chain(Dense(k, d_nn, tanh), Dense(d_nn, d_subspace, identity), 
    Dense(d_subspace, d_par,identity, initW = ((dims...)->Flux.glorot_uniform(dims...)*0.5f0)))
nn_ng = mapleaves(Tracker.data, nn)
Zmap = Flux.param(randn(Float32, k, length(trainIter))*0.01f0);

#### MT-LDS

In [None]:
model.zero!(clds)
clogσ = repeat([0f0], size(clds, 2))
cmtlds_g = model.mtldsg_from_lds(clds, nn, clogσ, 0.1f0);
cmtlds = model.make_nograd(cmtlds_g);

#### Optimisation

In [None]:
opt = ADAM(1e-4)
pars = Flux.params(nn, Zmap);

In [None]:
n_epochs = 200
opt.eta = 1e-3  # ... 8e-5
shuffle_examples = true

nB = length(trainIter)
W = mocaputil.weights(trainIter; as_pct=false) ./ batch_size #|> gpu

history = ones(n_epochs*nB) * NaN

for ee in 1:n_epochs
    if shuffle_examples
        mtl_ixs, trainData = mocaputil.indexed_shuffle(trainIter)
    else
        mtl_ixs, trainData = 1:length(trainIter), trainIter
    end
    for (ii, (Yb, Ub, h0)) in zip(mtl_ixs, trainData)
        h0 && model.zero_state!(cmtlds)
        Tb = size(Yb, 2)      # not constant
        
        ψ = cmtlds_g.nn(Zmap[:,ii])
        lds = model._make_lds_psi(cmtlds_g, ψ, cmtlds_g.η_h)
        X   = model.state_rollout(lds, Ub)
        Ŷ   = lds.C * X + lds.D * Ub .+ lds.d;
        cmtlds.h = Tracker.data(X)[:,end]   # truncated backprop
        
        obj = mean(x->x^2, Yb - Ŷ) * 8^2 * W[ii]

        # Prior penalty (weak - avoid collapse --> 0, ∵ Hard-EM indeterminacy)
        obj += 0.02*sum(Zmap[:,ii] .* Zmap[:,ii])
        
        Tracker.back!(obj)
        history[(ee-1)*nB + ii] = obj.data
        
        if ii % 34 == 0
            # regularisation
            for layer in nn.layers
                obj += 1e-3*sum(abs, layer.W)
                obj += 1e-3*sum(abs, layer.b)
            end
            
            for p in pars
                Tracker.update!(opt, p, Tracker.grad(p))
            end
        end
    end
    
    (ee % 1 == 0) && println(sqrt(mean(history[(1:nB) .+ nB*(ee-1)]))); flush(stdout)

end 

#### Plot optimisation progress

In [None]:
plot(sqrt.(DSP.conv(history, Windows.rect(nB))[nB:end-nB+1]/nB))

#### Plot latent space

In [None]:
ax = gca()
for i in 1:7
    ixs = segment_lkp[i]
    z = Zmap.data[:, ixs]
    ax.scatter(z[1,:], z[2,:], color=ColorMap("tab10")(i-1), alpha=0.5)
end
legend(style_names[(1:7) .+ 1])

#### Save model

In [None]:
error("safeguard")
BSON.bson(format("cmtlds{:d}_pool_{:d}{:02d}_v{:d}.bson", 
        d, day(today()), month(today()), batch_size), m=cmtlds, Zmap=Zmap.data);

#### Global optimisation of latents

In [None]:
nsmp = 600
cholesky(cov(Zmap.data')).U * f32(AxUtil.Random.sobol_gaussian(nsmp, 2)');

In [None]:
# populate error matrix with above samples
res = ones(Float32, length(trainIter), 600)
for i in 1:nsmp
    lds = model.make_lds(cmtlds, _Zsmp[:,i], cmtlds.η_h)
    for (n, (Yb, Ub, h0)) in enumerate(trainIter)
        h0 && model.zero_state!(lds)
        ŷ = lds(Ub)
        res[n,i] = mean(x->x^2, Yb - ŷ)
    end
end;

In [None]:
# find MAP of from implicit posterior 
pz = softmax(-32*(Matrix(res')))   # note that this is much less peaked than it should be=> mean not sum.
z_smpopt = copy(Zmap.data)
for i in 1:length(trainIter)
    z_smpopt[:,i] = _Zsmp[:, argmax(pz[:,i])]
end

In [None]:
# plot to compare with current position
ax = gca()
# ax.scatter(_Zsmp[:,1], _Zsmp[:,2], alpha=0.1)
for i in 1:7
    ixs = segment_lkp[i]
    z = z_smpopt[:, ixs] .+ randn(Float32, 2, length(ixs))*0.005
    ax.scatter(z[1,:], z[2,:], color=ColorMap("tab10")(i-1), alpha=0.5)
end
legend(style_names[(1:7) .+ 1])

In [None]:
# update latents
error("safeguard")
Zmap.data .= z_smpopt .+ randn(Float32, 2, length(trainIter))*0.005;

#### Visualise fit (and MT variability) for a batch

In [None]:
dset_i = 100
n_draws = 3

trainIters = collect(trainIter);
_Yb, _Ub, _h = trainIters[dset_i]
_eps = cholesky(cov(Zmap.data')).U * randn(Float32, 2, n_draws)
_eps = Zmap.data[:, rand(Categorical(ones(length(trainIter))/length(trainIter)), n_draws)]
_eps[:,1] = Zmap.data[:,dset_i]
cldsY = map(1:n_draws) do i
    ψ = cmtlds.nn(_eps[:,i])
    lds = model._make_lds_psi(cmtlds, ψ, cmtlds.η_h)
    lds(_Ub)
end

fig, axs = subplots(5,4,figsize=(10,10))
offset = 0
for i = 1:20
    axs[:][i].plot(_Yb'[:, i+offset])
    for j in 1:n_draws
        axs[:][i].plot(cldsY[j]'[:, i+offset], alpha=0.4)
    end
end

# MT-ORNN (Hard-EM) experiment

#### Setup data

In [None]:
# Get training set for STL and pooled models.
style_ix = 1
d = d_state = 100;

In [None]:
trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, :pooled);

In [None]:
# construct batch iterator
batch_size = 64
min_size = 50
trainIter = mocaputil.DataIterator(trainPool, batch_size, min_size=min_size);

In [None]:
# style segment lookups
style_names = ["angry", "childlike", "depressed", "neutral", "old", "proud", "sexy", "strutting"];
segment_lkp = [length(mocaputil.DataIterator(mocapio.get_data(expmtdata, i, :train, :stl, split=[0.875,0.125]),
            batch_size, min_size=50)) for i in 1:7];
segment_lkp = [collect(i+1:j) for (i,j) in zip(vcat(0, cumsum(segment_lkp[1:end-1])), cumsum(segment_lkp))];

#### Base model

In [None]:
# base *LDS*
train_neutral = mocapio.get_data(expmtdata, 4, :train, :stl, concat=true, simplify=true);
clds_orig = model.init_LDS_spectral(train_neutral[:Y], train_neutral[:U], d_state, t_ahead=ceil(Int, d_state/64));

In [None]:
# init a:
# extract cθ from a spectral LDS fit for ORNN initialisation (see Henaff et al. for block diag init motivation)
lds_evs = eigvals(model.Astable(clds_orig));
blkvals = vcat([sqrt((1-ct)/(1+ct)) for ct in real(lds_evs[2:2:end])]', 
                zeros(Float32, floor(Int, d_state/2))')[1:end-1]
a = AxUtil.Math.unmake_lt_strict(diagm(-1=>blkvals), d_state)
a = vcat(ones(Float32, d_state)*atanh(0.9f0), a);

In [None]:
_d_state, d_out, d_in = size(clds_orig)
@argcheck d_state == _d_state
_U = train_neutral[:U]
cN = size(_U, 2)

# construct init RNN
rnn = RNN(d_in, d_state, tanh)
rnn.cell.Wh.data .= model.Astable(a, d_state)

# initialise emissions
x̂ = reduce(hcat, let _rnn=mapleaves(Tracker.data, rnn); [_rnn(_U[:,i]) for i in 1:cN]; end)
CDd = model._tikhonov_mrdivide(train_neutral[:Y], [x̂; _U; ones(1, cN)], 1e-3);
C = param(CDd[:, 1:d_state]) |> f32
D = param(CDd[:, d_state+1:end-1]) |> f32
d_offset = param(CDd[:,end]) |> f32;

In [None]:
# initialise base model
ornn_base = model.ORNN_g(param(a), copy(rnn.cell.Wi), copy(rnn.cell.b), copy(rnn.cell.h),
                    copy(C), copy(D), copy(d_offset), tanh, Chain());  # empty input NN (Chain())

#### Multi-task manifold

In [None]:
k = 2                 # dimension of manifold
d_nn = 200            # "complexity" of manifold
d_subspace = 30;       # dim of subspace (⊆ parameter space) containg the manifold

In [None]:
d_par = [length(x) for x in model.pars_no_inpnn(ornn_base)] |> sum
nn = Chain(Dense(k, d_nn, tanh), Dense(d_nn, d_subspace, identity), 
    Dense(d_subspace, d_par, identity, initW = ((dims...)->Flux.glorot_uniform(dims...)*0.05f0)))
nn_ng = mapleaves(Tracker.data, nn)
Zmap = Flux.param(randn(Float32, k, length(trainIter))*0.01f0);

#### MT-ORNN
Note that the MTORNN object is still not mature, and I'm just manipulating directly below.

In [None]:
ornn_optim = copy(ornn_base);   # copy to avoid over-writing initialisation
ornn_optim_ng = model.make_nograd(ornn_optim);

#### Optimisation

In [None]:
opt = ADAM(1e-4)
pars = Flux.params(nn, Zmap);

In [None]:
n_epochs = 300
opt.eta = 1e-3 #0.5e-4
shuffle_examples = true

nB = length(trainIter)
W = mocaputil.weights(trainIter; as_pct=false) ./ batch_size

history = ones(n_epochs*nB) * NaN

for ee in 1:n_epochs
    rnn = RNN(size(ornn_optim,3), size(ornn_optim,1), ornn_optim.σ)
    
    if shuffle_examples
        mtl_ixs, trainData = mocaputil.indexed_shuffle(trainIter)
    else
        mtl_ixs, trainData = 1:length(trainIter), trainIter
    end
    for (ii, (Yb, Ub, h0)) in zip(mtl_ixs, trainData)
        h0 && Flux.reset!(rnn)
        Tb = size(Yb, 2)      # not constant
        Zs_post = Zmap[:,ii]  #.+ convert(Array{Float32}, AxUtil.Random.sobol_gaussian(m_bprop,2)'*0.01)
        
        c_ornn = model.make_rnn_psi(ornn_optim, nn(Zs_post), 1f0)
        
        model.build_rnn!(rnn, c_ornn)
        x̂ = reduce(hcat, [rnn(Ub[:,i]) for i in 1:Tb])  |> Tracker.collect
        #         ŷ = let m=ornn_optim; m.C*x̂ + m.D*Ub .+ m.d; end   # keep same C, D, d ∀ tasks
        ŷ = let m=c_ornn; m.C*x̂ + m.D*Ub .+ m.d; end                 # adapt C, D, d too.
        obj = mean(x->x^2, Yb - ŷ) * 8^2 * W[ii]

        # Prior penalty
        obj += 0.5*sum(Zs_post .* Zs_post)
        
        Tracker.back!(obj)
        history[(ee-1)*nB + ii] = obj.data
        
        if ii % 34 == 0
            for layer in nn.layers
                obj += 1e-3*sum(abs, layer.W)
                obj += 1e-3*sum(abs, layer.b)
            end
            
            for p in pars
                Tracker.update!(opt, p, Tracker.grad(p))
            end
        end
        
        rnn.cell.h.data .= 0       # initial state is a param :/. Easier to reset here.
        Flux.truncate!(rnn);
    end
    (ee % 1 == 0) && println(sqrt(mean(history[(1:nB) .+ nB*(ee-1)]))); flush(stdout)

end

#### Plot optimisation progress

In [None]:
plot(sqrt.(DSP.conv(history, Windows.rect(nB))[nB:end-nB+1]/nB))

#### Plot latent space

In [None]:
ax = gca()
for i in 1:7
    ixs = segment_lkp[i]
    z = Zmap.data[:, ixs]
    ax.scatter(z[1,:], z[2,:], color=ColorMap("tab10")(i-1), alpha=0.5)
end
legend(style_names[(1:7) .+ 1])

#### Save model

In [None]:
error("safeguard")
BSON.bson(format("ornn{:d}_pool_{:d}{:02d}_v{:d}.bson", 
        d_state, day(today()), month(today()), batch_size), m=ornn_optim_ng, Zmap=Zmap.data);

#### Global optimisation of latents

In [None]:
nsmp = 300
_Zsmp = cholesky(cov(Zmap.data')).U * randn(Float32, 2, nsmp);

In [None]:
# populate error matrix with above samples
res = ones(Float32, length(trainIter), nsmp)

ornn_optim_ng = model.make_nograd(ornn_optim);
rnn_ng = mapleaves(Tracker.data, RNN(d_in, d_state, ornn_optim.σ))
@time for i in 1:nsmp
    _ψ = nn_ng(_Zsmp[:,i]);
    c_ornn = model.make_rnn_psi(ornn_optim_ng, _ψ, 1f0)
    model.build_rnn!(rnn_ng, c_ornn)
    for (n, (Yb, Ub, h0)) in enumerate(trainIter)
        h0 && Flux.reset!(rnn_ng)
        Tb = size(Yb, 2)
        x̂ = reduce(hcat, [rnn_ng(Ub[:,i]) for i in 1:Tb])
        ŷ = let m=c_ornn; m.C*x̂ + m.D*Ub .+ m.d; end
        res[n,i] = mean(x->x^2, Yb - ŷ)
    end
end

In [None]:
# find MAP of implicit posterior (SNIS)
pz = softmax(-1*(res').^2)   # note that this is much less peaked than it should be=> mean not sum.
z_smpopt = copy(Zmap.data)
for i in 1:length(trainIter)
    z_smpopt[:,i] = _Zsmp[:, argmax(pz[:,i])]
end

In [None]:
# plot to compare with current position
ax = gca()
# ax.scatter(_Zsmp[:,1], _Zsmp[:,2], alpha=0.1)
for i in 1:7
    ixs = segment_lkp[i]
    z = z_smpopt[:, ixs] .+ randn(Float32, 2, length(ixs))*0.005
    ax.scatter(z[1,:], z[2,:], color=ColorMap("tab10")(i-1), alpha=0.5)
end
legend(style_names[(1:7) .+ 1])

In [None]:
# update latents
error("safeguard")
Zmap.data .= z_smpopt .+ randn(Float32, 2, length(trainIter))*0.005;

#### Visualise fit (and MT variability) for a batch

In [None]:
dset_i = 400
n_draws = 3

trainIters = collect(trainIter);
_Yb, _Ub, _h = trainIters[dset_i]
_Tb = size(_Yb, 2)
_eps = cholesky(cov(Zmap.data')).U * randn(Float32, 2, n_draws)
_eps = Zmap.data[:, rand(Categorical(ones(length(trainIter))/length(trainIter)), n_draws)]
_eps[:,1] = Zmap.data[:,dset_i]
cldsY = map(1:n_draws) do i
    _ψ = nn_ng(_eps[:,i]);    
    c_ornn = model.make_rnn_psi(ornn_optim_ng, _ψ, 1f0)
    model.build_rnn!(rnn_ng, c_ornn)
    x̂ = reduce(hcat, [rnn_ng(_Ub[:,i]) for i in 1:_Tb])
    let m=c_ornn; m.C*x̂ + m.D*_Ub .+ m.d; end
end

fig, axs = subplots(5,4,figsize=(10,10))
offset = 0
for i = 1:20
    axs[:][i].plot(_Yb'[:, i+offset])
    for j in 1:n_draws
        axs[:][i].plot(cldsY[j]'[:, i+offset], alpha=0.4)
    end
end