# Begin to productionise experiments

* Write save/load function for MyLDS.
* Think carefully about experiments:
    * Train on 70%, validate 10%, test 20%. How?
        * Remember that inter-files/intra-files are non-stationary.
    * Write loops to extract data and be able to train LDS of various dimension with little user interaction.
* How to choose $k$? 
    * Well we could cross-validate, but I need more Azure cores to do this.
    * My belief is that even by $k=30$ we are not really overfitting, especially as resetting every 256.
    * Argue that not only does the data exhibit $\approx 20 d.o.f.$, and we need to store previous frame in state to calculate velocity, if we are even to beat the copy baseline (i.e. need $\ge$ 20 d.o.f. purely for this). But also linear dynamics from one frame to the next may need more than PCA to permit this. I think 30 is a good place to start.
* Build Docker container.
* Work out how to mount filesystem on Azure and/or scp out.
* Request more Azure compute power.


### Also

* Remember that for MTL, we will allow each segment (e.g. 256 length) to be modelled ~ independently.
* Need to optimise nograd version of LDS for sampling. My suspicion is that it is not type stable (although it's possible that the matmuls are now dominating a lot more.

In [1]:
# using Revise
using LinearAlgebra, Random
using StatsBase, Statistics
using Distributions, MultivariateStats   # Categorical, P(P)CA
using Quaternions    # For manipulating 3D Geometry
using MeshCat        # For web visualisation / animation
using PyPlot         # Plotting
using AxUtil         # Cayley, skew matrices
using Flux           # Optimisation
using DSP            # convolution / low-pass (MA) filter

# small utils libraries
using ProgressMeter, Formatting, ArgCheck
using BSON

In [2]:
DIR_MOCAP_MTDS = "." 

# Data loading and transformation utils
include(joinpath(DIR_MOCAP_MTDS, "io.jl"))

# MeshCat skeleton visualisation tools
include(joinpath(DIR_MOCAP_MTDS, "mocap_viz.jl"))

# Data scaling utils
include(joinpath(DIR_MOCAP_MTDS, "util.jl"))
import .mocaputil: MyStandardScaler, scale_transform, invert
import .mocaputil: OutputDifferencer, difference_transform, fit_transform
import .mocaputil: no_pos, no_poscp

# Models: LDS
include(joinpath(DIR_MOCAP_MTDS, "models.jl"))
import .model: Astable

# Table visualisation
include(joinpath(DIR_MOCAP_MTDS, "pretty.jl"))

In [3]:
############################################
##    CUSTOM WIDELY USED FUNCTIONS
function zero_grad!(P) 
    for x in P
        x.grad .= 0
    end
end

rmse(Δ::AbstractArray, scale=size(Δ, 1)) = sqrt(mean(x->x^2, Δ))*scale

function rmse(d::mocaputil.DataIterator, m::model.MyLDS_ng)
    obj = map(d) do (y, u, new_state)
        new_state && (m.h .= zeros(size(m, 1))) 
        rmse(m(u) - y)
    end
    m.h .= zeros(size(m, 1))
    return dot(obj, mocaputil.weights(d, as_pct=true))
end


rmse(Ds::Vector{D}, m::model.MyLDS_ng) where {D <: Dict} = rmse(mocaputil.DataIterator(Ds, 1000000), m)
############################################

### Load in Data
See `2_Preprocess.ipynb`

**Note that in the current harddisk state, `edin_Ys.bson` was created with `include_ftcontact=false`**


In [4]:
database = "../data/mocap/edin-style-transfer/"
files_edin = [joinpath(database, f) for f in readdir(database)];
style_name_edin = [x[1] for x in match.(r"\.\./[a-z\-]+/[a-z\-]+/[a-z\-]+/([a-z]+)_.*", files_edin)];
styles = unique(style_name_edin)
styles_lkp = [findall(s .== style_name_edin) for s in styles];

In [5]:
Usraw = BSON.load("edin_Xs_30fps.bson")[:Xs];
Ysraw = BSON.load("edin_Ys_30fps.bson")[:Ys];

standardize_Y = fit(MyStandardScaler, reduce(vcat, Ysraw),  1)
standardize_U = fit(MyStandardScaler, reduce(vcat, Usraw),  1)

Ys = [scale_transform(standardize_Y, y[2:end, :] ) for y in Ysraw];  # (1-step ahead of u)
Us = [scale_transform(standardize_U, u[1:end-1,:]) for u in Usraw];  # (1-step behind y)

@assert (let c=cor(Usraw[1][1:end-1, :], Ysraw[1][2:end, :], dims=1); 
        !isapprox(maximum(abs.(c[.!isnan.(c)])), 1.0); end) "some input features "

In [6]:
invert_output_tform(y, i) = invert(standardize_Y, y) |> yhat -> invert(dtforms[i], yhat)

In [7]:
# standardize_Y = fit(model.MyStandardScaler, reduce(vcat, Ysraw),  1)
# standardize_U = fit(model.MyStandardScaler, reduce(vcat, Usraw),  1)

In [8]:
# SENSE CHECK
# check that no bugs in constructing U, Y (i.e. esp that t's align and can predict U --> Y)
let c=cor(reduce(vcat, Us) |>f64, reduce(vcat, Ys) |> f64, dims=1)
    imshow(c, aspect="auto")
    nonan_c = c[.!isnan.(c)]
    title(format("max (abs) corrcoeff: {:.8f}", maximum(abs.(nonan_c))))
    flush(stdout)
    display(findmax(reshape(nonan_c, size(c, 1) - 2, size(c,2))))
    printfmtln("10th best result {:.5f}", reverse(sort(nonan_c))[10]) 
end

## Useful functions

In [9]:
expmtdata = mocapio.ExperimentData(Ysraw, [Matrix(y') for y in Ys], 
    [Matrix(u') for u in Us], styles_lkp);

In [200]:
?mocapio.get_data

In [10]:
"""
    findin(x, q)
Equivalent of `findall` for an elementwise `in` operation. Returns
the indices of `[findfirst(xx ∈ q) for xx in x]`.
"""
function findin(x, q)
    x, q = sort(unique(x)), sort(unique(q))
    l = length(q)
    out, j = [], 1
    for xx in x
        while xx > q[j] && j < l
            j += 1
        end
        (xx == q[j]) && push!(out, j)
        (j == l) && break
    end
    return out
end

"""
results struct.
Initialise with `r=init_results()`. Obtain indices via e.g. `validation(r, 2)` will give 
the 2-step error in the validation set. To set indices use e.g. `set_validation(r, val, 2)`.
"""
mutable struct _results{T}
    data::Matrix{T}
end
lkah = [1,2,4,8,14,25]
init_results() = _results(ones(3, 6)*NaN)
results_struct(model_names) = Dict(m=>init_results() for m in model_names)
training(s::_results, i=lkah, ix=findin(i,lkah)) = s.data[1,ix]'
validation(s::_results, i=lkah, ix=findin(i,lkah)) = s.data[2,ix]'
testing(s::_results, i=lkah, ix=findin(i,lkah)) = s.data[3,ix]'
set_training(s::_results, val, i=lkah, ix=findin(i,lkah)) = (s.data[1,ix] = val);
set_validation(s::_results, val, i=lkah, ix=findin(i,lkah)) = (s.data[2,ix] = val);
set_testing(s::_results, val, i=lkah, ix=findin(i,lkah)) = (s.data[3,ix] = val);



aggregate_generic(v::Array, f::Function, model::Symbol, g::Function; args...) = 
    g(reduce(vcat, [f(v[i][model]) for i in 1:8]); args...)

aggregate_mean(v::Array, f::Function, model::Symbol) = aggregate_generic(v, f, model, mean; dims=1)
aggregate_std(v::Array, f::Function, model::Symbol) = aggregate_generic(v, f, model, std; dims=1)

# Training Loop

In [11]:
model_types = [:copy, :LR, :LDS2init, :LDS20init, :LDS40init, :LDS20smp, :LDS40smp]
results_stl = [results_struct(model_types) for i in 1:8]
results_pool = [results_struct(model_types) for i in 1:8];

In [12]:
trainSTL, validSTL, testSTL = mocapio.get_data(expmtdata, 1, :split, :pool; concat=true,
simplify=true)

### Copy prev frame model

In [13]:
# MODEL TYPE: COPY
for style_ix = 1:8
    # Get training set for STL and pooled models.
    trainSTL, validSTL, testSTL = mocapio.get_data(expmtdata, style_ix, :split, :stl;
                                    concat=true, simplify=true)
    trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, 
                                        :pooled, concat=true, simplify=true)
    cN_STL = size(trainSTL[:Y], 2);
    
    # Create "Copy" model
    model_copy(test_set, k_step=1) = test_set[:,1:end-k_step]
    
    for (settr, data) in zip([set_training, set_validation, set_testing], [trainSTL, validSTL, testSTL])
        results = [rmse(model_copy(data[:Y], k) - data[:Y][:, 1+k:end]) for k in lkah]
        settr(results_stl[style_ix][:copy], results);
    end
    
    for (settr, data) in zip([set_training, set_validation], [trainPool, validPool])
        results = [rmse(model_copy(data[:Y], k) - data[:Y][:, 1+k:end]) for k in lkah]
        settr(results_pool[style_ix][:copy], results);
    end
    results = [rmse(model_copy(testPool[:Y], k) - testPool[:Y][:, 1+k:end]) for k in lkah]
    set_testing(results_pool[style_ix][:copy], results);
end

### Linear Regression

In [193]:
# MODEL TYPE: LINEAR REGRESSION

@showprogress for style_ix = 1:8
    # Get training set for STL and pooled models.
    trainSTL, validSTL, testSTL = mocapio.get_data(expmtdata, style_ix, :split, :stl,
                                    concat=true, simplify=true)
    trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, :pooled,
                                        concat=true, simplify=true)
    cN_STL = size(trainSTL[:Y], 2);
    
    # Create model
    function model_LR(train_set) 
        Dd = train_set[:Y] / [train_set[:U]; ones(1, size(train_set[:U], 2))]
        clds = model.init_LDS_spectral(validSTL[:Y], validSTL[:U], 1, max_iter=1);
        clds.C .= 0
        clds.B .= 0
        clds.b .= 0
        clds.D .= Dd[:,1:end-1]
        clds.d .= Dd[:,end]
        lr(test, k) =  model.kstep_predict(clds, test[:U], test[:Y], standardize_Y, standardize_U, 10, k)
        return lr
    end
    
    # Train model
    cmodel = model_LR(trainSTL)
    
    # Evaluate on Train/Validation/Test
    for (settr, data) in zip([set_training, set_validation, set_testing], [trainSTL, validSTL, testSTL])
        results = [rmse(cmodel(data, k) - data[:Y][:, k:end]) for k in lkah]
        settr(results_stl[style_ix][:LR], results);
    end
    
    ################### POOLED DATA ##############################

    # Train model
    cmodel = model_LR(trainPool)
    
    # Evaluate on Train/Validation set
    for (settr, data) in zip([set_training, set_validation], [trainPool, validPool])
        results = [rmse(cmodel(data, k) - data[:Y][:, k:end]) for k in lkah]
        settr(results_pool[style_ix][:LR], results);
    end
    
    # Evaluate on Test Set
    results = [rmse(cmodel(testPool, k) - testPool[:Y][:, k:end]) for k in lkah]
    set_testing(results_pool[style_ix][:LR], results);
end

In [196]:
pretty.table([aggregate_mean(results_pool, training, m)/8 for m in [:copy, :LR]],
             [aggregate_mean(results_pool, testing, m)/8 for m in [:copy, :LR]]; 
    header_row=lkah, header_col=["Copy", "Linear Reg"], title=["Train", "Test"], dp=2, header="Average")

In [198]:
pretty.table([aggregate_std(results_pool, training, m)/8 for m in [:copy, :LR]],
             [aggregate_std(results_pool, testing, m)/8 for m in [:copy, :LR]]; 
        header_row=lkah, header_col=["Copy", "Linear Reg"], title=["Train", "Test"], dp=2, header="Standard Dev.")

## Coord Desc LDS

In [33]:
model_name = :LDS20init

@showprogress for style_ix = 1:8
    # Get training set for STL and pooled models.
    trainSTL, validSTL, testSTL = mocapio.get_data(expmtdata, style_ix, :split, :stl,
                                    concat=true, simplify=true)
    trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, :pooled,
                                    concat=true, simplify=true)
    cN_STL = size(trainSTL[:Y], 2);
    
    # Train model
    clds = model.init_LDS_spectral(trainSTL[:Y], trainSTL[:U], 20);
    cmodel(test, k) = model.kstep_predict(clds, test[:U], test[:Y], standardize_Y, standardize_U, 10, k)
    
    # Evaluate on Train/Validation/Test
    for (settr, data) in zip([set_training, set_validation, set_testing], [trainSTL, validSTL, testSTL])
        results = [rmse(cmodel(data, k) - data[:Y][:, k:end]) for k in lkah]
        settr(results_stl[style_ix][model_name], results);
    end
    
    ################### POOLED DATA ##############################

    # Train model
    clds = model.init_LDS_spectral(trainPool[:Y], trainPool[:U], 20);
    cmodel(test, k) = model.kstep_predict(clds, test[:U], test[:Y], standardize_Y, standardize_U, 10, k)
    
    # Evaluate on Train/Validation/Test set
    for (settr, data) in zip([set_training, set_validation, set_testing], [trainPool, validPool, testPool])
        results = [rmse(cmodel(data, k) - data[:Y][:, k:end]) for k in lkah]
        settr(results_pool[style_ix][model_name], results);
    end
end

In [34]:
model_name = :LDS40init

@showprogress for style_ix = 1:8
    # Get training set for STL and pooled models.
    trainSTL, validSTL, testSTL = mocapio.get_data(expmtdata, style_ix, :split, :stl,
                                    concat=true, simplify=true)
    trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, :pooled,
                                    concat=true, simplify=true)
    cN_STL = size(trainSTL[:Y], 2);
    
    # Train model
    clds = model.init_LDS_spectral(trainSTL[:Y], trainSTL[:U], 40);
    cmodel(test, k) = model.kstep_predict(clds, test[:U], test[:Y], standardize_Y, standardize_U, 10, k)
    
    # Evaluate on Train/Validation/Test
    for (settr, data) in zip([set_training, set_validation, set_testing], [trainSTL, validSTL, testSTL])
        results = [rmse(cmodel(data, k) - data[:Y][:, k:end]) for k in lkah]
        settr(results_stl[style_ix][model_name], results);
    end
    
    ################### POOLED DATA ##############################

    # Train model
    clds = model.init_LDS_spectral(trainPool[:Y], trainPool[:U], 40);
    cmodel(test, k) = model.kstep_predict(clds, test[:U], test[:Y], standardize_Y, standardize_U, 10, k)
    
    # Evaluate on Train/Validation/Test set
    for (settr, data) in zip([set_training, set_validation, set_testing], [trainPool, validPool, testPool])
        results = [rmse(cmodel(data, k) - data[:Y][:, k:end]) for k in lkah]
        settr(results_pool[style_ix][model_name], results);
    end
    
end

In [205]:
pretty.table([aggregate_mean(results_stl, training, m)/8 for m in [:copy, :LR, :LDS20init, :LDS40init]],
             [aggregate_mean(results_stl, testing, m)/8 for m in [:copy, :LR, :LDS20init, :LDS40init]]; 
    header_row=lkah, header_col=["Copy", "Linear Reg", "LDS20init", "LDS40init"], title=["Train", "Test"], 
    dp=2, header="Average -- STL Models")

In [206]:
pretty.table([aggregate_mean(results_pool, training, m)/8 for m in [:copy, :LR, :LDS20init, :LDS40init]],
             [aggregate_mean(results_pool, testing, m)/8 for m in [:copy, :LR, :LDS20init, :LDS40init]]; 
    header_row=lkah, header_col=["Copy", "Linear Reg", "LDS20init", "LDS40init"], title=["Train", "Test"], 
    dp=2, header="Average -- Pooled Models")

### Bits and pieces

In [604]:
# model_types = [:LDS40init_t2]
# for i in 1:8
#     for m in model_types
#         results_stl[i][m] = init_results()
#         results_pool[i][m] = init_results()
#     end
# end

## Training an STL model initialised from Pooled

* Looking at `angry` style (`ix=1`).
* L2 error of STL is 0.84 in training, 1.65 in test. Note however the end of the seq is harder.
* L2 test error of Pooled model on this style is 1.71. Clearly this is worse than STL, especially since the test result of STL is only of the test part: the hardest bit.
* So the gap is something like 1.71 (pooled) vs ~0.80 (STL). (not apples to apples due to test set of STL.)

From initialisation using the pooled model, **can SGD get close to the STL results, when training on the STL training set?**.

==> *The answer is that it can do **fairly** well, getting to ~ 1.00 on the STL test set. If the exact regression step can be used (note it cannot obviously be done for MTL), then we get close to 0.90.*

In [728]:
vcat(training(results_pool[1][:LDS20init]), testing(results_pool[1][:LDS20init]))/8

In [842]:
vcat(training(results_stl[1][:LDS20init]), validation(results_stl[1][:LDS20init]))/8

In [14]:
# Get training set for STL and pooled models.
style_ix = 1
k = 20

trainPool, validPool, testPool = mocapio.get_data(expmtdata, style_ix, :split, :pooled, concat=true, simplify=true)

# Train model
clds_orig = model.init_LDS_spectral(trainPool[:Y], trainPool[:U], k);

In [16]:
trainSTL, validSTL, testSTL = mocapio.get_data(expmtdata, style_ix, :split, :stl);

In [16]:
# clds_g = model.make_grad(model.init_LDS_spectral(cYs, cUT, k))
clds_g = model.make_grad(clds_orig)
clds   = model.make_nograd(clds_g)   # MUST DO THIS SECOND, (Flux.param takes copy)

# clds.a .= vcat(ones(k)*2.3, zeros(Int(k*(k-1)/2)))

opt = ADAM(1e-4)
opt_hidden = ADAM(0.7e-5)

ps_hidden = Flux.params(clds_g.a, clds_g.B, clds_g.b)
ps_observ = Flux.params(clds_g.C, clds_g.D, clds_g.d)

In [432]:
testing(results_pool[1][:LDS20init])

In [1182]:
model.fit_optimal_obs_params(clds, trainSTL[:Y], trainSTL[:U]);   # don't want to use this.

In [434]:
rmse(testPool, clds)

In [None]:
batch_size = 192
min_size = 50
@time h = begin
    n_epochs = 200
    # max batch length ==> will be variable as starting position is stochastic.
    trainIter = mocaputil.DataIterator(testPool, batch_size, min_size=min_size)
    W = mocaputil.weights(trainIter; as_pct=false) ./ batch_size
    nB = length(trainIter)
    history = ones(n_epochs*nB) * NaN
    
    for ee in 1:n_epochs
#         bix = rand(1:64)   # begin at index (in each mocap file)
#         trainIter = mocaputil.DataIterator(trainSTL, batch_size, min_size=min_size, start=bix);
        if ee == 1
            opt.eta, opt_hidden.eta = 0., 0.
        elseif ee == 3
            opt.eta = 5e-4 / 10 * 0.1 * 0.3 * 10 
            opt_hidden.eta =  0.02e-4 * 10 * 0.005 * 1
        elseif ee % 100 == 0
            printfmtln("100 epochs")
        end
        
        for (ii, (_cY, _cU, h0)) in enumerate(trainIter)
            h0 && (clds_g.h.data .= zeros(size(clds_g, 1)))   # reset state?
            X̂ = model.state_rollout(clds_g, _cU); 
            Yhat = clds_g.C * X̂ + clds_g.D * _cU .+ clds_g.d;
            obj = mean(x->x^2, _cY - Yhat)*64^2 * W[ii]
            history[(ee-1)*nB + ii] = obj.data
            Tracker.back!(obj)
            
            for p in ps_hidden
                Tracker.update!(opt_hidden, p, -Tracker.grad(p))
            end
            for p in ps_observ
                Tracker.update!(opt, p, -Tracker.grad(p))
            end
            clds_g.h.data .= X̂.data[:,end]
        end
        println(sqrt(mean(history[(1:nB) .+ nB*(ee-1)])))
    end
    history
end;

In [436]:
plot(sqrt.(conv(h, Windows.rect(nB))[nB:end-nB+1]/nB))

# Multi-task modelling!

In [29]:
include(joinpath(DIR_MOCAP_MTDS, "models.jl"))

In [13]:
style_ix = 1
d_state = 4
# Train model
trainAll = mocapio.get_data(expmtdata, style_ix, :all, :pooled, concat=true, simplify=true)
clds_orig = model.init_LDS_spectral(trainAll[:Y], trainAll[:U], d_state);
trainAll = nothing  # fairly large (~60 MB), don't want a multiplicity of these kicking around memory.

In [50]:
k = 2                 # dimension of manifold
d_nn = 200            # "complexity" of manifold
d_subspace = 40       # dim of subspace (⊆ parameter space) containg the manifold

d_par = length(model.get_pars(clds_orig))
nn = Chain(Dense(k, d_nn, tanh), Dense(d_nn, d_subspace, σ), 
    Dense(d_subspace, d_par,identity, initW = ((dims...)->Flux.glorot_uniform(dims...)*0.1)),
    Flux.Diagonal(d_par))
nn = Chain(Dense(k, d_par, identity), Flux.Diagonal(d_par))
clogσ = repeat([-1.5f0], size(clds_orig, 2))

cmtlds_g = model.mtldsg_from_lds(clds_orig, nn, clogσ, 0.1f0);
cmtlds = model.make_nograd(cmtlds_g);
# model.change_relative_lr!(cmtlds, 0.001f0)   # reduce sensitivity of chain params. 

In [170]:
rmse(trainSTL, model.make_lds(cmtlds, randn(Float32, k), 0.1f0))/8

In [18]:
rmse(trainSTL, model.make_lds(cmtlds, _Z[1:k,1], 0.1f0))/8

### Forward pass

For sampling, we can happily perform the following:

```julia
    ϵ = randn(Float32, k, 200)
    for i in 1:200
        lds = model.make_lds(cmtlds, ϵ[:,i])
        smp = lds(U)
        ...
    end
```

where we can use all the nice utility functions to create an LDS from the MTLDS and sample accordingly using the fast sampling from LDS objects.

However, for optimisation (once samples have been chosen), it is highly suboptimal to backprop through the (large) FFNN one sample at a time. Instead we should probably do the following:

```julia
    Z # from posterior sampling over the ϵ in the above stage
    Ψ = cmtlds.nn(Z)
    for i in 1:size(Ψ,2)
        lds = model._make_lds_psi(cmtlds, Ψ[:,i])
        obj += error(lds(U), Y)
        ...
    end
```

### Useful functions

* **make_lds**: `model.make_lds(cmtlds, randn(Float32, 3), 0.1f0)`. Make the LDS object corresponding to the value of `z` (here a `randn`) acting on `cmtlds`. The 3rd argument specifies the *relative* learning rate of the latent chain parameters. (Irrelevant if not performing optimisation.)
* **rmse**: `rmse(dataDicts, lds::MyLDS_ng)`: calculate the sum of RMSE over all data in `dataDicts`.
* **p_log_llh**: `p_log_llh(cmtlds, randn(Float32, 64, 200), randn(Float32, 121, 200)*0.1f0, Z)`. Calculate the log likelihood of the LDS models corresponding to the samples `Z` (can also supply importance weights in final argument (not shown)). Note that this value $\propto$ the length of data it is given.
* **sample_posterior**: `sample_posterior(cmtlds, trainSTL[1][:Y], trainSTL[1][:U], 100)`. Returns weights, the estimated log_px, and the samples used to calculate them. Instead of generating samples, the function can also be given them as a 5th argument.

In [19]:
function sample_forward(mtlds::model.MTLDS_ng{T,F}, U::AbstractArray{T}, M::Int, ϵ::AbstractMatrix{T}) where {T, F}
    return [model.make_lds(mtlds, view(ϵ, :, i), mtlds.η_h)(U) for i in 1:M]
end

function sample_forward(mtlds::model.MTLDS_ng{T,F}, U::AbstractArray{T}, M::Int) where {T, F}
    k = size(mtlds.nn.layers[1].W, 2)
    ϵ = convert(Array{T}, AxUtil.Random.sobol_gaussian(M, k)')
    return (sample_forward(mtlds, U, M, ϵ)..., ϵ)
end


@inline _gauss_lognormconst(logσ::Vector, tt) = 0.5*tt*length(logσ)*log(2π) + 0.5*tt*sum(2*logσ)

#= slightly faster version for when a single Y (can do subtraction in-place).
   This is called much more frequently during training (i.e. before every couple of batches) =#
function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Y::AbstractArray{T}, 
            U::AbstractArray{T}, M::Int, ϵ::AbstractArray{T}) where {T, F}
    
    Ŷ = sample_forward(mtlds, U, M, ϵ)
    
    # Calculate density of each sample Ŷ
    Δnorm = Vector{T}(undef, M)
    precision = 1 ./ exp.(2*mtlds.logσ)
    for i in 1:M
        @views Ŷ[i] .-= Y
        Δnorm[i] = dot(sum(x->x^2, Ŷ[i], dims=2), precision)
    end

    # Elementwise logpdf ==> single logpdf and importance weights
    el_lpdf = -0.5 * Δnorm .- _gauss_lognormconst(mtlds.logσ, size(Y,2))
    W, lgpdf = AxUtil.Math.softmax_lse!(reshape(el_lpdf, :, 1))   # softmax AND logsumexp
    return W, lgpdf[1] .- log(M)  # ϵ, W, log(1/M sum_j exp(log_j))
end

function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Y::AbstractArray{T}, 
            U::AbstractArray{T}, M::Int) where {T, F}
    k = size(mtlds.nn.layers[1].W, 2)
    ϵ = convert(Array{T}, AxUtil.Random.sobol_gaussian(M, k)')
    return (sample_posterior(mtlds, Y, U, M, ϵ)..., ϵ)
end


# slower version for *series of matrices* for Y (can no longer do in-place) 
function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Ys::AbstractArray{MT}, 
            Us::AbstractArray{MT}, M::Int) where {T, F, MT <: Matrix}
    
    k = size(mtlds.nn.layers[1].W, 2)
    
    # common r.v.s
    ϵ = convert(Array{T}, AxUtil.Random.sobol_gaussian(M, k)')
    W, lgpdf = sample_posterior(mtlds, Ys, Us, M, ϵ)
    
    return Matrix(W'), lgpdf, ϵ  # W, log(1/M sum_j exp(log_j)), ϵ
end

function sample_posterior(mtlds::model.MTLDS_ng{T,F}, Ys::AbstractArray{MT}, 
            Us::AbstractArray{MT}, M::Int, ϵ::AbstractArray{T}) where {T, F, MT <: Matrix}
    
    N = length(Ys)
    W = Matrix{T}(undef, N, M)
    lgpdf = Vector{T}(undef, N)
    
    for nn in 1:N
        _W, _lp = sample_posterior(mtlds, Ys[nn], Us[nn], M, ϵ)
        W[nn,:] = _W[:]
        lgpdf[nn] = _lp
    end
    
    return Matrix(W'), lgpdf  # W, log(1/M sum_j exp(log_j)), ϵ
end

function get_log_px(model::model.MTLDS_ng{T,F}, Y::AbstractArray, U::AbstractArray, m::Int) where {T,F}
    sample_posterior(model, Y, U, m)[2]
end

In [20]:
const prior_var = 1.0
"""
    p_log_llh(mtlds, Y, U, Z, wgt=ones(T, size(Z, 2)))
Calculate the log likelihood of the LDS models corresponding to the samples
`Z` (Vector or Column Matrix) within `mtlds`. `wgt` corr. to importance
weights of `Z`. (Under resampling, typically integral valued).
"""
function p_log_llh(mtlds::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}}, 
        Y::AbstractMatrix{T}, U::AbstractMatrix{T}, Z::AbstractVecOrMat{T}, 
        wgt::AbstractVector{T}=ones(T, size(Z, 2))) where {T,F}
    Ψ = mtlds.nn(Z)
    precision = 1 ./ exp.(2*mtlds.logσ)
    states = Matrix{T}(undef, size(mtlds,1), size(Z,2))
    
    llh = map(1:size(Z,2)) do i
        lds = model._make_lds_psi(mtlds, Ψ[:,i], mtlds.η_h)
        X   = model.state_rollout(lds, U)
        Ŷ   = lds.C * X + lds.D * U .+ lds.d;
        states[:,i] = Tracker.data(X)[:,end]
        - wgt[i] * dot(sum(x->x^2, Y - Ŷ, dims=2), precision)/2
    end
    return llh, states * (wgt/sum(wgt))
end

function p_log_prior(Z)
    ZZ = Z .* 1/sqrt(prior_var)
    exponent = -0.5*sum(ZZ.^2)
    return exponent
end

function p_log_posterior_unnorm(mtlds::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}}, 
        Y::AbstractMatrix{T}, U::AbstractMatrix{T}, Z::AbstractVecOrMat{T}) where {T,F}
    f_llh = p_log_llh(mtlds, Y, U, Z)
    f_prior = p_log_prior(Z)
    return f_llh + f_prior
