In [1]:
import sys
sys.path.append("../pseudosplit")

In [2]:
import numpy as np

In [3]:
from scipy.integrate import solve_ivp

In [4]:
%matplotlib qt
import matplotlib.pyplot as plt

In [5]:
from basis import FourierBasis, HermiteBasis

In [6]:
from state import State

In [7]:
from scheme import (AffineSym4, Strang, AffineSym6, Neri, Yoshida6, AffineSym2)

In [8]:
from solver import Solver

In [9]:
from models import FisherModel1D

In [10]:
model = FisherModel1D(alpha=2., beta=0.5)

In [11]:
u0 = lambda x: 1/(np.cosh(10*x))**2

In [12]:
method_list = [Strang(), Neri(), Yoshida6(), AffineSym2(), AffineSym4(), AffineSym6()]
dt_list = [2.5e-1, 1e-1, 7.5e-2, 5e-2, 2.5e-2, 1e-2, 7.5e-3, 5e-3, 2.5e-3, 1e-3, 7.5e-4, 5e-4, 2.5e-4, 1e-4]

In [13]:
# Set the pseudo-spectral basis
N = 2**10 # Number of Fourier modes
I = (-80, 80)  # Interval
fb = FourierBasis('fb', N, I)

In [14]:
# Define the initial state
u0 = State(name='u0', basis=fb, u=u0) 

  u0 = lambda x: 1/(np.cosh(10*x))**2
  u0 = lambda x: 1/(np.cosh(10*x))**2


In [15]:
t0 = 0.0
tf = 10.0

In [16]:
# Calculate the reference solution with DOP853
RHS = model.get_RHS(fb)

def wrapped_RHS(t, y):
    """
    Wrap the RHS to adapt it to the scipy solver.

    Parameters
    ----------
    t : float
        Time.
    y : array
        Initial value.

    Returns
    -------
    array
        Final value.

    """
    return -1j * RHS(y)

In [17]:
ref_sol = solve_ivp(wrapped_RHS, (t0, tf), u0.values.astype(np.complex128), 'DOP853', rtol=2.5e-14, atol=1e-16)
num_sol = ref_sol.y[:,-1]

In [18]:
plt.style.use("classic")
fig_dt, ax_dt = plt.subplots()
fig_nfev, ax_nfev = plt.subplots()
ax_dt.grid()
ax_nfev.grid()


In [19]:
for method in method_list:
    err_values = []
    nfev_values = []
    print(method.name)
    for dt in dt_list:
        solver = Solver(model, method)
        solver.start(u0, t0, tf) 

        count = 0
        while solver.active:
            u = solver.step(dt)
            t = solver.sim_time
            count += 1
            if count % 1000 == 0:
                print(t)
                count = 0
        err_values.append(np.max(np.abs(num_sol-u.values)))
        nfev_values.append(method.P_A.nfev)
        
    label = "$\mathtt{" + method.name + "}$"
    ax_dt.plot(dt_list, err_values, label=label)
    ax_nfev.plot(nfev_values, err_values, label=label)
    
ax_dt.legend()
ax_nfev.legend()

strang
9.999999999999831
7.500000000000095
4.999999999999916
10.0
2.499999999999958
5.000000000000082
7.5000000000004725
10.0
1.0000000000000007
1.9999999999998905
2.9999999999997806
3.9999999999996705
5.000000000000004
6.000000000000338
7.000000000000672
8.000000000001005
9.000000000000451
9.999999999999897
0.7500000000000007
1.500000000000029
2.2500000000000573
3.0000000000000857
3.750000000000114
4.500000000000142
5.2500000000001705
6.000000000000199
6.750000000000227
7.500000000000256
8.250000000000284
9.000000000000313
9.750000000000341
0.5000000000000003
0.9999999999999453
1.4999999999998903
1.9999999999998352
2.500000000000002
3.000000000000169
3.500000000000336
4.000000000000503
4.500000000000226
4.9999999999999485
5.499999999999671
5.999999999999394
6.499999999999117
6.99999999999884
7.499999999998563
7.999999999998286
8.499999999998897
8.999999999999508
9.500000000000119
10.0
0.25000000000000017
0.49999999999997263
0.7499999999999452
0.9999999999999176
1.250000000000001
1.500

<matplotlib.legend.Legend at 0x7ff472e28390>

In [20]:
plt.figure()
plt.plot(u.grid, np.abs(num_sol))

[<matplotlib.lines.Line2D at 0x7ff472ca2dd0>]

In [21]:
ax_dt.legend(loc="upper left")
ax_nfev.legend(loc="upper right")

<matplotlib.legend.Legend at 0x7ff477f29dd0>

In [22]:
plt.figure()

<Figure size 640x480 with 0 Axes>

In [23]:
from scipy.fft import fftshift

In [24]:
plt.plot(fftshift(np.abs(u.coeffs)))

[<matplotlib.lines.Line2D at 0x7ff472d06010>]

In [25]:
ax_dt.set_ylabel(ylabel="$\mathcal{E}_\infty$", fontsize=16)
ax_dt.set_xlabel(xlabel="$\Delta t$", fontsize=16)
ax_nfev.set_ylabel(ylabel="$\mathcal{E}_\infty$", fontsize=16)
ax_nfev.set_xlabel(xlabel="computational cost", fontsize=16)

Text(0.5, 0, 'computational cost')

In [26]:
ax_dt.set_xscale('log')
ax_dt.set_yscale('log')
ax_nfev.set_xscale('log')
ax_nfev.set_yscale('log')