In [None]:
using LinearAlgebra
using Printf
using SparseArrays

using UnPack
using OrdinaryDiffEq, DiffEqCallbacks
using Roots
using StaticArrays
using SummationByPartsOperators

using PyCall, LaTeXStrings; import PyPlot; plt=PyPlot
inset_locator = pyimport("mpl_toolkits.axes_grid.inset_locator")

cycler = pyimport("cycler").cycler
line_cycler   = (cycler(color=["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"]) +
                 cycler(linestyle=["-", "--", "-.", ":", "-", "--", "-."]))
marker_cycler = (cycler(color=["#E69F00", "#56B4E9", "#009E73", "#0072B2", "#D55E00", "#CC79A7", "#F0E442"]) +
                 cycler(linestyle=["none", "none", "none", "none", "none", "none", "none"]) +
                 cycler(marker=["4", "2", "3", "1", "+", "x", "."]))

plt.rc("axes", prop_cycle=line_cycler)
plt.rc("text", usetex=true)
plt.rc("text.latex", preamble="\\usepackage{newpxtext}\\usepackage{newpxmath}\\usepackage{commath}\\usepackage{mathtools}")
plt.rc("font", family="serif", size=18.)
plt.rc("savefig", dpi=100)
plt.rc("legend", loc="best", fontsize="medium", fancybox=true, framealpha=0.5)
plt.rc("lines", linewidth=2.5, markersize=10, markeredgewidth=2.5)


In [None]:
function relaxation!(integrator)
    told = integrator.tprev
    uold = integrator.uprev
    tnew = integrator.t
    unew = integrator.u

    terminate_integration = false
    
    # General approach using a scalar root finding procedure
    γ = one(tnew)
    γlo = one(γ) / 2
    γhi = 3*one(γ) / 2
    energy_old = relaxation_functional(uold, integrator.p)
    sign_condition = (relaxation_functional(γlo, unew, uold, integrator.p)-energy_old) * (relaxation_functional(γhi, unew, uold, integrator.p)-energy_old)
    if iszero(sign_condition)
        # everything is fine
        γ = one(tnew)
    elseif sign_condition > 0
        terminate_integration = true
        @warn "Terminating integration because no solution γ can be found."
    else
        γ = find_zero(g -> relaxation_functional(g, unew, uold, integrator.p)-energy_old, (γlo, γhi), Roots.AlefeldPotraShi())
    end
    
    if γ < eps(typeof(γ))
        terminate_integration = true
        @warn "Terminating integration because γ=$γ is too small."
    end

    @. unew = uold + γ * (unew - uold)
    DiffEqBase.set_u!(integrator, unew)
    if !(tnew ≈ first(integrator.opts.tstops))
        tγ = told + γ * (tnew - told)
        # doesn't work always, see https://github.com/SciML/OrdinaryDiffEq.jl/issues/1353
        # DiffEqBase.set_t!(integrator, tγ)
        integrator.t = tγ
    end

    if terminate_integration
        terminate!(integrator)
    end

    nothing
end


function quadratic_projection!(integrator)
#     uold = integrator.uprev
    uold = integrator.sol.prob.u0
    unew = integrator.u
    @unpack D1 = integrator.sol.prob.p
    
    factor = sqrt(integrate(u -> u^2, uold, D1) / integrate(u -> u^2, unew, D1))
    unew .= factor .* unew
    
    DiffEqBase.set_u!(integrator, unew)
    nothing
end


function relaxation_functional(γ, unew, uold, param)
    @unpack D1, tmp1 = param
    
    @. tmp1 = (1-γ)*uold + γ*unew
    energy = integrate(u -> u^2, tmp1, D1)
end

function relaxation_functional(u, param)
    @unpack D1 = param
    
    energy = integrate(u -> u^2, u, D1)
end


function save_func_KdV(u, t, integrator)
    @unpack D1, tmp1, usol = integrator.p
    print(".")
    
    mass = integrate(u, D1)
    
    energy = integrate(u -> u^2, u, D1)
    
    x = grid(D1)
    tmp1 .= ( usol.(t, x) .- u ).^2
    error_u = integrate(tmp1, D1) |> sqrt
    
    SVector(mass, energy, error_u)
end


function KdV_nonlinear!(du, u, param, t)
    @unpack D1, tmp1, tmp2 = param
    
    # conservative semidiscretization
    mul!(tmp1, D1, u)
    @. du = u * tmp1
    @. tmp1 = u^2
    mul!(tmp2, D1, tmp1)
    @. du = -(du + tmp2) / 3
    
    nothing
end


