# Analysis of KSE EP-OpInf

In [None]:
using FileIO
using JLD2
using LaTeXStrings
using LinearAlgebra
using Plots

include("../src/model/KS.jl")
include("../src/LiftAndLearn.jl")
const LnL = LiftAndLearn

# Settings for the KS equation
KSE = KS(
    [0.0, 100.0], [0.0, 200.0], [1.0, 1.0],
    512, 0.001, 1, "ep"
)

# WARNING:DO YOU WANT TO SAVE DATA?
save_data = true

# Create file name to save data
datafile = "data/kse_data.jld2"
opfile = "data/kse_operators.jld2"
resultfile = "data/kse_results.jld2"

# Downsampling rate
DS = 100

# Down-sampled dimension of the time data
Tdim_ds = size(1:DS:KSE.Tdim, 1)  # downsampled time dimension

# Number of random test inputs
num_test_ic = 50
;

In [None]:
DATA = load(datafile);

In [None]:
OPS = load(opfile);

In [None]:
RES = load(resultfile)
RES = Dict{String, Any}(RES); # convert type to avoid errors

In [None]:
RES["train_proj_err"] = Array{Float64}(undef, length(ro), KSE.Pdim) 
RES["train_state_err"] = Dict(
    :int => Array{Float64}(undef, length(ro), KSE.Pdim),
    :LS => Array{Float64}(undef, length(ro), KSE.Pdim),
    :ephec => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epsic => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epp => Array{Float64}(undef, length(ro), KSE.Pdim),
)
RES["train_CR"] = Dict(
    :int => Array{Float64}(undef, length(ro), KSE.Pdim),
    :LS => Array{Float64}(undef, length(ro), KSE.Pdim),
    :ephec => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epsic => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epp => Array{Float64}(undef, length(ro), KSE.Pdim),
    :fom => Array{Float64}(undef, KSE.Pdim)
)
RES["train_AC"] = Dict(
    :int => Array{Float64}(undef, length(ro), KSE.Pdim),
    :LS => Array{Float64}(undef, length(ro), KSE.Pdim),
    :ephec => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epsic => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epp => Array{Float64}(undef, length(ro), KSE.Pdim),
    :fom => Array{Float64}(undef, KSE.Pdim)
)
RES["train_FTLE"] = Dict(
    :int => Array{Float64}(undef, length(ro), KSE.Pdim),
    :LS => Array{Float64}(undef, length(ro), KSE.Pdim),
    :ephec => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epsic => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epp => Array{Float64}(undef, length(ro), KSE.Pdim),
    :fom => Array{Float64}(undef, KSE.Pdim)
)
RES["train_ME"] = Dict(
    :int => Array{Float64}(undef, length(ro), KSE.Pdim),
    :LS => Array{Float64}(undef, length(ro), KSE.Pdim),
    :ephec => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epsic => Array{Float64}(undef, length(ro), KSE.Pdim),
    :epp => Array{Float64}(undef, length(ro), KSE.Pdim),
    :fom => Array{Float64}(undef, KSE.Pdim)
)
;


