# Shallow water equations with variable bathymetry and periodic BCs

In [None]:
include("setup.jl")

## Discretization in 2D

In [None]:
function local_energy(h, hvx, hvy, g, b)
    0.5 * ((hvx^2 + hvy^2) / h + g * h^2) + g * h * b
end

function relaxation_functional(γ, u_new, u_old, param)
    @unpack D, g, b, tmp1 = param
    h_new   = u_new.x[1]
    hvx_new = u_new.x[2]
    hvy_new = u_new.x[3]
    h_old   = u_old.x[1]
    hvx_old = u_old.x[2]
    hvy_old = u_old.x[3]
    
    @. tmp1 = local_energy((1-γ)*h_old + γ*h_new, 
                           (1-γ)*hvx_old + γ*hvx_new, 
                           (1-γ)*hvy_old + γ*hvy_new, g, b)
    energy = integrate(tmp1, D)
end
function relaxation_functional(u, param)
    @unpack D, g, b, tmp1 = param
    h   = u.x[1]
    hvx = u.x[2]
    hvy = u.x[3]
    
    @. tmp1 = local_energy(h, hvx, hvy, g, b)
    energy = integrate(tmp1, D)
end


function save_func_periodic(u, t, integrator)
    @unpack D, tmp1, h_sol, hvx_sol, hvy_sol = integrator.p
    h   = u.x[1]
    hvx = u.x[2]
    hvy = u.x[3]
    print(".")
    
    mass_h   = integrate(h, D)
    mass_hvx = integrate(hvx, D)
    mass_hvy = integrate(hvy, D)
    energy = relaxation_functional(u, integrator.p)
    
    xy = grid(D)
    @. tmp1 = (h - h_sol(t, xy))^2
    error_h = integrate(tmp1, D) |> sqrt
    
    @. tmp1 = (hvx - hvx_sol(t, xy))^2
    error_hvx = integrate(tmp1, D) |> sqrt
    
    @. tmp1 = (hvy - hvy_sol(t, xy))^2
    error_hvy = integrate(tmp1, D) |> sqrt
    
    min_h,   max_h   = extrema(h)
    min_hvx, max_hvx = extrema(hvx)
    min_hvy, max_hvy = extrema(hvy)
    
    SVector(mass_h, mass_hvx, mass_hvy, energy, error_h, error_hvx, error_hvy,
            min_h, max_h, min_hvx, max_hvx, min_hvy, max_hvy)
end


function shallow_water_periodic!(du, u, param, t)
    @unpack D, g, b, tmp1, tmp2, vx, vy = param
    h   = u.x[1]
    hvx = u.x[2]
    hvy = u.x[3]
    dh   = du.x[1]
    dhvx = du.x[2]
    dhvy = du.x[3]
    
    @. vx = hvx / h
    @. vy = hvy / h
    
    mul!(tmp1, D, hvx, Val(:x))
    mul!(tmp2, D, hvy, Val(:y))
    @. dh = -tmp1 - tmp2
    @. dhvx = -0.5 * vx * tmp1
    @. dhvy = -0.5 * vy * tmp2
    
    mul!(tmp1, D, hvx, Val(:y))
    mul!(tmp2, D, hvy, Val(:x))
    @. dhvx -= 0.5 * vy * tmp1
    @. dhvy -= 0.5 * vx * tmp2
    
    @. tmp1 = (h + b)
    mul!(tmp2, D, tmp1, Val(:x))
    @. dhvx -= g * h * tmp2
    mul!(tmp2, D, tmp1, Val(:y))
    @. dhvy -= g * h * tmp2
    
    @. tmp1 = hvx * vx
    mul!(tmp2, D, tmp1, Val(:x))
    @. dhvx -= 0.5 * tmp2
    @. tmp1 = hvy * vy
    mul!(tmp2, D, tmp1, Val(:y))
    @. dhvy -= 0.5 * tmp2
    
    mul!(tmp1, D, vx, Val(:x))
    @. dhvx -= 0.5 * hvx * tmp1
    @. dhvy -= 0.5 * hvy * tmp1
    mul!(tmp1, D, vy, Val(:y))
    @. dhvx -= 0.5 * hvx * tmp1
    @. dhvy -= 0.5 * hvy * tmp1
    
    @. tmp1 = hvx * vy
    mul!(tmp2, D, tmp1, Val(:y))
    @. dhvx -= 0.5 * tmp2
    mul!(tmp2, D, tmp1, Val(:x))
    @. dhvy -= 0.5 * tmp2
    
    nothing