end

function p_log_posterior_unnorm_beta(mtlds::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}}, 
        Y::AbstractMatrix{T}, U::AbstractMatrix{T}, Z::AbstractVecOrMat{T}, beta::T) where {T,F}
    f_llh = p_log_llh(mtlds, Y, U, Z)
    f_prior = p_log_prior(Z)
    return f_prior + beta*f_llh, f_llh + f_prior
end

In [37]:
# ============= CALCULATING LOGPX IS REALLY EXPENSIVE!! ============================
tmp = sample_posterior(cmtlds, [trainPool[i][:Y][:,1:680] for i in 1:7], 
    [trainPool[i][:U][:,1:680] for i in 1:7], 500)   # 8 secs for 10% of (*nonstationary!*) training set (M=500)

In [21]:
# GET DATA
trainSTL, validSTL, testSTL = mocapio.get_data(expmtdata, style_ix, :split, :stl);

In [106]:
k = 2                 # dimension of manifold
d_nn = 200            # "complexity" of manifold
d_subspace = 40       # dim of subspace (⊆ parameter space) containg the manifold

d_par = length(model.get_pars(clds_orig))
nn = Chain(Dense(k, d_nn, tanh), Dense(d_nn, d_subspace, σ), 
    Dense(d_subspace, d_par,identity, initW = ((dims...)->Flux.glorot_uniform(dims...)*0.1)),
    Flux.Diagonal(d_par))
