## Physics Informed Neural Network

Simulation of complex physical systems described by nonlinear partial differential equations (PDEs) is central to engineering and physical science. We'll now train a neural network to solve a PDE given the boundary conditions. Till now, we have trained an MLP through data and penalize the network until it learns what we desire. However, in the case of PDEs, we can use knowledge of the known PDE to guide the training. Using the core idea of [PINNs](https://www.sciencedirect.com/science/article/pii/S0021999118307125), train a network to solve the Poisson equation 
$$u_{xx} + u_{yy} = -\sin (\pi x) \sin(\pi y)$$
with the following BCs:
$$u(0, y) = u(1, y) = u(x, 1) = u(x, 1) = 0$$

Utilize 10000 collocation points in the domain to enforce the PDE and 100 data-points on each boundary to enforce boundary condition. Compare your solution against the analytic solution and report error in the relative $\mathbb{L}_{2}$ norm.


In [None]:
import jax.numpy as np
import numpy as onp
from jax import random, jit, vmap, grad, device_put
from jax.example_libraries import optimizers

import itertools
from functools import partial
from tqdm import trange
import matplotlib.pyplot as plt

import jax

In [None]:
def MLP(layers, activation=np.tanh):
  def init(rng_key):
      def init_layer(key, d_in, d_out):
          k1, k2 = random.split(key)
          glorot_stddev = 1. / np.sqrt((d_in + d_out) / 2.)
          W = glorot_stddev * random.normal(k1, (d_in, d_out))
          b = np.zeros(d_out)
          return W, b
      key, *keys = random.split(rng_key, len(layers))
      params = list(map(init_layer, keys, layers[:-1], layers[1:]))
      return params
  def apply(params, inputs):
      for W, b in params[:-1]:
          outputs = np.dot(inputs, W) + b
          inputs = activation(outputs)
      W, b = params[-1]
      outputs = np.dot(inputs, W) + b
      return outputs
  return init, apply

In [None]:
@optimizers.optimizer
def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
    step_size = optimizers.make_schedule(step_size)
    def init(x0):
        m0 = np.zeros_like(x0)
        v0 = np.zeros_like(x0)
        return x0, m0, v0
    def update(i, g, state):
        x, m, v = state
        m = (1 - b1) * g + b1 * m  # First  moment estimate.
        v = (1 - b2) * np.square(g) + b2 * v  # Second moment estimate.
        mhat = m / (1 - np.asarray(b1, m.dtype) ** (i + 1))  # Bias correction.
        vhat = v / (1 - np.asarray(b2, m.dtype) ** (i + 1))
        x = x - step_size(i) * mhat / (np.sqrt(vhat) + eps)
        return x, m, v
    def get_params(state):
        x, _, _ = state
        return x
    return init, update, get_params

In [None]:
@jax.jit
def step(i, opt_state, x_b, y_b, x_c, y_c):
    params = get_params(opt_state)
    gradients = grad(loss)(params, x_b, y_b, x_c, y_c)
    return opt_update(i, gradients, opt_state)

In [None]:
# Data points:

# u(0, y) = 0
x_bl = np.zeros(100)
y_bl = np.linspace(0, 1, 100)

# u(1, y) = 0
x_br = np.ones(100)
y_br = np.linspace(0, 1, 100)

# u(x, 0) = 0
x_bb = np.linspace(0, 1, 100)
y_bb = np.zeros(100)

# u(x, 1) = 0
x_bt = np.linspace(0, 1, 100)
y_bt = np.ones(100)


x_b = np.concatenate((x_bb, x_br, x_bt, x_bl)) # concatenate in counter clockwise direction
y_b = np.concatenate((y_bb, y_br, y_bt, y_bl)) # concatenate in counter clockwise direction

In [None]:
# collocation points:
x_c = np.linspace(0, 1, 100)
y_c = np.linspace(0, 1, 100)
# create individual x and y meshgrid
x_c, y_c = np.meshgrid(x_c, y_c)

# flatten
x_c, y_c = x_c.flatten(), y_c.flatten()

In [None]:
rng_model = random.PRNGKey(0)
layers = [2, 16, 16, 1]
init, apply = MLP(layers)
params = init(rng_model) 

In [None]:
# Boundary condition function:
def u(params, x_b, y_b):
  u = apply(params, np.array([x_b, y_b]))
  return u[0]

# PINN: 
def PINN(params, x_c, y_c):
  
  u = u(params, x_c, y_c)
  u_xx = jax.grad(jax.grad(u, 1), 1)(params, x_c, y_c) 
  u_yy = jax.grad(jax.grad(u, 2), 2)(params, x_c, y_c)

  # loss:
  L_U = u_xx + u_yy + np.sin(np.pi*x_c) * np.sin(np.pi*y_c)
  return L_U

In [None]:
def loss_u(params, x_bc, y_bc):
    return ((jax.vmap(u, (None, 0, 0))(params, x_b, y_b))**2).mean()

def loss_f(params, x_c, y_c):
    return ((jax.vmap(PINN, (None, 0, 0))(params, x_c, y_c))**2).mean()

@jax.jit
def loss(params, x_b, y_b, x_c, y_c):
    return loss_u(params, x_b, y_b) + loss_f(params, x_c, y_c)

# compute gradient
grad_loss = jax.grad(loss)

In [None]:
# Optimizer initialization and update functions
lr = optimizers.exponential_decay(1e-3, decay_steps=100, decay_rate=0.99)
opt_init, opt_update, get_params = adam(lr)
opt_state = opt_init(params)

loss_log = []

# initial 
loss_log.append(loss(params, x_b, y_b, x_c, y_c))

# 100000 iter
pbar = trange(100000)
for it in pbar:
    opt_state = step(it, opt_state, x_b, y_b, x_c, y_c)
    if(it % 50 == 0):
        params = get_params(opt_state)
        loss_log.append(loss(params, x_b, y_b, x_c, y_c))

In [None]:
plt.figure(dpi = 300)
plt.plot(loss_log)
plt.xlabel('Epoch')
plt.ylabel('Log Loss')
plt.yscale("log")
plt.legend()

In [None]:
# analytical solution
dydx = (np.sin(np.pi*x_c)*np.sin(np.pi*y_c)) /(2*np.pi**2)

In [None]:
# our solution
PINN = jax.vmap(u, (None, 0, 0))(params, x_c, y_c)

In [None]:
der = dydx.reshape(100, 100)
PINN = PINN.reshape(100,100)
x = x_c.reshape(100, 100)
y = y_c.reshape(100, 100)