end

function solve_ode_shallow_water_periodic(h_sol, hvx_sol, hvy_sol, g, b_sol, 
                                          D, tspan, alg, abstol, reltol, dt, adaptive)
    xy = grid(D)
    h0   = h_sol.(  tspan[1], xy)
    hvx0 = hvx_sol.(tspan[1], xy)
    hvy0 = hvy_sol.(tspan[1], xy)
    b    = b_sol.(  tspan[1], xy)
    u0 = ArrayPartition(h0, hvx0, hvy0)
    tmp1 = similar(h0); tmp2 = similar(h0);
    vx   = similar(h0); vy   = similar(h0);
    param = (;D, g, b, tmp1, tmp2, vx, vy, h_sol, hvx_sol, hvy_sol, b_sol)

    ode = ODEProblem(shallow_water_periodic!, u0, tspan, param)
    
    saveat = range(tspan..., length=100)
    saved_values_baseline = SavedValues(eltype(D), SVector{13,eltype(D)})
    saving_baseline = SavingCallback(save_func_periodic, saved_values_baseline, saveat=saveat)
    saved_values_relaxation = SavedValues(eltype(D), SVector{13,eltype(D)})
    saving_relaxation = SavingCallback(save_func_periodic, saved_values_relaxation, saveat=saveat)
    relaxation = DiscreteCallback((u,t,integrator) -> true, relaxation!, 
        save_positions=(false,false))
    cb_baseline = CallbackSet(saving_baseline)
    cb_relaxation = CallbackSet(relaxation, saving_relaxation)
    
    saveat_full = [tspan[1], 0.95*tspan[1]+0.05*tspan[end], 
                             0.75*tspan[1]+0.25*tspan[end],
                             0.50*tspan[1]+0.50*tspan[end], 
                             0.25*tspan[1]+0.75*tspan[end], tspan[end]]
    
    @time sol_relaxation = solve(ode, alg, abstol=abstol, reltol=reltol, dt=dt, adaptive=adaptive, 
        save_everystep=false, callback=cb_relaxation, tstops=saveat, saveat=saveat_full)
    flush(stdout)
    @time sol_baseline = solve(ode, alg, abstol=abstol, reltol=reltol, dt=dt, adaptive=adaptive, 
        save_everystep=false, callback=cb_baseline, tstops=saveat, saveat=saveat_full)
    flush(stdout)

    h_num_baseline     = sol_baseline[end].x[1]
    hvx_num_baseline   = sol_baseline[end].x[2]
    hvy_num_baseline   = sol_baseline[end].x[3]
    h_num_relaxation   = sol_relaxation[end].x[1]
    hvx_num_relaxation = sol_relaxation[end].x[2]
    hvy_num_relaxation = sol_relaxation[end].x[3]
#     h_ana   = h_sol.(tspan[end], x)
#     hvx_ana = hvx_sol.(tspan[end], x)
#     hvy_ana = hvy_sol.(tspan[end], x)
#     @printf("Error in h   (baseline):   %.3e\n", integrate(u->u^2, h_num_baseline    - h_ana, D) |> sqrt)
#     @printf("Error in h   (relaxation): %.3e\n", integrate(u->u^2, h_num_relaxation  - h_ana, D) |> sqrt)
#     @printf("Error in hvx (baseline):   %.3e\n", integrate(u->u^2, hvx_num_baseline   - hvx_ana, D) |> sqrt)
#     @printf("Error in hvx (relaxation): %.3e\n", integrate(u->u^2, hvx_num_relaxation - hvx_ana, D) |> sqrt)
#     @printf("Error in hvy (baseline):   %.3e\n", integrate(u->u^2, hvy_num_baseline   - hvy_ana, D) |> sqrt)
#     @printf("Error in hvy (relaxation): %.3e\n", integrate(u->u^2, hvy_num_relaxation - hvy_ana, D) |> sqrt)
    @printf("Difference of baseline and relaxation in h:   %.3e\n", 
        integrate(u->u^2, h_num_baseline  - h_num_relaxation, D) |> sqrt)
    @printf("Difference of baseline and relaxation in hvx: %.3e\n", 
        integrate(u->u^2, hvx_num_baseline - hvx_num_relaxation, D) |> sqrt)
    @printf("Difference of baseline and relaxation in hvy: %.3e\n", 
        integrate(u->u^2, hvy_num_baseline - hvy_num_relaxation, D) |> sqrt)