# nn = Chain(Dense(k, d_par, identity), Flux.Diagonal(d_par))
clogσ = repeat([-1.5f0], size(clds_orig, 2))

cmtlds_g = model.mtldsg_from_lds(clds_orig, nn, clogσ, 0.1f0);
cmtlds = model.make_nograd(cmtlds_g);
# model.change_relative_lr!(cmtlds, 0.001f0)   # reduce sensitivity of chain params. 

In [107]:
opt = ADAM(8e-4)
pars = model.pars(cmtlds_g);

# --- Training parameters ---
T = eltype(cmtlds)
n_epochs   = 2
m_proposal = 200
m_bprop    = 3
batch_sz   = 192 # ???
prior_lsigma = MvNormal(zeros(T, 64), Diagonal(100*ones(T, 64)))
# --- /end ---

history = zeros(T, n_epochs, 2)
trainData = mocaputil.DataIterator(trainSTL, batch_sz, min_size=50)
Ws = mocaputil.weights(trainData, as_pct=false)

# for ee in 1:n_epochs

total_len  = sum(Ws)
epoch_loss = zero(T)
recon_loss = zero(T)
logp = zero(T)


In [1163]:
cmtlds = model.make_nograd(cmtlds_g);

In [1165]:
cmtlds.nn.layers[1].W[1,1] = -0.05