function solve_ode_KdV(usol, D1, D3, tspan, alg, tol, dt, adaptive; kwargs...)

    x = grid(D1)
    u0 = usol.(tspan[1], x)
    tmp1 = similar(u0); tmp2 = similar(tmp1)
    param = (; D1, D3, tmp1, tmp2, usol)
    
    linear_part = DiffEqArrayOperator(-D3)
    nonlin_part = KdV_nonlinear!
    ode = SplitODEProblem(linear_part, nonlin_part, u0, tspan, param)
    
    saveat = range(tspan..., length=100)
    saved_values_baseline = SavedValues(eltype(D1), SVector{3,eltype(D1)})
    saving_baseline = SavingCallback(save_func_KdV, saved_values_baseline, saveat=saveat)
    saved_values_relaxation = SavedValues(eltype(D1), SVector{3,eltype(D1)})
    saving_relaxation = SavingCallback(save_func_KdV, saved_values_relaxation, saveat=saveat)
    relaxation = DiscreteCallback((u,t,integrator) -> true, relaxation!, save_positions=(false,false))
    cb_baseline = CallbackSet(saving_baseline)
    cb_relaxation = CallbackSet(relaxation, saving_relaxation)
    saved_values_projection = SavedValues(eltype(D1), SVector{3,eltype(D1)})
    saving_projection = SavingCallback(save_func_KdV, saved_values_projection, saveat=saveat)
    projection = DiscreteCallback((u,t,integrator) -> true, quadratic_projection!, save_positions=(false,false))
    cb_projection = CallbackSet(projection, saving_projection)
    
    println("Relaxation:")
    @time sol_relaxation = solve(ode, alg, abstol=tol, reltol=tol, dt=dt, adaptive=adaptive, save_everystep=false,
        callback=cb_relaxation, tstops=saveat, saveat=saveat, maxiters=10^8; kwargs...)
    flush(stdout)
    println("Projection:")
    @time sol_projection = solve(ode, alg, abstol=tol, reltol=tol, dt=dt, adaptive=adaptive, save_everystep=false,
        callback=cb_projection, tstops=saveat, saveat=saveat, maxiters=10^8; kwargs...)
    flush(stdout)
    println("Baseline:")
    @time sol_baseline = solve(ode, alg, abstol=tol, reltol=tol, dt=dt, adaptive=adaptive, save_everystep=false,
        callback=cb_baseline, tstops=saveat, saveat=saveat, maxiters=10^8; kwargs...)
    println()
    flush(stdout)

    unum_baseline   = sol_baseline[end]
    unum_relaxation = sol_relaxation[end]
    unum_projection = sol_projection[end]
    uana = usol.(tspan[end], x)
    @printf("Error in u (baseline):   %.3e\n", integrate(u->u^2, unum_baseline - uana, D1) |> sqrt)
    @printf("Error in u (relaxation): %.3e\n", integrate(u->u^2, unum_relaxation - uana, D1) |> sqrt)
    @printf("Error in u (projection): %.3e\n", integrate(u->u^2, unum_projection - uana, D1) |> sqrt)
    @printf("Difference of baseline and relaxation in u: %.3e\n", 
        integrate(u->u^2, unum_baseline - unum_relaxation, D1) |> sqrt)

    sleep(0.1)
    fig_u, ax = plt.subplots(1, 1)
    plt.plot(x, unum_baseline, label="non-conservative")
    plt.plot(x, unum_relaxation, label="relaxation")
    plt.plot(x, unum_projection, label="projection")
    plt.plot(x, uana, ":", color="gray", label="analytical")
    plt.xlabel(L"x"); plt.ylabel(L"u")
    plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))

    t_baseline = saved_values_baseline.t
    t_relaxation = saved_values_relaxation.t
    t_projection = saved_values_projection.t
    mass_baseline      = map(x->x[1], saved_values_baseline.saveval)
    mass_relaxation    = map(x->x[1], saved_values_relaxation.saveval)
    mass_projection    = map(x->x[1], saved_values_projection.saveval)
    quadratic_baseline   = map(x->x[2], saved_values_baseline.saveval)
    quadratic_relaxation = map(x->x[2], saved_values_relaxation.saveval)
    quadratic_projection = map(x->x[2], saved_values_projection.saveval)

    fig_invariants, ax = plt.subplots(1, 1)
    ax.set_yscale("symlog", linthreshy=1.0e-14)
    plt.plot(t_baseline,   mass_baseline   .- mass_baseline[1], 
                label=L"$\int u$ (non-conservative)",
                color="#E69F00", linestyle="-")
    plt.plot(t_relaxation, mass_relaxation .- mass_relaxation[1], 
                label=L"$\int u$ (relaxation)",
                color="#56B4E9", linestyle="-")
    plt.plot(t_relaxation, mass_projection .- mass_projection[1], 
                label=L"$\int u$ (projection)",
                color="#009E73", linestyle="-")
    plt.plot(t_baseline,   quadratic_baseline   .- quadratic_baseline[1], 
                label=L"$\int u^2$ (non-conservative)",
                color="#E69F00", linestyle="--")
    plt.plot(t_relaxation, quadratic_relaxation .- quadratic_relaxation[1], 
                label=L"$\int u^2$ (relaxation)",
                color="#56B4E9", linestyle="--")
    plt.plot(t_relaxation, quadratic_projection .- quadratic_projection[1], 
                label=L"$\int u^2$ (projection)",
                color="#009E73", linestyle="--")
    plt.xlabel(L"t"); plt.ylabel("Change of Invariants")
    plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
    
    error_u_baseline   = map(x->x[3], saved_values_baseline.saveval)
    error_u_relaxation = map(x->x[3], saved_values_relaxation.saveval)
    error_u_projection = map(x->x[3], saved_values_projection.saveval)

    fig_error, ax = plt.subplots(1, 1)
    ax.set_xscale("log")
    ax.set_yscale("log")
    plt.plot(t_baseline,   error_u_baseline, label="non-conservative")
    plt.plot(t_relaxation, error_u_relaxation, label="relaxation")
    plt.plot(t_projection, error_u_projection, label="projection")
    plt.xlabel(L"t"); plt.ylabel("Error")
    plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
    
    fig_legend = plt.figure()
    handles, labels = fig_u.axes[1].get_legend_handles_labels()
    plt.figlegend(handles, labels, loc="center", ncol=3)
        
    (;  sol_baseline, sol_relaxation, sol_projection,
        saved_values_baseline, saved_values_relaxation, saved_values_projection,
        fig_u, fig_invariants, fig_error, fig_legend)
