In [1]:
import sys; sys.path.insert(0, '../..')

import fenics as fa
import matplotlib.pyplot as plt
import mshr
import numpy as np
import pdb
import argparse
import jax
from collections import namedtuple

from src.linear_stokes.linear_stokes_common import (
    plot_solution,
    loss_fn,
    fenics_to_jax,
    SecondOrderTaylorLookup,
    error_on_coords,
    sample_params,
    sample_points,
    loss_divu_fn,
    loss_stress_fn,
    get_p,
    get_u,
    deviatoric_stress
)

from src.linear_stokes.linear_stokes_fenics import (
    solve_fenics,
    is_defined,
    parser
)

In [2]:
args = parser.parse_args("")
args = namedtuple("ArgsTuple", vars(args))(**vars(args))

params = sample_params(jax.random.PRNGKey(args.seed), args)
source_params, bc_params, per_hole_params, num_holes = params
print("params: ", params)



params:  (DeviceArray([0.4130522 , 0.25975317], dtype=float32), DeviceArray([3000.803], dtype=float32), DeviceArray([[ 0.04156993, -0.0613506 ,  0.47837767,  0.20494777,
               0.44709057],
             [-0.03926784, -0.11530625,  0.10562178,  0.4114218 ,
               0.35706055],
             [-0.04186818,  0.01709074,  0.18381561, -0.53561956,
               0.29799348]], dtype=float32), DeviceArray(1, dtype=int32))


In [None]:
u_p = solve_fenics(params)

In [None]:
points = sample_points(jax.random.PRNGKey(args.seed + 1), 1024, params)
points_on_inlet, points_on_walls, points_on_holes, points_in_domain = points

all_points = np.concatenate(points)

In [None]:
u, p = u_p.split()
fa.plot(u)

In [None]:
x0 = points_in_domain[120]

In [None]:
taylor = SecondOrderTaylorLookup(u_p, all_points, d = 3)

# Loss Function 

In [None]:
jax_stress_loss = loss_stress_fn(taylor, points_in_domain, params)
jax_divu_loss = loss_divu_fn(taylor, points_in_domain, params)

print(np.mean(jax_stress_loss, axis=0))
print(np.mean(jax_divu_loss, axis=0))

plt.figure()
clrs = plt.scatter(points_in_domain[:,0], points_in_domain[:,1], c = np.array(jax_stress_loss))
plt.colorbar(clrs)
plt.show()

plt.figure()
clrs = plt.scatter(points_in_domain[:,0], points_in_domain[:,1], c = np.array(jax_divu_loss))
plt.colorbar(clrs)
plt.show()

In [None]:
plt.figure()
plt.hist(np.log(jax_domain_loss))
plt.show()