#     sleep(0.1)
#     fig_h, ax = plt.subplots(1, 1)
#     plt.plot(x, h0 + b, label=L"h^0 + b")
# #     plt.plot(x, h_ana + b, label=L"$h^\mathrm{ana} + b$")
#     plt.plot(x, h_num_baseline + b,   label=L"$h^\mathrm{num} + b$ (baseline)")
#     plt.plot(x, h_num_relaxation + b, label=L"$h^\mathrm{num} + b$ (relaxation)")
#     plt.xlabel(L"x"); plt.ylabel(L"h + b")
#     plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5));

#     fig_hv, ax = plt.subplots(1, 1)
#     plt.plot(x, hv0, label=L"hv^0")
# #     plt.plot(x, hv_ana, label=L"$u^\mathrm{ana}$")
#     plt.plot(x, hv_num_baseline,   label=L"$hv^\mathrm{num}$ (baseline)")
#     plt.plot(x, hv_num_relaxation, label=L"$hv^\mathrm{num}$ (relaxation)")
#     plt.xlabel(L"x"); plt.ylabel(L"hv")
#     plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5));

    t_baseline = saved_values_baseline.t
    t_relaxation = saved_values_relaxation.t
    mass_h_baseline     = map(x->x[1], saved_values_baseline.saveval)
    mass_h_relaxation   = map(x->x[1], saved_values_relaxation.saveval)
    mass_hvx_baseline   = map(x->x[2], saved_values_baseline.saveval)
    mass_hvx_relaxation = map(x->x[2], saved_values_relaxation.saveval)
    mass_hvy_baseline   = map(x->x[3], saved_values_baseline.saveval)
    mass_hvy_relaxation = map(x->x[3], saved_values_relaxation.saveval)
    energy_baseline     = map(x->x[4], saved_values_baseline.saveval)
    energy_relaxation   = map(x->x[4], saved_values_relaxation.saveval)

    fig_invariants, ax = plt.subplots(1, 1)
    ax.set_yscale("symlog", linthreshy=1.0e-14)
    plt.plot(t_baseline,   mass_h_baseline   .- mass_h_baseline[1],   
        label=L"$\int h$ (non-conservative)", color="#E69F00", linestyle="-")
    plt.plot(t_relaxation, mass_h_relaxation .- mass_h_relaxation[1], 
        label=L"$\int h$ (conservative)", color="#56B4E9", linestyle="-")
    plt.plot(t_baseline,   mass_hvx_baseline   .- mass_hvx_baseline[1],   
        label=L"$\int hv_x$ (non-conservative)", color="#E69F00", linestyle="--")
    plt.plot(t_relaxation, mass_hvx_relaxation .- mass_hvx_relaxation[1], 
        label=L"$\int hv_x$ (conservative)", color="#56B4E9", linestyle="--")