In [1166]:
cmtlds_g.nn.layers[1].W[1,1]

In [917]:
_Z = randn(Float32, 3, 1)

In [1008]:
sample_posterior(cmtlds, validSTL[1][:Y], validSTL[1][:U], 1, _Z)

In [1031]:
p_log_llh(cmtlds, validSTL[1][:Y], validSTL[1][:U], randn(Float32, 3))[1]

In [609]:
(Yb, Ub, h0) = first(trainData);

In [610]:
h0 && model.zero_state!(clds_g)
Tb = size(Yb, 2)

In [836]:
get_log_px(cmtlds, trainSTL[4][:Y], trainSTL[4][:U], 200)

In [24]:
cmtlds.nn(randn(Float32, 2))

In [990]:
rmse(validSTL, model.make_lds(cmtlds, randn(Float32, 3), cmtlds.η_h))/8

In [960]:
model.make_lds(cmtlds, randn(Float32, 3), cmtlds.η_h)(validSTL[1][:U])

In [162]:
_eps = randn(Float32, 3, 20);

In [177]:
sample_posterior(cmtlds, validSTL[1][:Y][:,1:192], validSTL[1][:U][:,1:192], 20, _eps)[1]

In [166]:
using AxPlot

In [171]:
scatter(_eps[1,:], _eps[3,:], alpha=0.1)
AxPlot.scatter_alpha(_eps[1,:], _eps[3,:], 
    f32(sample_posterior(cmtlds, validSTL[1][:Y], validSTL[1][:U], 20, _eps)[1])[:])

