In [3]:
import sys; sys.path.insert(0, '../..')

import fenics as fa
import matplotlib.pyplot as plt
import mshr
import numpy as np
import pdb
import argparse
import jax
from collections import namedtuple

from src.nonlinear_stokes.nonlinear_stokes_common import (
    plot_solution,
    loss_fn,
    fenics_to_jax,
    SecondOrderTaylorLookup,
    error_on_coords,
    sample_params,
    sample_points,
)

from src.nonlinear_stokes.nonlinear_stokes_fenics import (
    solve_fenics,
    is_defined,
    parser
)

In [4]:
args = parser.parse_args("")
args = namedtuple("ArgsTuple", vars(args))(**vars(args))

params = sample_params(jax.random.PRNGKey(args.seed), args)
source_params, bc_params, per_hole_params, num_holes = params
print("params: ", params)



params:  (DeviceArray([0.4130522 , 0.25975317], dtype=float32), DeviceArray([3000.803], dtype=float32), DeviceArray([[ 0.04156993, -0.0613506 ,  0.47837767,  0.20494777,
               0.44709057],
             [-0.03926784, -0.11530625,  0.10562178,  0.4114218 ,
               0.35706055],
             [-0.04186818,  0.01709074,  0.18381561, -0.53561956,
               0.29799348]], dtype=float32), DeviceArray(1, dtype=int32))


In [5]:
u_p = solve_fenics(params)

In [6]:
points = sample_points(jax.random.PRNGKey(args.seed + 1), 1024, params)
points_on_inlet, points_on_walls, points_on_holes, points_in_domain = points

all_points = np.concatenate(points)

In [7]:
x0 = points_in_domain[100]


In [8]:
taylor = SecondOrderTaylorLookup(u_p, [x0])

In [24]:
direction = np.array([-1., 1.])
for i in range(8):
    x = x0 + direction * 1e-8 * 10**i
    y = np.array(u_p(x))
    yhat = np.array(taylor(x)).reshape(y.shape)
    err = np.linalg.norm(y - yhat)
    print("delta size: {}, err {:.3e}".format(1e-8 * 10**i, err))

delta size: 1e-08, err 2.165e-04
delta size: 1e-07, err 3.706e-04
delta size: 1e-06, err 1.401e-04
delta size: 1e-05, err 5.061e-04
delta size: 0.0001, err 3.757e-03
delta size: 0.001, err 2.718e-02
delta size: 0.01, err 7.837e-01
delta size: 0.1, err 1.168e+02


In [None]:
y = np.array(u_p(x0))
yhat = np.array(taylor(x0)).reshape(y.shape)
np.linalg.norm(y - yhat)