In [None]:
%env PATH=/usr/local/cuda-11.5/bin:$PATH
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [None]:
from pinns.prelude import *
from pinns.domain import Sphere
from pinns.calc import divergence
from pinns.krr import krr, rbf
from pinns.elm import elm
from pinns.pde import poisson_dirichlet_qp_mc, poisson_dirichlet_ecqp_mc
from scipy.stats.qmc import Sobol

import itertools
import numpy as np
import matplotlib.pyplot as plt
from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
domain = Sphere(1., (0., 0., 0.))
x_dom = array(Sobol(3, seed=0).random_base2(12))
x_dom = domain.transform(x_dom)
x_bnd = array(Sobol(2, seed=1).random_base2(9))
x_bnd = domain.transform_bnd(x_bnd)

x_support_dom = x_dom[:2**8]
x_support_bnd = x_bnd[:2**6]

In [None]:
def plot_result(sol, ax):
    plt.sca(ax)
    N = 200
    x = np.linspace(-1, 1, N)
    z = np.linspace(-1, 1, N)
    X = np.array([_x for _x in itertools.product(x, z)])
    _X = jnp.asarray(X)
    _X = jnp.insert(_X, array([1]), zeros((N * N, 1)), 1)
    phi = np.array(sol(_X))
    phi[np.linalg.norm(X, axis=-1) > 1.] = np.nan

    p = ax.contourf(x, z, phi.reshape(N, N).T, 20, cmap=plt.get_cmap("autumn"), alpha=0.5)
    plt.colorbar(p)


def plot_model(model):
    fig = plt.figure(figsize=(8, 3))
    ax1, ax2 = fig.subplots(1, 2)
    fig.subplots_adjust(wspace=0.4)
    plot_result(model, ax1)
    ax1.set_title("Model")    
    plot_result(lambda x: abs(model(x) - norm(x, axis=-1) + 1), ax2)
    ax2.set_title("abs error")


In [None]:
weights = array(Sobol(4, seed=12345).random_base2(8))
W = (weights[:, :3] * 2 - 1)
b = (weights[:, 3] * 2 - 1)


m = lambda x: x / norm(x)
f = lambda x: -divergence(m)(x)
#l = lambda x: (1 - norm(x))
n = lambda x: jax.nn.normalize(-grad(l)(x))
h = lambda x: tanh(W @ x + b)
#h = lambda x: exp(- 40 * (W @ x + b)**2)
u = lambda x: l(x) * h(x)
#u = lambda x: l(x) * exp(-1 * (W @ x + b)**2)
g1 = lambda x: 0.

phi1 = poisson_dirichlet_qp_mc(u, g1, x_dom, f, tol=1e-9, maxiter=4000)

In [None]:
def phi2_solution(x, x_bnd):
    eps = 1e-7
    def g(y):
        return dot(m(y), n(y)) - dot(grad(phi1)(y), n(y))
    
    dist = vmap(lambda x: norm(x - x_bnd, axis=-1))(x)
    _g = vmap(g)(x_bnd)
    def kernel(dist):
        idx = dist > eps
        newton_kernel = where(idx, 1 / dist, 0.)
        N = jnp.count_nonzero(idx)
        return (4 * pi * 1. ** 2) / (4 * pi * N) * dot(newton_kernel, _g)

    return vmap(kernel)(dist)