# Poisson Equation
In this example we will show how to leverage pancax to solve the simplest equation of mathetmatical 
physics, the Poisson equation... which we can write as 

$$\nabla^2 u + f = 0$$

We can approach this problem in several different ways

1. The strong form as written above
2. The weak form $$-\int_\Omega\nabla u\cdot\nabla v d\Omega + \int_\Omega fv d\Omega = 0$$
3. Energy minimization $$\min_u\int_\Omega \left[-\|\nabla u\|^2 + fu\right]d\Omega$$

All of the above cases are possible with tools in pancax.

Firse we must load necessary things from pancax. This is most easily done with the following line.

In [41]:
from pancax import *

Next we need a key for random numbers generated downstream in neural network initialization for example. This is usually done at the top of a script for repeatability reasons and can be done as follows.

In [42]:
##################
# for reproducibility
##################
key = random.key(10)

Now for some file management. We need some form of geometry to solve a PDE in. In pancax, we leverage computational meshes since these are the standard IO mechanism in computational mechanics. 

There is a helper method in pancax called ``find_mesh_file`` which will look in the current directory or sub-directories for appropriate extension names in files. If you prefer to not use this, a regular ole python ``str`` or a ``Path`` from ``pathlib`` will also work.

In [43]:
##################
# file management
##################
# mesh_file = find_mesh_file('mesh_quad4.g')
mesh_file = './mesh/mesh_quad4.g'
logger = Logger('pinn.log', log_every=250)
pp = PostProcessor(mesh_file, 'exodus')

Next, we can set up a domain object which holds on to various geometric quantities. There are several different types of domains in pancax depending upon the solution approach. 

For collocation problems, we can set up a ``CollocationDomain`` as follows

In [44]:
##################
# domain setup
##################
times = jnp.linspace(0.0, 0.0, 1)
domain = CollocationDomain(mesh_file, times)
print(domain)

