In [None]:
import jax.numpy as jnp
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

In [None]:
#######################
# Heading parameters. #
# Strong order proxy. #
#######################

solver_name = "EKF1_2"
problem_name = "FHN"
prefix = f"{solver_name}_{problem_name}"
folder = "./EKF1_FHN"

deltas = 1/jnp.array([16, 32, 64, 128, 256, 512, 1024])
deltas = 1/jnp.array([1024])
Ns = 1/deltas
fineN = Ns**1.0
Mdeltas = jnp.ones((len(deltas),)) * (Ns)**0
T = 1.0
Ndeltas = T/deltas

for n in range(len(Ndeltas)):
    delta = deltas[n]
    N = int(Ndeltas[n])
    M = int(Mdeltas[n])
    fine = int(fineN[n])
    paths_1 = jnp.load(f'{folder}/{prefix}_pathwise_sols_{N}_{M}.npy')
    paths_2 = jnp.load(f'{folder}/{prefix}_pathwise_sols2_{N}_{fine}.npy')
    N = int(N)
    N += 1
    assert N == len(paths_1[0]) == len(paths_2[0])
    ts = jnp.linspace(0, T, N)
    ranged = range(N)
    plt.plot(ts, jnp.log(jnp.mean(jnp.linalg.norm(paths_1 - paths_2, axis=-1), axis=0)) / jnp.log(delta),
             label=f'delta={delta}, M={M}')
plt.legend()

In [None]:
###########################
# Weak and global errors  #
# With log-log regression.#
###########################

STRONG_GLOBAL_ERRORS_P1P2 = []
STRONG_LOCAL_ERRORS_P1P2 = []
WEAK_GLOBAL_ERRORS_P1P2 = []
@partial(jnp.vectorize, signature="(d,x)->(d,d)")
def WEAK_POLYNOMIAL(x):
    return x @ x.T


for n in range(len(deltas)):
    N = int(Ndeltas[n])
    M = int(Mdeltas[n])
    fine = int(fineN[n])
    paths_1 = jnp.load(f'{folder}/{prefix}_pathwise_sols_{N}_{M}.npy')
    paths_2 = jnp.load(f'{folder}/{prefix}_pathwise_sols2_{N}_{fine}.npy')
    #paths = jnp.load(f'{folder}/{prefix}_paths_{N}_{fine}.npy')
    STRONG_GLOBAL_ERROR_P1P2 = jnp.mean(jnp.linalg.norm(paths_2 - paths_1, axis=-1), axis=0)[-1]
    STRONG_LOCAL_ERROR_P1P2 = jnp.mean(jnp.linalg.norm(paths_2[:, 1] - paths_1[:, 1], axis=-1), axis=0)
    
    MOMENT_P1 = WEAK_POLYNOMIAL(paths_1[...,jnp.newaxis])
    MOMENT_P2 = WEAK_POLYNOMIAL(paths_2[...,jnp.newaxis])
    
    WEAK_GLOBAL_ERROR_P1P2 = jnp.max(jnp.linalg.norm(jnp.linalg.norm(jnp.mean(MOMENT_P2, axis=0)-jnp.mean(MOMENT_P1, axis=0),axis=-1),axis=-1),axis=-1)
    STRONG_GLOBAL_ERRORS_P1P2.append(STRONG_GLOBAL_ERROR_P1P2)
    STRONG_LOCAL_ERRORS_P1P2.append(STRONG_LOCAL_ERROR_P1P2)
    WEAK_GLOBAL_ERRORS_P1P2.append(WEAK_GLOBAL_ERROR_P1P2)
    
STRONG_GLOBAL_ERRORS_P1P2 = jnp.array(STRONG_GLOBAL_ERRORS_P1P2)
STRONG_LOCAL_ERRORS_P1P2 = jnp.array(STRONG_LOCAL_ERRORS_P1P2)
WEAK_GLOBAL_ERRORS_P1P2 = jnp.array(WEAK_GLOBAL_ERRORS_P1P2)
plt.plot(-jnp.log(deltas), -jnp.log(STRONG_GLOBAL_ERRORS_P1P2), label='global error P1P2')
plt.plot(-jnp.log(deltas), -jnp.log(STRONG_LOCAL_ERRORS_P1P2), label='local error P1P2')
plt.plot(-jnp.log(deltas), -jnp.log(WEAK_GLOBAL_ERRORS_P1P2), label='weak global error P1P2')
plt.legend()

linear_regressor = LinearRegression().fit(-jnp.log(deltas).reshape(-1, 1), -jnp.log(STRONG_GLOBAL_ERRORS_P1P2))
print(linear_regressor.coef_)
linear_regressor = LinearRegression().fit(-jnp.log(deltas).reshape(-1, 1), -jnp.log(STRONG_LOCAL_ERRORS_P1P2))
print(linear_regressor.coef_)
linear_regressor = LinearRegression().fit(-jnp.log(deltas).reshape(-1, 1), -jnp.log(WEAK_GLOBAL_ERRORS_P1P2))
print(linear_regressor.coef_)