#     plt.plot(t_baseline,   mass_hvy_baseline   .- mass_hvy_baseline[1],   
#         label=L"$\int hv_y$ (non-conservative)", color="#E69F00", linestyle="-.")
#     plt.plot(t_relaxation, mass_hvy_relaxation .- mass_hvy_relaxation[1], 
#         label=L"$\int hv_y$ (conservative)", color="#56B4E9", linestyle="-.")
    plt.plot(t_baseline,   energy_baseline   .- energy_baseline[1],   
        label="Energy (non-conservative)", color="#E69F00", linestyle=":")
    plt.plot(t_relaxation, energy_relaxation .- energy_relaxation[1], 
        label="Energy (conservative)", color="#56B4E9", linestyle=":")
    plt.xlabel(L"t"); plt.ylabel("Change of Invariants")
    plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
    
    error_h_baseline     = map(x->x[5], saved_values_baseline.saveval)
    error_h_relaxation   = map(x->x[5], saved_values_relaxation.saveval)
    error_hvx_baseline   = map(x->x[6], saved_values_baseline.saveval)
    error_hvx_relaxation = map(x->x[6], saved_values_relaxation.saveval)
    error_hvy_baseline   = map(x->x[7], saved_values_baseline.saveval)
    error_hvy_relaxation = map(x->x[7], saved_values_relaxation.saveval)

    fig_error, ax = plt.subplots(1, 1)
    ax.set_xscale("log")
    ax.set_yscale("log")
    plt.plot(t_baseline,   error_h_baseline,   label=L"Error of $h$ (non-conservative)",
        color="#E69F00", linestyle="-")
    plt.plot(t_relaxation, error_h_relaxation, label=L"Error of $h$ (conservative)",
        color="#56B4E9", linestyle="-")
    plt.plot(t_baseline,   error_hvx_baseline,   label=L"Error of $h v_x$ (non-conservative)",
        color="#E69F00", linestyle="--")
    plt.plot(t_relaxation, error_hvx_relaxation, label=L"Error of $h v_x$ (conservative)",
        color="#56B4E9", linestyle="--")
    plt.plot(t_baseline,   error_hvy_baseline,   label=L"Error of $h v_y$ (non-conservative)",
        color="#E69F00", linestyle="-.")
    plt.plot(t_relaxation, error_hvy_relaxation, label=L"Error of $h v_y$ (conservative)",
        color="#56B4E9", linestyle="-.")
    plt.xlabel(L"t"); plt.ylabel("Error")
    plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))
        
    (; sol_relaxation, sol_baseline, fig_invariants, fig_error, saved_values_baseline, saved_values_relaxation)
end



# using a diffracton generated numerically
function mean_to_center(q::AbstractVector)
    nghost = 2
    qBCs = similar(q, length(q) + 2*nghost)
    qBCs .= zero(eltype(q))
    qBCs[1:nghost] = q[end:-1:end-nghost+1]
    qBCs[nghost+1:end-nghost] = q
    qBCs[end-nghost+1:end] = q[1:nghost]
    res = @views @. (9*qBCs[1:end-4] - 116*qBCs[2:end-3] + 2134*qBCs[3:end-2] - 116*qBCs[4:end-1] + 9*qBCs[5:end]) / 1920.0
end
function mean_to_center(q::AbstractMatrix)
    res = copy(q)
    for i in 1:size(q, 1)
        @views res[i, :] .= mean_to_center(res[i, :])
    end
    for j in 1:size(q, 2)
        @views res[:, j] .= mean_to_center(res[:, j])
    end
    res
end