In [170]:
scatter(_eps[1,:], _eps[2,:], alpha=0.1)
AxPlot.scatter_alpha(_eps[1,:], _eps[2,:], 
    f32(sample_posterior(cmtlds, validSTL[1][:Y], validSTL[1][:U], 20, _eps)[1])[:])

In [None]:
model.make_lds(cmtlds, 

In [647]:
hist(vcat([vec(p.grad) for p in pars]...))
gca().set_yscale("log")

In [65]:
opt.eta *= 4

In [75]:
opt.eta /=2

In [950]:
0.5*(let δ=(cmtlds.logσ - prior_lsigma.μ); δ' * (inv(prior_lsigma.Σ) * δ); end) / Tb

In [272]:
cmtlds.logσ   #.-= 0.5

In [153]:
0.5*(let δ=(cmtlds.logσ - prior_lsigma.μ); δ' * (inv(prior_lsigma.Σ) * δ); end)

## Open questions

* Why does STL not appear to converge to even the non-MTL optimum?
    * Why is it still attempting to increase even with the dynamics grad ~=0?
* Why does logσ go in the wrong direction?



* *(Bonus: the low rank matrix thing of two linear layers stacked at the end didn't appear to work well. What is going on here? When does the sigmoid nonlinearity help? Is it a conditioning problem?)*

In [116]:
model.change_relative_lr!(cmtlds, 0.1f0) 
cmtlds_g.η_h .= 0.1f0

In [109]:
_lp = [get_log_px(cmtlds, trainSTL[i][:Y], trainSTL[i][:U], 200) for i in 1:length(trainSTL)]
sum(_lp)/sum([size(trainSTL[i][:Y],2) for i in 1:length(trainSTL)])

In [122]:
# --- Training parameters ---
T = eltype(cmtlds)
n_epochs   = 100
m_proposal = 300
m_bprop    = 2
batch_sz   = 192 # ???
# --- /end ---

history = zeros(T, n_epochs, 2)
trainData = mocaputil.DataIterator(trainSTL, batch_sz, min_size=50)
Ws = mocaputil.weights(trainData, as_pct=false)

logσ_prior = (-1.3, 0.05)
prior_lsigma = MvNormal(ones(T, 64)*logσ_prior[1], Diagonal(logσ_prior[2]*ones(T, 64)))

opt.eta = 0.2e-5    # 3.8e-5  0.8e-5
for ee in 1:n_epochs
    total_len  = sum(Ws)
    epoch_loss = zero(T)
    recon_loss = zero(T)
    logp = zero(T)


    for (Yb, Ub, h0) in trainData
        h0 && model.zero_state!(cmtlds_g)
        Tb = size(Yb, 2)      # not constant

        # ====== Generate approximate posterior samples ==============
        # ---> importance sample w forward sim.
        w, logp_s, Zproposal = sample_posterior(cmtlds, Yb, Ub, m_proposal)
        logp += sum(logp_s)

        # ---> resample and aggregate duplicates
        smps = AxUtil.Random.multinomial_indices_linear(m_bprop, w[:])
        # often duplicates (e.g. m_bprop of same sample), esp nr beginning. Aggregate to improve efficiency.
        smps, smp_wgt = countmap(smps) |> x-> (collect(keys(x)), collect(values(x)))
        Zs_post = Zproposal[:, smps]
        Zs_wgt  = T.(smp_wgt) / m_bprop
#         logp += 0
#         Zs_post = _Z[:,1:1]
#         Zs_wgt = [T(1.0)]

        # ===== AutoDiff-aware pass through posterior-integrated likelihood =====
        # ---> neg. log likelihood / reconstruction
        llh, state = p_log_llh(cmtlds_g, Yb, Ub, Zs_post, Zs_wgt)
        cmtlds.h .= state   # update state acc. to forward sim in p_log_llh.
#         println(llh.data)

        # ---> Update loss functions
        loss   = - sum(llh)  # *decrease* *negative* llh. (Don't normalise for data length.)
        recon_loss += -loss.data - Tb*(sum(cmtlds.logσ) + size(cmtlds, 2)*log(2π)/2)

        # .... Likelihood normalising constant  (# not *n* Tb sum(logσ) as we have normalised by n already)
        loss += Tb * sum(cmtlds.logσ)

        # .... Log Normal prior on log sigma
        loss += 0.5 * sum(x->x^2, (cmtlds.logσ .- logσ_prior[1]))/logσ_prior[2]
#         loss += 0.5*(let δ=(cmtlds.logσ - prior_lsigma.μ); δ' * (inv(prior_lsigma.Σ) * δ); end)

        # .... Regularisation of MLP weights
        # careful of regularizing layer 3 (due to lr rescale)
#         loss += 1e-3*sum(abs, cmtlds.nn.layers[1].W) / Tb
#         loss += 1e-3*sum(abs, cmtlds.nn.layers[2].W) / Tb

        epoch_loss += loss.data

        # ===== Backprop/update of posterior-integrated likelihood =====
        Tracker.back!(loss)

        for p in pars
#             display(norm(Tracker.grad(p)))
            Tracker.update!(opt, p, -Tracker.grad(p))
        end

    end

    history[ee,1] = logp / total_len
    history[ee,2] = recon_loss / total_len
    if (ee % 1 == 0)
        printfmtln("Epoch {1:d}, recon loss: {2:.4f}, approx logp(x): {3:.3f}", ee, history[ee,2], history[ee,1])
        flush(stdout)
    end

end  # for ee in 1:n_epochs
model.zero_state!(cmtlds_g)

In [124]:
plot(history[1:40,1])

In [125]:
_lp = [get_log_px(cmtlds, trainSTL[i][:Y], trainSTL[i][:U], 200) for i in 1:length(trainSTL)]
sum(_lp)/sum([size(trainSTL[i][:Y],2) for i in 1:length(trainSTL)])

In [None]:
# _eps = randn(Float32, 3, 20)

In [137]:
# printfmtln("RMSE = {:.3f}", rmse(cYs - Yhat)); flush(stdout)
# _Z = randn(Float32, 3, 20)
dset_i = 4
Yhats = [model.make_lds(cmtlds, _eps[1:k,i], cmtlds.η_h[1])(trainSTL[dset_i][:U]) for i in 1:12]

[display(p_log_llh(cmtlds, trainSTL[dset_i][:Y], trainSTL[dset_i][:U], _eps[1:k,i])[1]) for i in 1:4];


fig, axs = subplots(5,4,figsize=(10,10))
offset = 0
offset_tt = 400
_Δt = 200
for i = 1:20
    axs[:][i].plot(trainSTL[dset_i][:Y]'[(1:_Δt-1) .+ offset_tt, i+offset])
    [axs[:][i].plot(Yhats[j]'[(1:_Δt-1) .+ offset_tt, i+offset], alpha=0.4) for j in 1:12];
end
# savefig("mountains_in_the_sky.png")

In [131]:
plot(trainSTL[dset_i][:Y]'[(1:_Δt-1) .+ offset_tt, 2])
plot(Yhats[1]'[(1:_Δt-1) .+ offset_tt, 2], alpha=0.8)
plot(Yhats[2]'[(1:_Δt-1) .+ offset_tt, 2], alpha=0.8)
plot(Yhats[3]'[(1:_Δt-1) .+ offset_tt, 2], alpha=0.8)

_Z = randn(Float32, 2, 10);

In [140]:
printfmtln("Previously: RMSE/8: {:.3f}, LLH: {:.3f}.", 
    rmse(model.make_lds(cmtlds, _Z[:,1], cmtlds.η_h)(validSTL[2][:U]) - validSTL[2][:Y])/8,
    p_log_llh(cmtlds, validSTL[1][:Y], validSTL[1][:U], _Z[:,1])[1])

In [141]:
printfmtln("Post-op: RMSE/8: {:.3f}, LLH: {:.3f}.", 
    rmse(model.make_lds(cmtlds, _Z[:,1], cmtlds.η_h)(validSTL[2][:U]) - validSTL[2][:Y])/8,
    p_log_llh(cmtlds, validSTL[1][:Y], validSTL[1][:U], _Z[:,1])[1])

In [81]:
rmse(trainData, model.make_lds(cmtlds, _Z[:,1], cmtlds.η_h))/8

In [82]:
p_log_llh(cmtlds_g, first(trainData)[1], first(trainData)[2], _Z[:,1], [1.0f0])[1]

In [48]:
model.change_relative_lr!(cmtlds, 0.1f0) 

In [88]:
cmtlds_g.η_h 

In [86]:
println(cmtlds.nn.layers[1].W === cmtlds_g.nn.layers[1].W.data)
println(cmtlds.nn.layers[1].b === cmtlds_g.nn.layers[1].b.data)
println(cmtlds.nn.layers[2].W === cmtlds_g.nn.layers[2].W.data)
println(cmtlds.nn.layers[2].b === cmtlds_g.nn.layers[2].b.data)
println(cmtlds.nn.layers[3].W === cmtlds_g.nn.layers[3].W.data)
println(cmtlds.nn.layers[3].b === cmtlds_g.nn.layers[3].b.data)

In [181]:
Ψ = cmtlds.nn(_eps[:,1:3])

In [183]:
cmtlds.logσ
precision = 1 ./ exp.(2*cmtlds.logσ)
states = Matrix{T}(undef, size(cmtlds,1), size(_eps,2))

In [187]:
lds = model._make_lds_psi(cmtlds, Ψ[:,1], cmtlds.η_h)

In [192]:
X   = model.state_rollout(lds, validSTL[1][:U][:,1:192])
Ŷ   = lds.C * X + lds.D * validSTL[1][:U][:,1:192] .+ lds.d;

In [213]:
-dot(sum(x->x^2, validSTL[1][:Y][:,1:192] - Ŷ, dims=2), precision)/2 -_gauss_lognormconst(cmtlds.logσ, 192)

In [217]:
-dot(sum(x->x^2, validSTL[1][:Y][:,1:192] - Ŷ, dims=2), ones(64))/2 -_gauss_lognormconst(zeros(64), 192)

In [None]:
function p_log_llh(mtlds::Union{model.MTLDS_g{T,F}, model.MTLDS_ng{T,F}}, 
        Y::AbstractMatrix{T}, U::AbstractMatrix{T}, Z::AbstractVecOrMat{T}, 
        wgt::AbstractVector{T}=ones(T, size(Z, 2))) where {T,F}
    Ψ = mtlds.nn(Z)
    precision = 1 ./ exp.(2*mtlds.logσ)
    obj = zero(T)
    states = Matrix{T}(undef, size(mtlds,1), size(Z,2))
    
    for i in 1:size(Z,2)
        lds = model._make_lds_psi(mtlds, Ψ[:,i], mtlds.η_h)
        X   = model.state_rollout(lds, U)
        Ŷ   = lds.C * X + lds.D * U .+ lds.d;
        obj -= wgt[i] * dot(sum(x->x^2, Y - Ŷ, dims=2), precision)/2
        states[:,i] = Tracker.data(X)[:,end]
    end
    return obj, states * (wgt/sum(wgt))
end

In [210]:
sum([sum(logpdf.(Normal(0,1), validSTL[1][:Y][i,1:192] - Ŷ[i,:])) for i in 1:64])