In [None]:
function EP_analyze(Data, operators, model, integrator, options, X_all, IC, ro)

    ro_dim = length(ro)  # Get the number of reduced orders

    PE_all = Array{Float64}(undef, ro_dim, model.Pdim) 
    # Relative state error
    SE_all = Dict(
        :int => Array{Float64}(undef, ro_dim, model.Pdim),
        :LS => Array{Float64}(undef, ro_dim, model.Pdim),
        :ephec => Array{Float64}(undef, ro_dim, model.Pdim),
        :epsic => Array{Float64}(undef, ro_dim, model.Pdim),
        :epp => Array{Float64}(undef, ro_dim, model.Pdim),
    )
    # Constraint Residual
    CR_all = Dict(
        :int => Array{Float64}(undef, ro_dim, model.Pdim),
        :LS => Array{Float64}(undef, ro_dim, model.Pdim),
        :ephec => Array{Float64}(undef, ro_dim, model.Pdim),
        :epsic => Array{Float64}(undef, ro_dim, model.Pdim),
        :epp => Array{Float64}(undef, ro_dim, model.Pdim),
        :fom => Array{Float64}(undef, model.Pdim)
    )
    # Autocorrelation
    AC_all = Dict(
        :int => Array{Float64}(undef, ro_dim, model.Pdim),
        :LS => Array{Float64}(undef, ro_dim, model.Pdim),
        :ephec => Array{Float64}(undef, ro_dim, model.Pdim),
        :epsic => Array{Float64}(undef, ro_dim, model.Pdim),
        :epp => Array{Float64}(undef, ro_dim, model.Pdim),
    )
    # Finite-time Lyapunov exponent
    FTLE_all = Dict(
        :int => Array{Float64}(undef, ro_dim, model.Pdim),
        :LS => Array{Float64}(undef, ro_dim, model.Pdim),
        :ephec => Array{Float64}(undef, ro_dim, model.Pdim),
        :epsic => Array{Float64}(undef, ro_dim, model.Pdim),
        :epp => Array{Float64}(undef, ro_dim, model.Pdim),
    )

    # Load values
    op_LS = operators["op_LS"]
    op_int = operators["op_int"]
    op_ephec = operators["op_ephec"]
    op_epsic = operators["op_epsic"]
    op_epp = operators["op_epp"]
    Vrmax = Data["Vr"]

    num_ic_params = length(IC)  # Get the number of initial conditions

    @info "Analyze the operators..."
    @showprogress for i in 1:length(model.μs)

        # Energy, constraint residual, and constraint violation of the FOM
        EN_all[:fom][i] = norm.(eachcol(X_all[i]), 2)
        F_full = Data["op_fom_tr"][i].F
        CR_all[:fom][i], MMT_all[:fom][i] = LnL.constraintResidual(F_full, size(F_full, 1), options.optim.which_quad_term)
        CV_all[:fom][i] = LnL.constraintViolation(X_all[i], F_full, options.optim.which_quad_term)

        if isnothing(ro)
            ro = rmin:rmax
        end
        
        for (j,r) in enumerate(ro)
            Vr = Vrmax[i][:, 1:r]

            # Temporary data storage
            PE = Array{Float64}(undef, num_ic_params)  # projection error
            SE = Dict(
                :int => Array{Float64}(undef, num_ic_params),
                :LS => Array{Float64}(undef, num_ic_params),
                :ephec => Array{Float64}(undef, num_ic_params),
                :epsic => Array{Float64}(undef, num_ic_params),
                :epp => Array{Float64}(undef, num_ic_params),
            )
            En = Dict(
                :int => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :LS => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :ephec => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :epsic => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :epp => Matrix{Float64}(undef, num_ic_params, model.Tdim),
            )
            CR = Dict(
                :int => Vector{Float64}(undef, num_ic_params),
                :LS => Vector{Float64}(undef, num_ic_params),
                :ephec => Vector{Float64}(undef, num_ic_params),
                :epsic => Vector{Float64}(undef, num_ic_params),
                :epp => Vector{Float64}(undef, num_ic_params),
            )
            mmt = Dict(
                :int => Vector{Float64}(undef, num_ic_params),
                :LS => Vector{Float64}(undef, num_ic_params),
                :ephec => Vector{Float64}(undef, num_ic_params),
                :epsic => Vector{Float64}(undef, num_ic_params),
                :epp => Vector{Float64}(undef, num_ic_params),
            )
            CV = Dict(
                :int => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :LS => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :ephec => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :epsic => Matrix{Float64}(undef, num_ic_params, model.Tdim),
                :epp => Matrix{Float64}(undef, num_ic_params, model.Tdim),
            )

            for (ct, ic) in enumerate(IC)

                # Integrate the LS operator inference model
                Finf_extract_LS = LnL.extractF(op_LS[i].F, r)
                Xinf_LS = integrator(op_LS[i].A[1:r, 1:r], Finf_extract_LS, KSE.t, Vr' * ic)

                # Integrate the intrusive model
                Fint_extract = LnL.extractF(op_int[i].F, r)
                Xint = integrator(op_int[i].A[1:r, 1:r], Fint_extract, KSE.t, Vr' * ic)
                
                # if options.optim.SIGE
                #     # Integrate the energy-preserving hard equality constraint operator inference model
                #     Finf_extract_ephec = op_ephec[i,j].F
                #     Xinf_ephec = integrator(op_ephec[i,j].A, op_ephec[i,j].F, KSE.t, Vr' * ic)
                #     # Finf_extract_ephec = LnL.extractF(op_ephec[i,rmax].F, r)
                #     # Xinf_ephec = integrator(op_ephec[i,rmax].A[1:r, 1:r], Finf_extract_ephec, KSE.t, Vr' * ic)
                    
                #     # Integrate the energy-preserving soft inequality constraint operator inference model
                #     Finf_extract_epsic = op_epsic[i,j].F
                #     Xinf_epsic = integrator(op_epsic[i,j].A, op_epsic[i,j].F, KSE.t, Vr' * ic)
                #     # Finf_extract_epsic = LnL.extractF(op_epsic[i,rmax].F, r)
                #     # Xinf_epsic = integrator(op_epsic[i,rmax].A[1:r, 1:r], Finf_extract_epsic, KSE.t, Vr' * ic)

                #     # Integrate the energy-preserving unconstrained operator inference model
                #     Finf_extract_epp = op_epp[i,j].F
                #     Xinf_epp = integrator(op_epp[i,j].A, op_epp[i,j].F, KSE.t, Vr' * ic)
                #     # Finf_extract_epuc = LnL.extractF(op_epuc[i,rmax].F, r)
                #     # Xinf_epuc = integrator(op_epuc[i,rmax].A[1:r, 1:r], Finf_extract_epuc, KSE.t, Vr' * ic)
                # else
                #     # Integrate the energy-preserving hard equality constraint operator inference model
                #     Finf_extract_ephec = LnL.extractF(op_ephec[i].F, r)
                #     Xinf_ephec = integrator(op_ephec[i].A[1:r, 1:r], Finf_extract_ephec, KSE.t, Vr' * ic)
                    
                #     # Integrate the energy-preserving soft inequality constraint operator inference model
                #     Finf_extract_epsic = LnL.extractF(op_epsic[i].F, r)
                #     Xinf_epsic = integrator(op_epsic[i].A[1:r, 1:r], Finf_extract_epsic, KSE.t, Vr' * ic)

                #     # Integrate the energy-preserving unconstrained operator inference model
                #     Finf_extract_epp = LnL.extractF(op_epp[i].F, r)
                #     Xinf_epp = integrator(op_epp[i].A[1:r, 1:r], Finf_extract_epp, KSE.t, Vr' * ic)
                # end

                # Compute the projection error
                PE[ct] = LnL.compProjError(X_all[i,ct], Vr)

                # Compute the relative state error
                SE[:LS][ct] = LnL.compStateError(X_all[i,ct][:, 1:DS:end], Xinf_LS[:, 1:DS:end], Vr)
                SE[:int][ct] = LnL.compStateError(X_all[i,ct][:, 1:DS:end], Xint[:, 1:DS:end], Vr)
                # SE[:ephec][ct] = LnL.compStateError(X_all[i,ct][:, 1:DS:end], Xinf_ephec[:, 1:DS:end], Vr)
                # SE[:epsic][ct] = LnL.compStateError(X_all[i,ct][:, 1:DS:end], Xinf_epsic[:, 1:DS:end], Vr)
                # SE[:epp][ct] = LnL.compStateError(X_all[i,ct][:, 1:DS:end], Xinf_epp[:, 1:DS:end], Vr)

                # Compute the energy
                En[:LS][ct, :] .= norm.(eachcol(Vr * Xinf_LS), 2)
                En[:int][ct, :] .= norm.(eachcol(Vr * Xint), 2)
                # En[:ephec][ct, :] .= norm.(eachcol(Vr * Xinf_ephec), 2)
                # En[:epsic][ct, :] .= norm.(eachcol(Vr * Xinf_epsic), 2)
                # En[:epp][ct, :] .= norm.(eachcol(Vr * Xinf_epp), 2)

                # Compute the constraint residual and momentum
                CR[:LS][ct], mmt[:LS][ct] =  LnL.constraintResidual(Finf_extract_LS, r, options.optim.which_quad_term)
                CR[:int][ct], mmt[:int][ct] = LnL.constraintResidual(Fint_extract, r, options.optim.which_quad_term)
                # CR[:ephec][ct], mmt[:ephec][ct] = LnL.constraintResidual(Finf_extract_ephec, r, options.optim.which_quad_term)
                # CR[:epsic][ct], mmt[:epsic][ct] = LnL.constraintResidual(Finf_extract_epsic, r, options.optim.which_quad_term)
                # CR[:epp][ct], mmt[:epp][ct] = LnL.constraintResidual(Finf_extract_epp, r, options.optim.which_quad_term)

                # Compute the constraint violation
                CV[:LS][ct, :] .= LnL.constraintViolation(Xinf_LS, Finf_extract_LS, options.optim.which_quad_term)
                CV[:int][ct, :] .= LnL.constraintViolation(Xint, Fint_extract, options.optim.which_quad_term)
                # CV[:ephec][ct, :] .= LnL.constraintViolation(Xinf_ephec, Finf_extract_ephec, options.optim.which_quad_term)
                # CV[:epsic][ct, :] .= LnL.constraintViolation(Xinf_epsic, Finf_extract_epsic, options.optim.which_quad_term)
                # CV[:epp][ct, :] .= LnL.constraintViolation(Xinf_epp, Finf_extract_epp, options.optim.which_quad_term)
            end

            # Compute errors
            PE_all[j, i] = mean(PE)
            SE_all[:LS][j, i] = mean(SE[:LS])
            SE_all[:int][j, i] = mean(SE[:int])
            # SE_all[:ephec][j, i] = mean(SE[:ephec])
            # SE_all[:epsic][j, i] = mean(SE[:epsic])
            # SE_all[:epp][j, i] = mean(SE[:epp])
            
            # Compute energy 
            EN_all[:LS][j, i] = mean(En[:LS], dims=1)
            EN_all[:int][j, i] = mean(En[:int], dims=1)
            # EN_all[:ephec][j, i] = mean(En[:ephec], dims=1)
            # EN_all[:epsic][j, i] = mean(En[:epsic], dims=1)
            # EN_all[:epp][j, i] = mean(En[:epp], dims=1)
            
            # Compute the CR and momentum
            CR_all[:LS][j, i] = mean(CR[:LS])
            CR_all[:int][j, i] = mean(CR[:int])
            # CR_all[:ephec][j, i] = mean(CR[:ephec])
            # CR_all[:epsic][j, i] = mean(CR[:epsic])
            # CR_all[:epp][j, i] = mean(CR[:epp])

            MMT_all[:LS][j, i] = mean(mmt[:LS])
            MMT_all[:int][j, i] = mean(mmt[:int])
            # MMT_all[:ephec][j, i] = mean(mmt[:ephec])
            # MMT_all[:epsic][j, i] = mean(mmt[:epsic])
            # MMT_all[:epp][j, i] = mean(mmt[:epp])

            # Compute constraint violations
            CV_all[:LS][j, i] = mean(CV[:LS], dims=1)
            CV_all[:int][j, i] = mean(CV[:int], dims=1)
            # CV_all[:ephec][j, i] = mean(CV[:ephec], dims=1)
            # CV_all[:epsic][j, i] = mean(CV[:epsic], dims=1)
            # CV_all[:epp][j, i] = mean(CV[:epp], dims=1)
        end
    end

    Data = Dict{String, Any}(Data)  # convert type to avoid errors
    Data["train_proj_err"] = PE_all
    Data["train_state_err"] = SE_all
    Data["train_En"] = EN_all
    Data["train_CR"] = CR_all
    Data["train_mmt"] = MMT_all
    Data["train_CV"] = CV_all

    return Data
end
;


## Autocorrelation

In [None]:
function analyze_autocorr(op, model, X_all, Vr_all, IC, ro, DS, integrator; burn_factor=10)
    # auto_correletion
    auto_correlation = Array{Float64}(undef, length(ro), model.Pdim)
    num_ic_params = length(IC)

    @showprogress for i in eachindex(model.μs)
        for (j,r) in enumerate(ro)
            Vr = Vr_all[i][:, 1:r]
            ac = Array{Float64}(undef, num_ic_params)

            for (ct, ic) in enumerate(IC)
                Fextract = LnL.extractF(op[i].F, r)
                X = integrator(op[i].A[1:r, 1:r], Fextract, model.t, Vr' * ic)
                Xrecon = Vr * X
                ac[ct] = autocorr(Xrecon, burn_factor)
            end
            auto_correlation[j, i] = mean(ac)
        end
    end
end

function autocorr(X, burn_factor)
    N, K = size(X)
    Cx = zeros((N, K))
    pfft = plan_fft(Cx[:, 1])

    for i in 1:K
        Xhat = fftshift(pfft * X[:, i]) / N
        Cx[:, i] = real.(N*ifft(ifftshift(Xhat[:, i] .* conj.(Xhat[:, i])) )) 
    end
    N_burn = N÷burn_factor
    N_burn = (N_burn == 0) ? 1 : N_burn
    auto_correletion =  (sum(Cx[:, N_burn:end], dims=2)) / (N-N_burn)
    return auto_correletion
end


In [None]:
# Least-squares
RES["train_AC"][:LS] = analyze_autocorr(OPS["op_LS"], KSE, DATA["Xtr_all"], DATA["Vr"], DATA["IC_train"], DATA["ro"], DS, KSE.integrate_FD);

In [None]:
# Intrusive
RES["train_AC"][:int] = analyze_rse(OPS["op_int"], KSE, DATA["Xtr_all"], DATA["Vr"], DATA["IC_train"], DATA["ro"], DS, KSE.integrate_FD);

In [None]:
# EPHEC
RES["train_state_err"][:ephec] = analyze_rse(OPS["op_ephec"], KSE, DATA["Xtr_all"], DATA["Vr"], DATA["IC_train"], DATA["ro"], DS, KSE.integrate_FD);

In [None]:
mean_LS_state_err = mean(RES["train_state_err"][:LS], dims=2)
mean_int_state_err = mean(RES["train_state_err"][:int], dims=2)
# mean_ephec_state_err = mean(RES["train_state_err"][:ephec], dims=2)
# mean_epsic_state_err = mean(RES["train_state_err"][:epsic], dims=2)
# mean_epp_state_err = mean(RES["train_state_err"][:epp], dims=2)

plot(DATA["ro"], mean_LS_state_err, c=:black, marker=(:circle, 3.5, :black), label=L"\mathrm{opinf}")
plot!(DATA["ro"], mean_int_state_err, c=:orange, marker=(:cross, 8, :orange), label=L"\mathrm{intrusive}")
# plot!(DATA["ro"], mean_ephec_state_err, c=:blue, markerstrokecolor=:blue, marker=(:rect, 3), ls=:dash, lw=2, label=L"\mathrm{ephec}\rm{-}\mathrm{opinf}")
# plot!(DATA["ro"], mean_epsic_state_err, c=:purple, markerstrokecolor=:purple, marker=(:dtriangle, 5), ls=:dot, label=L"\mathrm{epsic}\rm{-}\mathrm{opinf}")
# plot!(DATA["ro"], mean_epp_state_err, c=:red, markerstrokecolor=:red, marker=(:star, 4), lw=1, ls=:dash, label=L"\mathrm{epp}\rm{-}\mathrm{opinf}")
plot!(yscale=:log10, majorgrid=true, legend=:bottomleft)
# tmp = log10.(mean_int_state_err)
# yticks!([10.0^i for i in floor(minimum(tmp))-1:ceil(maximum(tmp))+1])
# xticks!(rmin:rmax)
xlabel!(L"\mathrm{reduced~model~dimension~} r")
ylabel!(L"\mathrm{average~~relative~~state~~error}")
plot!(guidefontsize=16, tickfontsize=13,  legendfontsize=13)
