## Examples of running multiway (inverse) covariance models for ensemble Kalman filter

In [1]:
using Plots
using MIRT: jim
using TensorGraphicalModels

┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1278
┌ Info: Precompiling MIRT [7035ae7a-3787-11e9-139a-5545ed3dc201]
└ @ Base loading.jl:1278
│   exception = (LoadError("/home/wayneyw/.julia/packages/Plots/Awg62/src/backends/hdf5.jl", 162, UndefVarError(:Group)), Union{Ptr{Nothing}, Base.InterpreterIP}[Ptr{Nothing} @0x00007f0a53bf105f, Ptr{Nothing} @0x00007f0a53c8198c, Ptr{Nothing} @0x00007f0a53c81ed5, Ptr{Nothing} @0x00007f0a53c81b3f, Ptr{Nothing} @0x00007f0a53c82543, Ptr{Nothing} @0x00007f0a53c837a7, Base.InterpreterIP in top-level CodeInfo for Plots._hdf5_implementation at statement 4, Ptr{Nothing} @0x00007f0a53ca0a19, Ptr{Nothing} @0x00007f0a53ca1192, Ptr{Nothing} @0x00007f0a53ca057a, Ptr{Nothing} @0x00007f0a53ca0519, Ptr{Nothing} @0x00007f0a53c73205, Ptr{Nothing} @0x00007f0a53ca1fc1, Ptr{Nothing} @0x00007f0a456d0cae, Ptr{Nothing} @0x00007f09c81d167e, Ptr{Nothing} @0x00007f0a53c68727, Ptr{Nothing} @0x00007f0a53c81ef5, Ptr{Nothing} @0x00007f0a5

In [3]:
# generate ground truth data 
dynamic_type = "poisson"
obs_type = "linear_perm_miss"
T = 20
N = 50
px = py = (32, 32)
obs_noise = 0.01
process_noise = 0.01
add_process_noise = false

X, Y, H = TensorGraphicalModels.gen_kalmanfilter_data(dynamic_type, obs_type, T, px, py, obs_noise, process_noise, add_process_noise)

LoadError: [91mUndefVarError: px not defined[39m

In [None]:
# gif all time steps
anim_x = @animate for i=2:(T+1)
    # plot(jim(reshape(X[:,i],px),clim=(-3.0,3.0)),
    #     title=string("Time step: ",i))
    Plots.plot(jim(reshape(X[:,i],px)),
        title=string("Time step: ",i))
end
gif(anim_x, fps=5)

In [None]:
# run enkf with tensor graphical models
method_list = ["glasso", "kpca", "kglasso", "teralasso", "sg_palm"]
NRMSEs_list = []
time_list = []
Omegahat_list = []
for method in method_list
    # starttime = time()
    ## run enkf 
    Xhat, Xhat_bar, _ = enkf(Y, 
                            method_str_to_type(method),
                            dynamic_type,
                            H,
                            px,
                            py,
                            N,
                            obs_noise, 
                            process_noise,
                            add_process_noise)
    # ## timer
    # stoptime = time() - starttime
    # push!(time_list, stoptime)
    ## compute NRMSEs
    NRMSEs = compute_nrmse(X, Xhat)
    push!(NRMSEs_list, NRMSEs)
    # ## store est. precision matrix
    # push!(Omegahat_list, Omegahat)
end

In [None]:
# set up plots of nrmses
fig = Plots.plot()
xlabel!("Time step")
ylabel!("RMSE")
for method in method_list
    if method == "sg_palm"
        NRMSEs = NRMSEs_list[end] 
    elseif method == "teralasso"
        NRMSEs = NRMSEs_list[4] 
    elseif method == "kglasso"
        NRMSEs = NRMSEs_list[3] 
    elseif method == "glasso"
        NRMSEs = NRMSEs_list[1] 
    elseif method == "kpca"
        NRMSEs = NRMSEs_list[2] 
    end
    ## plot rmse progression for each method
    plot_nrmse!(NRMSEs, method)
end
# plot_nrmse!(NRMSEs_list[1], method_list[1])
display(fig)