# Linear dispersive equation with BBM-type dispersive term

In [None]:
include("setup.jl")

# Discretization

In [None]:
function relaxation_functional(γ, unew, uold, param)
    @unpack D, D2, tmp1, tmp2 = param
    
    @. tmp2 = (1-γ)*uold + γ*unew
    mul!(tmp1, D2, tmp2)
    @. tmp1 = tmp2^2 - tmp2 * tmp1
    energy = integrate(tmp1, D)
end
function relaxation_functional(u, param)
    @unpack D, D2, tmp1 = param
    
    mul!(tmp1, D2, u)
    @. tmp1 = u^2 - u * tmp1
    energy = integrate(tmp1, D)
end


function save_func_linear_periodic(u, t, integrator)
    @unpack D, D2, tmp1, usol = integrator.p
    print(".")
    
    mass = integrate(u, D)
    
    mul!(tmp1, D2, u)
    @. tmp1 = u^2 - u * tmp1
    quadratic = integrate(tmp1, D)
    
    x = grid(D)
    tmp1 .= ( usol.(t, x, x[1], -x[1]) .- u ).^2
    error_u = integrate(tmp1, D) |> sqrt
    
    SVector(mass, quadratic, error_u)
end


function linear_periodic!(du, u, param, t)
    @unpack D, invImD2, tmp1, tmp2 = param
    
    # conservative semidiscretization
    mul!(tmp1, D, u)
    @. tmp1 = -tmp1
    ldiv!(du, invImD2, tmp1)
    
    nothing
end

function solve_ode_linear_periodic(usol, D, D2, tspan, alg, tol, dt, adaptive)
    invImD2 = isa(D2, AbstractMatrix) ? factorize(I - D2) : I - D2
    
    x = grid(D)
    u0 = usol.(tspan[1], x, x[1], -x[1])
    tmp1 = similar(u0); tmp2 = similar(tmp1)
    param = (D=D, D2=D2, invImD2=invImD2, tmp1=tmp1, tmp2=tmp2, usol=usol)

    ode = ODEProblem(linear_periodic!, u0, tspan, param)
    
    saveat = range(tspan..., length=100)
    saved_values_baseline = SavedValues(eltype(D), SVector{3,eltype(D)})
    saving_baseline = SavingCallback(save_func_linear_periodic, saved_values_baseline, saveat=saveat)
    saved_values_relaxation = SavedValues(eltype(D), SVector{3,eltype(D)})
    saving_relaxation = SavingCallback(save_func_linear_periodic, saved_values_relaxation, saveat=saveat)
    relaxation = DiscreteCallback((u,t,integrator) -> true, relaxation!, save_positions=(false,false))
    cb_baseline = CallbackSet(saving_baseline)
    cb_relaxation = CallbackSet(relaxation, saving_relaxation)
    
    sol_relaxation = solve(ode, alg, abstol=tol, reltol=tol, dt=dt, adaptive=adaptive, save_everystep=false,
        callback=cb_relaxation, tstops=saveat, saveat=saveat, maxiters=10^8)
    flush(stdout)
    sol_baseline = solve(ode, alg, abstol=tol, reltol=tol, dt=dt, adaptive=adaptive, save_everystep=false,
        callback=cb_baseline, tstops=saveat, saveat=saveat, maxiters=10^8)
    flush(stdout)

    unum_baseline   = sol_baseline[end]
    unum_relaxation = sol_relaxation[end]
    uana = usol.(tspan[end], x, x[1], -x[1])
    @printf("Error in u (baseline):   %.3e\n", integrate(u->u^2, unum_baseline - uana, D) |> sqrt)
    @printf("Error in u (relaxation): %.3e\n", integrate(u->u^2, unum_relaxation - uana, D) |> sqrt)
    @printf("Difference of baseline and relaxation in u: %.3e\n", 
        integrate(u->u^2, unum_baseline - unum_relaxation, D) |> sqrt)

    sleep(0.1)
    fig_u, ax = plt.subplots(1, 1)
    plt.plot(x, unum_baseline, label="non-conservative")
    plt.plot(x, unum_relaxation, label="conservative")
    plt.plot(x, uana, ":", color="gray", label="analytical")
    plt.xlabel(L"x"); plt.ylabel(L"u")

    t_baseline = saved_values_baseline.t
    t_relaxation = saved_values_relaxation.t
    mass_baseline      = map(x->x[1], saved_values_baseline.saveval)
    mass_relaxation    = map(x->x[1], saved_values_relaxation.saveval)
    quadratic_baseline   = map(x->x[2], saved_values_baseline.saveval)
    quadratic_relaxation = map(x->x[2], saved_values_relaxation.saveval)

    fig_invariants, ax = plt.subplots(1, 1)
    ax.set_yscale("symlog", linthreshy=1.0e-14)
    plt.plot(t_baseline,   mass_baseline   .- mass_baseline[1], 
                label=L"$\int u$ (non-conservative)",
                color="#E69F00", linestyle="-")
    plt.plot(t_relaxation, mass_relaxation .- mass_relaxation[1], 
                label=L"$\int u$ (conservative)",
                color="#56B4E9", linestyle="-")
    plt.plot(t_baseline,   quadratic_baseline   .- quadratic_baseline[1], 
                label=L"$\int (u^2 + (\partial_x u)^2)$ (non-conservative)",
                color="#E69F00", linestyle="--")
    plt.plot(t_relaxation, quadratic_relaxation .- quadratic_relaxation[1], 
                label=L"$\int (u^2 + (\partial_x u)^2)$ (conservative)",
                color="#56B4E9", linestyle="--")
    plt.xlabel(L"t"); plt.ylabel("Change of Invariants")
    plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
    
    error_u_baseline   = map(x->x[3], saved_values_baseline.saveval)
    error_u_relaxation = map(x->x[3], saved_values_relaxation.saveval)

    fig_error, ax = plt.subplots(1, 1)
    ax.set_xscale("log")
    ax.set_yscale("log")
    plt.plot(t_baseline,   error_u_baseline, label=L"Error of $u$ (non-conservative)")
    plt.plot(t_relaxation, error_u_relaxation, label=L"Error of $u$ (conservative)")
    plt.xlabel(L"t"); plt.ylabel("Error")