In [None]:
let
    data_filename = joinpath(dirname(@__DIR__), "data", "shallow_water", "refn2_small_domain.h5")
    global get_c() = 2.3043
    global get_xmin() = 0.0
    global get_xmax() = 20.0
    global get_ymin() = -0.5
    global get_ymax() = 0.5
    
    x_center, y_center, h_center, hvx_center, hvy_center, hpb_center, b_center = h5open(data_filename, "r") do io
        # x, z are the coordinates of the cell centers
        x   = read(io, "x")
        y   = read(io, "y")
        # the other values are mean cell mean values
        h   = read(io, "h")
        hvx = read(io, "hu")
        hvy = read(io, "hv")
        hpb = read(io, "eta")
        b   = hpb - h

        x, y, mean_to_center(h'), mean_to_center(hvx'), mean_to_center(hvy'), mean_to_center(hpb'), mean_to_center(b')
    end

    dx = (x_center[end] - x_center[1]) / (length(x_center) - 1)
    xmin = x_center[1]; xmax = x_center[end]
    x = range(xmin, xmax, length=length(x_center))

    dy = (y_center[end] - y_center[1]) / (length(y_center) - 1)
    ymin = y_center[1]; ymax = y_center[end]
    y = range(ymin, ymax, length=length(y_center))
    
    h0itp   = CubicSplineInterpolation((x,y), mean_to_center(h_center),   extrapolation_bc=Periodic())
    hpb0itp = CubicSplineInterpolation((x,y), mean_to_center(hpb_center), extrapolation_bc=Periodic())
    b0itp   = CubicSplineInterpolation((x,y), mean_to_center(b_center),   extrapolation_bc=Periodic())
    hvx0itp = CubicSplineInterpolation((x,y), mean_to_center(hvx_center), extrapolation_bc=Periodic())
    hvy0itp = CubicSplineInterpolation((x,y), mean_to_center(hvy_center), extrapolation_bc=Periodic())

    global get_g() = 9.8
    global function b_sol(t, xy)
        x, y = xy
        b0itp(x, y)::Float64
    end
    global function h_sol(t, xy)
        x, y = xy
        c = get_c()
        xmin = get_xmin()
        xmax = get_xmax()
        x_t = mod(x - c*t - xmin, xmax - xmin) + xmin

        h0itp(x_t, y)::Float64
    end
    global function hvx_sol(t, xy)
        x, y = xy
        c = get_c()
        xmin = get_xmin()
        xmax = get_xmax()
        x_t = mod(x - c*t - xmin, xmax - xmin) + xmin

        hvx0itp(x_t, y)::Float64
    end
    global function hvy_sol(t, xy)
        x, y = xy
        c = get_c()
        xmin = get_xmin()
        xmax = get_xmax()
        x_t = mod(x - c*t - xmin, xmax - xmin) + xmin

        hvy0itp(x_t, y)::Float64
    end
end

In [None]:
@show Nx = 2^10
@show Ny = 2^6
dt = 1.0 * max((get_xmax() - get_xmin()) / Nx, (get_ymax() - get_ymin()) / Ny) / get_c()
@show dt

tspan = (0.0, 15*(get_xmax() - get_xmin())/get_c())
@show tspan
flush(stdout)

abstol = reltol = 1.0e-6
adaptive = true
D = fourier_derivative_operator(get_xmin(), get_xmax(), Nx,
                                get_ymin(), get_ymax(), Ny)

@time results = solve_ode_shallow_water_periodic(
    h_sol, hvx_sol, hvy_sol, get_g(), b_sol, 
    D, tspan, Tsit5(), abstol, reltol, dt, adaptive);


In [None]:
xy, h_initial  = evaluate_coefficients(results.sol_baseline.u[1].x[1], D)
_, hvx_initial = evaluate_coefficients(results.sol_baseline.u[1].x[2], D)
_, hvy_initial = evaluate_coefficients(results.sol_baseline.u[1].x[3], D)
x = first.(xy[:, 1]); y = last.(xy[1, :])
b = results.sol_baseline.prob.p.b

idx = length(results.sol_baseline)
_, h_baseline   = evaluate_coefficients(results.sol_baseline.u[idx].x[1], D)
_, hvx_baseline = evaluate_coefficients(results.sol_baseline.u[idx].x[2], D)
_, hvy_baseline = evaluate_coefficients(results.sol_baseline.u[idx].x[3], D)

_, h_relaxation   = evaluate_coefficients(results.sol_relaxation.u[idx].x[1], D)
_, hvx_relaxation = evaluate_coefficients(results.sol_relaxation.u[idx].x[2], D)
_, hvy_relaxation = evaluate_coefficients(results.sol_relaxation.u[idx].x[3], D)

bson(joinpath(dirname(@__DIR__), "data", "shallow_water_Fourier.bson"), 
    Dict(:saved_values_baseline=>results.saved_values_baseline, 
         :saved_values_relaxation=>results.saved_values_relaxation,
         :x=>x, :y=>y, :b=>b,
         :h_initial=>h_initial, :hvx_initial=>hvx_initial, :hvy_initial=>hvy_initial,
         :h_baseline=>h_baseline, :hvx_baseline=>hvx_baseline, :hvy_baseline=>hvy_baseline,
         :h_relaxation=>h_relaxation, :hvx_relaxation=>hvx_relaxation, :hvy_relaxation=>hvy_relaxation))

In [None]:
# load saved data
saved_results = BSON.load(joinpath(dirname(@__DIR__), "data", "shallow_water_Fourier.bson"))

In [None]:
t_baseline   = saved_results[:saved_values_baseline].t
t_relaxation = saved_results[:saved_values_relaxation].t
error_h_baseline     = map(x->x[5], saved_results[:saved_values_baseline].saveval)
error_h_relaxation   = map(x->x[5], saved_results[:saved_values_relaxation].saveval)
error_hvx_baseline   = map(x->x[6], saved_results[:saved_values_baseline].saveval)
error_hvx_relaxation = map(x->x[6], saved_results[:saved_values_relaxation].saveval)
error_hvy_baseline   = map(x->x[7], saved_results[:saved_values_baseline].saveval)
error_hvy_relaxation = map(x->x[7], saved_results[:saved_values_relaxation].saveval)

fig_error, ax = plt.subplots(1, 1)
ax.set_xscale("log")
ax.set_yscale("log")
plt.plot(t_baseline,   error_h_baseline,   label=L"Error of $h$ (non-conservative)",
    color="#E69F00", linestyle="-")
plt.plot(t_relaxation, error_h_relaxation, label=L"Error of $h$ (conservative)",
    color="#56B4E9", linestyle="-")
plt.plot(t_baseline,   error_hvx_baseline,   label=L"Error of $h v_x$ (non-conservative)",
    color="#E69F00", linestyle="--")
plt.plot(t_relaxation, error_hvx_relaxation, label=L"Error of $h v_x$ (conservative)",
    color="#56B4E9", linestyle="--")
plt.plot(t_baseline,   error_hvy_baseline,   label=L"Error of $h v_y$ (non-conservative)",
    color="#E69F00", linestyle="-.")
plt.plot(t_relaxation, error_hvy_relaxation, label=L"Error of $h v_y$ (conservative)",
    color="#56B4E9", linestyle="-.")
plt.xlabel(L"t"); plt.ylabel("Error")
plt.legend(loc="center left", bbox_to_anchor=(1.0, 0.5))

t = [1.0e1, 1.3e2]
ax.plot(t, 2.0e-4 .* t .^ 0.7, ":", color="gray")
ax.annotate(L"\mathcal{O}(t^{0.7})", (4.0e1, 5.0e-3), color="gray")
ax.plot(t, 3.0e-4 .* t .^ 0.3, ":", color="gray")
ax.annotate(L"\mathcal{O}(t^{0.3})", (4.0e1, 7.0e-4), color="gray")

savefig(
    joinpath(dirname(@__DIR__), "figures", "shallow_water_Fourier_error.pdf"),
    bbox_inches="tight")


In [None]:
function plot_single_heatmap!(ax, x, _y, _data, label; kwargs...)
    y = copy(_y); push!(y, 2*y[end]-y[end-1])
    data = similar(_data, (size(_data,1), size(_data,2)+1))
    data[:, 1:end-1] = _data; data[:, end] = _data[:, 1]
    pcm = ax.pcolormesh(x, y, data', linewidth=0, rasterized=true; kwargs...)
    pcm.set_edgecolor("face")
    ax.set_ylabel(L"y")
    cbar = plt.colorbar(pcm, label=label, ax=ax)
    return nothing
end

function plot_heatmaps(x, y, b, h, hvx, hvy, filename="")
    fig, ax = plt.subplots(3, 1, figsize=(14,3), sharex=true)
    plot_single_heatmap!(ax[1], x, y, h+b, L"h + b"; vmin=0.74, vmax=0.81)
    plot_single_heatmap!(ax[2], x, y, hvx, L"h v_x"; vmin=-0.15, vmax=0.15)
    plot_single_heatmap!(ax[3], x, y, hvy, L"h v_y"; vmin=-0.015, vmax=0.015)
    ax[3].set_xlabel(L"x")
    plt.subplots_adjust(hspace=0.4)
    if !isempty(filename)
        fig.savefig(filename, bbox_inches="tight")
    end
    return fig
end

plot_heatmaps(saved_results[:x], saved_results[:y], saved_results[:b],
              saved_results[:h_initial], saved_results[:hvx_initial], saved_results[:hvy_initial],
              joinpath(dirname(@__DIR__), "figures", "shallow_water_Fourier_initial.pdf")
)
plot_heatmaps(saved_results[:x], saved_results[:y], saved_results[:b],
              saved_results[:h_baseline], saved_results[:hvx_baseline], saved_results[:hvy_baseline],
              joinpath(dirname(@__DIR__), "figures", "shallow_water_Fourier_baseline.pdf")
)
plot_heatmaps(saved_results[:x], saved_results[:y], saved_results[:b],
              saved_results[:h_relaxation], saved_results[:hvx_relaxation], saved_results[:hvy_relaxation],
              joinpath(dirname(@__DIR__), "figures", "shallow_water_Fourier_relaxation.pdf")
);