end


# traveling wave solution
get_xmin() = 0.0
get_xmax() = 80.0
get_c() = 2 / 3
function usol(t, x)
    xmin = get_xmin()
    xmax = get_xmax()
    μ = (xmax - xmin) / 2
    c = get_c()
    A = 3 * c
    K = sqrt(1/c - 1)
    x_t = mod(x - c*t - xmin, xmax - xmin) + xmin - μ
    
    A / cosh(sqrt(3*A) / 6 * x_t)^2
end

In [None]:
@show get_c()
@show xmin = get_xmin()
@show xmax = get_xmax()
@show usol(0., xmin)
@show usol(0., xmax)
@show N = 2^8
@show dt = 0.1 * (get_xmax() - get_xmin()) / N
@show tspan = (0.0, (xmax-xmin)/(3*get_c()) + 30*(xmax-xmin)/get_c())
flush(stdout)

D1 = periodic_derivative_operator(1, 8, xmin, xmax, N)
D3 = periodic_derivative_operator(3, 8, xmin, xmax, N) |> sparse

tol = 1.0e-7
adaptive = true
results = solve_ode_KdV(usol, D1, D3, tspan, KenCarp4(), tol, dt, adaptive);

In [None]:
fig, ax = plt.subplots(1, 1)
x = grid(D1)
for idx in round.(Int, range(1, length(results.sol_projection.u), length=6))
    plt.plot(x, results.sol_projection.u[idx], label=@sprintf("\$t = %.1f\$", results.sol_projection.t[idx]))
end
plt.xlabel(L"x"); plt.ylabel(L"u")
plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
plt.xlim(get_xmin(), get_xmax())
plt.ylim(-5.0e-3, 5.0e-3)
plt.savefig(joinpath("..", "figures", "kdv_projection_zoom.pdf"), bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1, 1)
x = grid(D1)
plt.plot(x, results.sol_baseline.u[end], label="non-conservative")
plt.plot(x, results.sol_relaxation.u[end], label="relaxation")
plt.plot(x, results.sol_projection.u[end], label="projection")
plt.plot(x, usol.(last(tspan), x), ":", color="gray", label="analytical")
plt.xlabel(L"x"); plt.ylabel(L"u")
plt.xlim(get_xmin(), get_xmax())
plt.savefig(joinpath("..", "figures", "kdv_solution.pdf"), bbox_inches="tight")

plt.figure()
handles, labels = fig.axes[1].get_legend_handles_labels()
plt.figlegend(handles, labels, loc="center", ncol=4)
plt.savefig(joinpath("..", "figures", "kdv_legend.pdf"), bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(1, 1)
ax.set_xscale("log")
ax.set_yscale("log")
plt.plot(results.sol_baseline.t,   map(x->x[3], results.saved_values_baseline.saveval), label="non-conservative")
plt.plot(results.sol_projection.t, map(x->x[3], results.saved_values_relaxation.saveval), label="relaxation")
plt.plot(results.sol_relaxation.t, map(x->x[3], results.saved_values_projection.saveval), label="projection")
plt.xlabel(L"t"); plt.ylabel("Error")

t = [2.0e2, 4.0e3]
ax.plot(t, 2.0e-7 .* t.^2, ":", color="gray")
ax.annotate(L"\mathcal{O}(t^{2})", (1.0e3, 1.0), color="gray")
ax.plot(t, 5.0e-6 .* t.^1, ":", color="gray")
ax.annotate(L"\mathcal{O}(t^{1})", (1.0e3, 1.5e-2), color="gray")

plt.savefig(joinpath("..", "figures", "kdv_error.pdf"), bbox_inches="tight")