#     plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
    
    fig_legend = plt.figure()
    handles, labels = fig_u.axes[1].get_legend_handles_labels()
    plt.figlegend(handles, labels, loc="center", ncol=3)
        
    (;sol_baseline, sol_relaxation, saved_values_baseline, saved_values_relaxation, 
        fig_u, fig_invariants, fig_error, fig_legend)
end


# Error growth of traveling wave solutions

In [None]:
# traveling wave solution
xmin = -1.
xmax = -xmin
get_c() = 1 / (1 + 4*π^2)
function usol(t, x, xmin, xmax)
    c = get_c()
    K = sqrt(1/c - 1)
    x_t = mod(x - c*t - xmin, xmax - xmin) + xmin
    
    sin(K * x_t)
end


In [None]:
println("c = ", get_c())
@show xmin
@show xmax
@show usol(0., xmin, xmin, xmax)
@show usol(0., xmax, xmin, xmax)
@show N = 2^6
@show dt = 1.0e-1 * (xmax - xmin) / N
flush(stdout)

# tspan = (0.0, (xmax-xmin)/(3*get_c()) + 1*(xmax-xmin)/get_c())
# tspan = (0.0, (xmax-xmin)/(3*get_c()) + 10*(xmax-xmin)/get_c())
tspan = (0.0, (xmax-xmin)/(3*get_c()) + 50_000*(xmax-xmin)/get_c())

tol = 1.0e-5
adaptive = true
D = fourier_derivative_operator(xmin, xmax, N)
D2 = D^2

results = solve_ode_linear_periodic(usol, D, D2, tspan, Tsit5(), tol, dt, adaptive);

In [None]:
ax = results.fig_error.axes[1]
t = [1.0e3, 5.0e6]
ax.plot(t, 1.0e-7 .* t.^1, ":", color="gray")
ax.annotate(L"\mathcal{O}(t^{1})", (1.0e5, 5.0e-3), color="gray")

results.fig_error.savefig(joinpath(dirname(@__DIR__), "figures", "linear_error.pdf"), 
    bbox_inches="tight")

results.fig_u.savefig(joinpath(dirname(@__DIR__), "figures", "linear_solution_u.pdf"), bbox_inches="tight")
results.fig_legend.savefig(joinpath(dirname(@__DIR__), "figures", "linear_solution_legend.pdf"), bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1, 1)
ax.set_xscale("log")
ax.set_yscale("log")
for line in results.fig_error.axes[1].lines
    ax.plot(line.get_xdata(), line.get_ydata())
end
t = [1.0e3, 5.0e6]
ax.plot(t, 1.0e-7 .* t.^1, ":", color="gray")
ax.annotate(L"\mathcal{O}(t^{1})", (1.0e5, 5.0e-3), color="gray")