In [None]:

jnp.save(f'{prefix}_STRONG_GLOBAL_ERRORS', jnp.array([deltas, STRONG_GLOBAL_ERRORS_P1P2]))
jnp.save(f'{prefix}_STRONG_LOCAL_ERRORS', jnp.array([deltas, STRONG_LOCAL_ERRORS_P1P2]))
jnp.save(f'{prefix}_WEAK_GLOBAL_ERRORS', jnp.array([deltas, WEAK_GLOBAL_ERRORS_P1P2]))
#jnp.save(f'{prefix}_WEAK_GLOBAL_ERRORS_P1GroundTruth', jnp.array([deltas, WEAK_GLOBAL_ERRORS_P1GroundTruth]))

In [None]:
STRONG_GLOBAL_ERRORS = jnp.load(f'{prefix}_STRONG_GLOBAL_ERRORS.npy')
STRONG_LOCAL_ERRORS = jnp.load(f'{prefix}_STRONG_LOCAL_ERRORS.npy')
WEAK_GLOBAL_ERRORS = jnp.load(f'{prefix}_WEAK_GLOBAL_ERRORS.npy')
#WEAK_GLOBAL_ERRORS_P1GroundTruth = jnp.load(f'{prefix}_WEAK_GLOBAL_ERRORS_P1GroundTruth.npy')

np.savetxt(f'{prefix}_STRONG_GLOBAL_ERRORS.csv', STRONG_GLOBAL_ERRORS.T, delimiter=',', header='deltas,errors', comments="")
np.savetxt(f'{prefix}_STRONG_LOCAL_ERRORS.csv', STRONG_LOCAL_ERRORS.T, delimiter=',', header='deltas,errors', comments="")
np.savetxt(f'{prefix}_WEAK_GLOBAL_ERRORS.csv', WEAK_GLOBAL_ERRORS.T, delimiter=',', header='deltas,errors', comments="")
#np.savetxt(f'{prefix}_WEAK_GLOBAL_ERRORS_P1GroundTruth.csv', WEAK_GLOBAL_ERRORS_P1GroundTruth.T, delimiter=',', header='deltas,errors', comments="")

In [None]:
###############
# Sample path #
###############

plt.plot(paths_1[10]) # Approximation
plt.plot(paths_2[10]) # Fine solution.

In [None]:
ts = jnp.insert(jnp.linspace(1/1024, 1, 1024),0,0)
np.savetxt(f'{prefix}_correct_mean_first.csv', jnp.array([ts, jnp.mean(paths_2[:,:,0], axis=0)]).T, delimiter=',', header='t,mean', comments="")
np.savetxt(f'{prefix}_correct_mean_secnd.csv', jnp.array([ts, jnp.mean(paths_2[:,:,1], axis=0)]).T, delimiter=',', header='t,mean', comments="")
np.savetxt(f'{prefix}_incorrect_mean_first.csv', jnp.array([ts, jnp.mean(paths_1[:,:,0], axis=0)]).T, delimiter=',', header='t,mean', comments="")
np.savetxt(f'{prefix}_incorrect_mean_secnd.csv', jnp.array([ts, jnp.mean(paths_1[:,:,1], axis=0)]).T, delimiter=',', header='t,mean', comments="")

In [None]:
plt.plot(ts, paths_2[0,:,0], label='correct mean first')
plt.plot(ts, paths_2[0,:,1], label='correct mean second')

In [None]:
np.savetxt(f'{prefix}_path_0.csv', jnp.array([ts, paths_2[0,:,0]]).T, delimiter=',', header='t,mean', comments="")
np.savetxt(f'{prefix}_path_1.csv', jnp.array([ts, paths_2[0,:,1]]).T, delimiter=',', header='t,mean', comments="")

In [None]:
import pandas as pd

solver_name = "EKF0_SSM"
problem_name = "FHN"
prefix = f"{solver_name}_{problem_name}"
folder = "./EKF0_FHN"

print(prefix)
#res_global_error = pd.read_csv(f'{folder}/{prefix}_STRONG_GLOBAL_ERRORS.csv', index_col=False, header=0)
#res_local_error = pd.read_csv(f'{folder}/{prefix}_STRONG_LOCAL_ERRORS.csv', index_col=False, header=0)
res_weak_error = pd.read_csv(f'{folder}/{prefix}_WEAK_GLOBAL_ERRORS.csv', index_col=False, header=0)



In [None]:
def summary(df):
    regr = LinearRegression()
    X = jnp.log(df.deltas.values).reshape(6, 1)
    Y = jnp.log(df.errors.values).reshape(6,1)
    regr.fit(X, Y)
    var_ = jnp.sqrt(1/4*jnp.sum((regr.predict(X)-Y)**2,axis=0))
    return regr.coef_, regr.intercept_, var_, regr.score(X, Y)
#print('STRONG LOCAL')
#print(summary(res_local_error))
#print('STRONG GLOBAL')
#print(summary(res_global_error))
print('WEAK GLOBAL')
print(summary(res_weak_error))