Time in read_exodus_mesh: 0.01130949 seconds
Time in Reading Mesh...: 0.01162285 seconds
CollocationDomain(
  mesh_file='./mesh/mesh_quad4.g',
  mesh=Mesh(
    coords=f32[576,2],
    conns=i32[529,4],
    simplexNodesOrdinals=i32[576],
    parentElement=Quad4Element(
      elementType='quad4',
      degree=1,
      coordinates=f32[4,2],
      vertexNodes=i32[3],
      faceNodes=i32[4,2],
      interiorNodes=None
    ),
    parentElement1d=LineElement(
      elementType='line',
      degree=1,
      coordinates=f32[2],
      vertexNodes=i32[2],
      faceNodes=None,
      interiorNodes=i32[0]
    ),
    blocks={'block_1': i32[529]},
    nodeSets={
      'nset_1':
      i32[24](numpy),
      'nset_2':
      i32[24](numpy),
      'nset_3':
      i32[24](numpy),
      'nset_4':
      i32[24](numpy)
    },
    sideSets={
      'sset_1':
      i32[23,2],
      'sset_2':
      i32[23,2],
      'sset_3':
      i32[23,2],
      'sset_4':
      i32[23,2]
    }
  ),
  coords=f32[576,2],
  times=f

If instead you would like to work with variation approaches (e.g. the weak form or energy minimization), you can setup a domain as follows

In [45]:
##################
# domain setup
##################
times = jnp.linspace(0.0, 0.0, 1)
domain = VariationalDomain(mesh_file, times)
print(domain)

Time in read_exodus_mesh: 0.00947552 seconds
Time in Reading Mesh...: 0.00999020 seconds
Time in QuadratureRule.__init__: 0.00106756 seconds
Time in NonAllocatedFunctionSpace.__init__: 0.03084259 seconds
Time in QuadratureRule.__init__: 0.00060014 seconds
Time in NonAllocatedFunctionSpace.__init__: 0.02132083 seconds
VariationalDomain(
  mesh_file='./mesh/mesh_quad4.g',
  mesh=Mesh(
    coords=f32[576,2],
    conns=i32[529,4],
    simplexNodesOrdinals=i32[576],
    parentElement=Quad4Element(
      elementType='quad4',
      degree=1,
      coordinates=f32[4,2],
      vertexNodes=i32[3],
      faceNodes=i32[4,2],
      interiorNodes=None
    ),
    parentElement1d=LineElement(
      elementType='line',
      degree=1,
      coordinates=f32[2],
      vertexNodes=i32[2],
      faceNodes=None,
      interiorNodes=i32[0]
    ),
    blocks={'block_1': i32[529]},
    nodeSets={
      'nset_1':
      i32[24](numpy),
      'nset_2':
      i32[24](numpy),
      'nset_3':
      i32[24](numpy),
 

As can be seen by the two different outputs above, the variational domain contains some more complexity in the form of function spaces, squadrature rules, connectivity, etc.

Now that we have a time and geometric domain set up, we need to specify some physics. Below is an example which is a carbon copy of the internally implemented Poisson equation in pancax. This is shown so an eager user can see how a physics class is implemented.

In [46]:
from pancax.physics_kernels.base import BaseEnergyFormPhysics, BaseStrongFormPhysics
from typing import Callable
import jax.numpy as jnp


class MyPoisson(
  BaseEnergyFormPhysics, 
  BaseStrongFormPhysics
):
  field_value_names: tuple[str, ...] = ('u')
  f: Callable

  def __init__(self, f: Callable) -> None:
    super().__init__(('u'))
    self.f = f

  def energy(self, params, x, t, u, grad_u, *args):
    f = self.f(x)
    pi = 0.5 * jnp.dot(grad_u, grad_u.T) - f * u
    return jnp.sum(pi)

  def strong_form_neumann_bc(self, params, x, t, n, *args):
    field, _ = params
    grad_u = self.field_gradients(field, x, t, *args)
    return -jnp.dot(grad_u, n)

  def strong_form_residual(self, params, x, t, *args):
    field, _ = params
    delta_u = self.field_laplacians(field, x, t, *args)
    f = self.f(x)
    return -delta_u - f

This is a very general implementation which allows for both strong form and variational implementations of the Poisson equation. Multiple inheritance is leveraged here to inherit methods from botht he ``BaseEnergyFormPhysics`` and ``BaseStrongFormPhysics`` for flexibility in loss funciton choices later.

Alternatively, we could use the ``Poisson`` class already implemented in pancax as follows

In [47]:
##################
# physics setup
##################
physics = Poisson(lambda x: 2 * jnp.pi**2 * jnp.sin(2. * jnp.pi * x[0]) * jnp.sin(2. * jnp.pi * x[1]))


Boundary conditions


In [48]:
##################
# bcs
##################
def bc_func(x, t, z):
  x, y = x[0], x[1]
  return x * (1. - x) * y * (1. - y) * z

physics = physics.update_dirichlet_bc_func(bc_func)

ics = [
]
essential_bcs = [
  EssentialBC('nset_1', 0),
  EssentialBC('nset_2', 0),
  EssentialBC('nset_3', 0),
  EssentialBC('nset_4', 0),
]
natural_bcs = [
]

Problem setup


In [49]:
##################
# problem setup
##################
problem = ForwardProblem(domain, physics, ics, essential_bcs, natural_bcs)
print(problem)

Time in DofManager.__init__: 0.00866868 seconds
ForwardProblem(
  domain=VariationalDomain(
    mesh_file='./mesh/mesh_quad4.g',
    mesh=Mesh(
      coords=f32[576,2],
      conns=i32[529,4],
      simplexNodesOrdinals=i32[576],
      parentElement=Quad4Element(
        elementType='quad4',
        degree=1,
        coordinates=f32[4,2],
        vertexNodes=i32[3],
        faceNodes=i32[4,2],
        interiorNodes=None
      ),
      parentElement1d=LineElement(
        elementType='line',
        degree=1,
        coordinates=f32[2],
        vertexNodes=i32[2],
        faceNodes=None,
        interiorNodes=i32[0]
      ),
      blocks={'block_1': i32[529]},
      nodeSets={
        'nset_1':
        i32[24](numpy),
        'nset_2':
        i32[24](numpy),
        'nset_3':
        i32[24](numpy),
        'nset_4':
        i32[24](numpy)
      },
      sideSets={
        'sset_1':
        i32[23,2],
        'sset_2':
        i32[23,2],
        'sset_3':
        i32[23,2],
        'ss

ML setup

TODO this needs to be cleaned up alot

In [50]:
##################
# ML setup
##################
n_dims = domain.coords.shape[1]
field = MLP(n_dims + 1, physics.n_dofs, 50, 3, jax.nn.tanh, key)
params = FieldPropertyPair(field, problem.physics)

loss_function = StrongFormResidualLoss()
opt = Adam(loss_function, learning_rate=1e-3, has_aux=True)
opt_st = opt.init(params)


Training


In [51]:
for epoch in range(5000):
  params, opt_st, loss = opt.step(params, problem, opt_st)

  if epoch % 100 == 0:
    print(epoch)
    print(loss)


0
(Array(89.44533, dtype=float32), {'residual': Array(89.44533, dtype=float32)})
100
(Array(15.81907, dtype=float32), {'residual': Array(15.81907, dtype=float32)})
200
(Array(1.6400435, dtype=float32), {'residual': Array(1.6400435, dtype=float32)})
300
(Array(0.6807925, dtype=float32), {'residual': Array(0.6807925, dtype=float32)})
400
(Array(0.44704914, dtype=float32), {'residual': Array(0.44704914, dtype=float32)})
500
(Array(0.34868824, dtype=float32), {'residual': Array(0.34868824, dtype=float32)})
600
(Array(0.28184032, dtype=float32), {'residual': Array(0.28184032, dtype=float32)})
700
(Array(0.22886859, dtype=float32), {'residual': Array(0.22886859, dtype=float32)})
800
(Array(0.186384, dtype=float32), {'residual': Array(0.186384, dtype=float32)})
900
(Array(0.15230344, dtype=float32), {'residual': Array(0.15230344, dtype=float32)})
1000
(Array(0.12577225, dtype=float32), {'residual': Array(0.12577225, dtype=float32)})
1100
(Array(0.10473828, dtype=float32), {'residual': Array(0