# Imposing exact and soft boundary conditions in Physics-Informed Neural Networks

**Author:** [sebbas](https://twitter.com/sebbas)<br>
**Date created:** 2023/08/27<br>
**Last modified:** 2023/09/29<br>
**License:** MIT<br>
**Description:** An overview of how to build PINNs with exact and soft boundary conditions by using the example of Poisson's equation with variable coefficient.


## Introduction

In this guide, you will learn how to implement a **Physics-Informed Neural Network** that can approximate instances of the **Variable Coefficient Poisson equation**. Everything will be implemented from scratch with Keras and Tensorflow!

The following code includes examples on how to load training data, how to build PINN models with PDE loss and soft/exact boundary conditions, it shows how training and prediction steps work, and provides utility functions to visualize ground truth, predictions and error metrics.

By the end of this tutorial, you should be able to build your own PINNs with custom BC imposition approaches.


## Setup

This guide depends on `keras` and `tensorflow`. Hence, the first step is to import these frameworks - plus the usual packages that  are commonly used when training Machine Learning models.

In [None]:
import keras.layers as KL
import keras.regularizers as KR
import keras.callbacks as KC
import keras.optimizers as KO
import keras.metrics as KM
import keras.losses as KLS
import keras.backend as KB
import tensorflow as tf
import numpy as np
import h5py as h5
import pandas as pd
import scipy.interpolate as spint
import matplotlib.pyplot as plt
import seaborn as sns
import os, time, sys
from matplotlib.ticker import FuncFormatter
from IPython.display import display_markdown

Next, we need to gather the training data. There are two options to get training data into the working directory.

### "make clean"

Before fetching any datasets though, we make sure that the current directory is clean and no data from previous runs is present.

You can remove everything that was downloaded and generated by executing the code box below. I.e. every time you need a clean slate, run the code from below in addition to restarting the run-time.

In [None]:
#%%script false --no-raise-error # uncomment this line to prevent it from being executed (e.g. to preserve models)
!rm -rf psndata_* psnNet* *.png

### Option 1: Download pre-simulated datasets

The code below will download datasets of pre-simulated Poisson equation problems. Each dataset contains 2D grids at different resolutions (32x32 and 128x128) for variables `a` (the variable coefficient), `f` (RHS), `g` (BC) and `p` (solution) from the Poisson equation.

In [None]:
!rm -f psndata_*.h5
!wget -nc https://github.com/epssmallerzero/poisson-pinn/raw/master/data/psndata_1_128_0.h5
!wget -nc https://github.com/epssmallerzero/poisson-pinn/raw/master/data/psndata_1_128_1.h5
!wget -nc https://github.com/epssmallerzero/poisson-pinn/raw/master/data/psndata_1_128_2.h5
!wget -nc https://github.com/epssmallerzero/poisson-pinn/raw/master/data/psndata_1_128_3.h5
!wget -nc https://github.com/epssmallerzero/poisson-pinn/raw/master/data/psndata_1_32_0.h5

### Option 2: Create custom datasets

You can easily generate your own Poisson equation instances using the Poisson generator from [https://github.com/sebbas/poisson-ddm#training-data-generation](https://github.com/sebbas/poisson-ddm#training-data-generation).

The `README` in that repository has a section explaining how to call the generator script.

Once you have a `.h5` dataset, it is best to place it alongside this Python notebook. When using Google Colab the most convenient way to access datasets is through Google Drive. Colab can mount Drive with the command from below.

In [None]:
use_google_drive = 0 # toggle this when using your own datasets from Google Drive
if use_google_drive:
    from google.colab import drive
    drive.mount('/content/drive/')

Similarly to the pre-simulated datasets, the exact path to the datasets can be set later in the training parameters.

### Getting Poisson equation instances

The Poisson equation datasets are available on disk now. Next, we'll need a function that can fetch individual Poisson equation instances from a dataset.

To this end, we define `getBcAFP()` which returns `n` 2D Poisson equation instances of shape `shape`. Each instance contains random but smooth value distributions for variables `g`, `a`, `f`, and `p`.

There is an additional option called `eqId` (equation ID). It takes values 0, 1, or 2 and determines which special case of the Poisson equation to load:

- `eqId=0`: Poisson equation with zero BC
- `eqId=1`: Laplace equation (`f=0`) with non-zero BC
- `eqId=2`: Poisson equation with non-zero BC

Throughout this tutorial, we will stick to `eqId=2`. That is, the variable coefficient Poisson equation with non-zero BC. When you experiment with this tutorial, I would encourage you to try the other equations too.

In [None]:
def getBcAFP(n, fname, shape, eqId):
  dFile = h5.File(fname, 'r')
  nSample  = dFile.attrs['nSample']
  s        = dFile.attrs['shape']
  assert all(x == y for x, y in zip(shape, s))
  length   = dFile.attrs['length']
  aData    = dFile.get('a')[:n, ...]
  fData    = dFile.get('f')[:n, ...]
  ppData   = dFile.get('pp')[:n, ...]
  plData   = dFile.get('pl')[:n, ...]
  pBcData  = dFile.get('pBc')[:n, ...]
  dFile.close()

  assert n == aData.shape[0] and n == fData.shape[0] and n == ppData.shape[0]

  # Construct solution p
  if eqId == 0: # Homogeneous Poisson
    p = np.expand_dims(ppData, axis=-1)
  elif eqId == 1: # Inhomogeneous Laplace
    p = np.expand_dims(plData, axis=-1)
    fData *= 0.0 # For now, just set f to 0 and keep in channels
  elif eqId == 2: # Inhomogeneous Poisson (i.e. Homogeneous Poisson + Inhomogeneous Laplace)
    pplData = ppData + plData
    p = np.expand_dims(pplData, axis=-1)

  # Construct solution, bc, a, and f arrays for Poisson / Laplace equation
  a  = np.expand_dims(aData, axis=-1)
  f  = np.expand_dims(fData, axis=-1)

  # Construct bc array
  bcWidth     = 1
  # Times 2 because bc always on 2 sides in one dim
  pad         = bcWidth * 2
  sizeNoPad   = np.subtract(shape, pad) # Array size without padding
  # 0s on border of ones array
  onesWithPad = np.pad(np.ones(sizeNoPad), bcWidth)
  # Extra dim at beginning to match batchsize and at end to match channels
  onesExpand  = np.expand_dims(onesWithPad, axis=0)
  onesExpand  = np.expand_dims(onesExpand, axis=-1)
  # Repeat array 'nSample' times in 1st array dim
  bc          = np.repeat(onesExpand, n, axis=0)
  # Fill boundary with Dirichlet bc values (data from Laplace solve, ie f=0)
  if eqId == 1 or eqId == 2: # Inhomogeneous Poisson / Laplace
    nx, ny = shape
    bc[:,  0,  :, 0] = p[:,  0,  :, 0] # i- boundary
    bc[:,  :, -1, 0] = p[:,  :, -1, 0] # j+ boundary
    bc[:, -1,  :, 0] = p[:, -1,  :, 0] # i+ boundary
    bc[:,  :,  0, 0] = p[:,  :,  0, 0] # j- boundary
    # Average bc values to counterbalance overlap in corner cells
    if 0:
      bcCnt = np.ones_like(bc)
      corners = [[0,0], [nx-1,0], [0,ny-1], [nx-1,ny-1]]
      for x,y in corners:
        bcCnt[:,x,y,0] += 1
      bc /= bcCnt

  # Combine bc, a, f along channel dim
  bcAF = np.concatenate((bc, a, f), axis=-1)

  return bcAF, p

### Plots and metrics

Lastly, before getting into the actual model setup, we define several utility functions. These will come in handy when evaluating the accuracy of our models. The metric functions will probably look familiar to you. The plotting functions make use of them and will be called once we have trained our first model.

In [None]:
def rmse(targets, predictions):
  return np.sqrt(np.mean((targets-predictions)**2))


def mae(targets, predictions):
  return np.mean(np.abs(targets-predictions))


def mape(targets, predictions):
  return np.mean(np.abs((targets - predictions) / targets)) * 100.0


def getFileName(type, prefix, eqId, archId, bcId):
  return f'{prefix}_{type}_eq-{eqId}_arch-{archId}_bc-{bcId:d}'

In [None]:
def plotSolution(sP, nPred, p, phat, prefix, eqId, archId, bcId):
  sns.set_style("ticks")
  fig = plt.figure(figsize=(10, 2), dpi=300, constrained_layout=True)

  eP = sP + nPred
  errorP = np.abs(p[sP:eP,:,:,0] - phat[0:nPred,:,:,0])

  # Min, max values, needed for colobar range
  minP, maxP       = np.min(p[sP:eP,:,:,0]), np.max(p[sP:eP,:,:,0])
  minPErr, maxPErr = np.min(errorP[0:nPred,:,:]), np.max(errorP[0:nPred,:,:])
  xTicks, yTicks   = [0, p.shape[1]-1], [0, p.shape[2]-1]

  nRows = 1
  nCols = 4
  iInput = sP

  ax = fig.add_subplot(nPred, nCols, 1)
  plt.title('p')
  plt.imshow(p[iInput,:,:,0], vmin=-1, vmax=1, origin='lower', cmap='jet')
  plt.colorbar()
  plt.xticks(xTicks)
  plt.yticks(yTicks)

  ax = fig.add_subplot(nPred, nCols, 2)
  plt.title('phat')
  plt.imshow(phat[0,:,:,0], vmin=-1, vmax=1, origin='lower', cmap='jet')
  plt.colorbar()
  plt.xticks(xTicks)
  plt.yticks(yTicks)

  ax = fig.add_subplot(nPred, nCols, 3)
  plt.title('abs(p-phat)')
  plt.imshow(errorP[0,:,:], vmin=0, vmax=0.1, origin='lower', cmap='jet')
  plt.colorbar()
  plt.xticks(xTicks)
  plt.yticks(yTicks)

  ax = fig.add_subplot(nPred, nCols, 4)
  ax.axis('off')
  plt.title('Stats')
  plt.text(0.2, 0.7, 'MAPE: %.4f %%' % mape(p[sP,:,:,0], phat[0,:,:,0]))
  plt.text(0.2, 0.5, 'RMSE: %.4f' % rmse(p[sP,:,:,0], phat[0,:,:,0]))
  plt.text(0.2, 0.3, 'MAE: %.4f' % mae(p[sP,:,:,0], phat[0,:,:,0]))

  fname = getFileName('solutions', prefix, eqId, archId, bcId)
  plt.savefig(f'{fname}.png', bbox_inches='tight')
  plt.show()

In [None]:
def plotMetrics(model, epochStart, epochEnd, name, eqId, archId, bcId, plotLr=False):
  history = pd.read_csv(model + '.log', sep = ',', engine='python')[epochStart:epochEnd]

  titles  = ['Overall loss', 'MAE', 'Data loss', 'PDE loss']
  metrics = [['loss', 'val_loss'], ['mae', 'val_mae'], \
             ['data', 'val_data'], ['pde', 'val_pde']]
  nCols, nRows = 2, 2

  sns.set_style("darkgrid")
  fig = plt.figure(figsize=(10, 5), dpi=300)
  fig.subplots_adjust(hspace=.5)

  for i, names in enumerate(metrics):
    ax = fig.add_subplot(nRows, nCols, i+1)
    ax.set_yscale('log')
    ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: int(x)))
    for j, n in enumerate(names):
      assert n in names
      sns.lineplot(data=history, x=history['epoch'], y=n, legend='brief', label=n)

    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    ax.title.set_text(titles[i])
    if plotLr:
      sns.lineplot(history['lr'], color='C3')

  fname = getFileName('losses', name, eqId, archId, bcId)
  plt.savefig(f'{fname}.png', bbox_inches='tight')
  plt.show()

In [None]:
def plotMetricsTable(p, phatSoft, phatExact, sTime, eTime):
  sMAPE = mape(p[0,:,:,0], phatSoft[0,:,:,0])
  sRMSE = rmse(p[0,:,:,0], phatSoft[0,:,:,0])
  sMAE  = mae(p[0,:,:,0], phatSoft[0,:,:,0])

  eMAPE = mape(p[0,:,:,0], phatExact[0,:,:,0])
  eRMSE = rmse(p[0,:,:,0], phatExact[0,:,:,0])
  eMAE  = mae(p[0,:,:,0], phatExact[0,:,:,0])

  return f'''
  | Model      |  MAE       | RMSE        | MAPE         | Training time (sec) |
  | :--------: | :--------: | :---------: | :----------: | :-----------------: |
  | Soft BCs   | {sMAE:.2e} | {sRMSE:.2e} | {sMAPE:.2f}% | {sTime:.2f}         |
  | Exact BCs  | {eMAE:.2e} | {eRMSE:.2e} | {eMAPE:.2f}% | {eTime:.2f}         |
  '''

## Generators

We need to have ways to generate points uniformly/randomly and exactly at grid indices. Therefore, we define a continuous and a discrete point generator.

The former will be used to generate points during training. The latter will serve during the prediction step where points coinciding with pixel positions will be needed so that model outputs can easily be compared to the ground truth and plotted as images.

### The continuous point generator

The continuous generator can sample points uniformly and randomly in a given range [`minCoord`, `maxCoord`] - the size of our domain. While point coordinates are random, the location of points can be finetuned by distinguishing between collocation and boundary points: Collocation points can end up anywhere in the domain, boundary points will always be sampled near domain boundaries.

The number of points for each point type can be specified with the `nColPts` and `nBndPts` arguments. A good rule of thumb is to use a 10:1 ratio of collocation and boundary points. This ensures, that

1. there will always be a sufficient amount of points inside the domain (ensures model accuracy), and
2. the density of points along domain boundaries is higher relative to the domain interior (ensures that the model learns the BC found at domain boundaries).

We chose this setup as PINNs tend to converge faster when the BC can be learned accurately. However, to optimize model convergence further, it is best to try out multiple different point ratios.

### Interpolating from ground truth data

You might wonder how we can generate points at random positions and with training data for `a`, `f`, and `p` when the ground truth data is only available on a discrete grid (i.e. the simulated Poisson data we obtain from the `.h5` datasets).

The continuous generator goes around this problem by interpolating points. While this operation carries an error it is still better than only using the data found at grid indices. Especially when domains are smaller (e.g. 32x32 grids), interpolation augments the possible number of training points and increases model accuracy.

When experimenting with the generators I would recommend searching for `RegularGridInterpolator` in the generator below. The interpolation method, for example, could also be set to `cubic` for increased precision.

In [None]:
def generatePtsCont(begin, end, input, label, batchsize=32, loop=True, shuffle=True, \
                    minCoord=0, maxCoord=1, exactBc=False, nColPts=1e4, nBndPts=0, nBcPts=0):
  assert label.shape == input[:,:,:,0:1].shape
  size = (label.shape[1], label.shape[2])

  # Transform grid indices to value range from args
  xx = np.linspace(minCoord, maxCoord, num=size[0])
  yy = np.linspace(minCoord, maxCoord, num=size[1])
  xyBcBatch, pBcBatch = None, None

  # spint.RegularGridInterpolator expects ij indexing
  bc = np.expand_dims(np.swapaxes(input[0,:,:,0], 0, 1), axis=-1)
  a = np.expand_dims(np.swapaxes(input[0,:,:,1], 0, 1), axis=-1)
  f = np.expand_dims(np.swapaxes(input[0,:,:,2], 0, 1), axis=-1)
  p = np.expand_dims(np.swapaxes(label[0,:,:,0], 0, 1), axis=-1)
  # Interpolation functions for discrete grid data
  interpMethod = 'linear' # or use 'nearest' or 'cubic' - slower but more precise
  bcInterp = spint.RegularGridInterpolator((xx, yy), bc, method=interpMethod)
  aInterp = spint.RegularGridInterpolator((xx, yy), a, method=interpMethod)
  fInterp = spint.RegularGridInterpolator((xx, yy), f, method=interpMethod)
  pInterp = spint.RegularGridInterpolator((xx, yy), p, method=interpMethod)

  # Generate random x, y values in range
  eps = 0. # optional space to domain boundary
  x = np.random.uniform(minCoord+eps, maxCoord-eps, size=nColPts)
  y = np.random.uniform(minCoord+eps, maxCoord-eps, size=nColPts)
  xy = np.concatenate((x[:, np.newaxis], y[:, np.newaxis]), axis=-1)

  # Generate additional points along domain boundaries
  if nBndPts > 0:
    xyBnd = _generatePtsBnd(nBndPts, random=True)
    xy = np.concatenate((xy, xyBnd), axis=0) # Add boundary points to global point array

  if shuffle:
    perm = np.random.permutation(xy.shape[0])
    xy = xy[perm]

  if exactBc and nBcPts > 0:
    xyBc = _generatePtsBnd(nBcPts, random=False)
    pBc = bcInterp(xyBc) # Solution at bnd points
    # Repeat bnd xy and p batch size times (every point in batch has a copy of bnd points)
    xyBcBatch = np.expand_dims(xyBc, axis=0)
    xyBcBatch = np.repeat(xyBcBatch, batchsize, axis=0)
    pBcBatch = np.expand_dims(pBc, axis=0)
    pBcBatch = np.repeat(pBcBatch, batchsize, axis=0)

  # Infinite generator loop, abort only after covering all pts and when loop arg is false
  abort = False
  s = begin
  while not abort:
    if loop and s + batchsize > end:
      s = 0
    e = s + batchsize
    if e >= end:
      e = end
      abort = True
    yield [xy[s:e,:], pInterp(xy[s:e,:]), xyBcBatch, pBcBatch, aInterp(xy[s:e,:]), bcInterp(xy[s:e,:]), fInterp(xy[s:e,:])], []
    s += batchsize


def _generatePtsBnd(nPts, random=False, minCoord=0, maxCoord=1, eps=0):
  assert nPts > 0, 'Must supply additional boundary points when using exact boundary condition'

  if random:
    xLow = np.random.uniform(minCoord+eps, maxCoord-eps, size=nPts) # bottom
    xUp  = np.random.uniform(minCoord+eps, maxCoord-eps, size=nPts) # top
    yLow = np.random.uniform(minCoord+eps, maxCoord-eps, size=nPts) # left
    yUp  = np.random.uniform(minCoord+eps, maxCoord-eps, size=nPts) # right
  else:
    offset = 1 / nPts
    xLow = np.linspace(minCoord+eps, maxCoord-eps-offset, num=nPts)
    xUp  = np.linspace(minCoord+eps + offset, maxCoord-eps, num=nPts)
    yLow = np.linspace(minCoord+eps + offset, maxCoord-eps, num=nPts)
    yUp  = np.linspace(minCoord+eps, maxCoord-eps-offset, num=nPts)

  mins  = np.repeat(minCoord+eps, nPts)
  maxs  = np.repeat(maxCoord-eps, nPts)

  # Append bnd points to existing xy
  xy = np.empty((0,2))
  for a, b in list(zip([xLow, xUp, mins, maxs], [mins, maxs, yLow, yUp])):
    tmp = np.concatenate((a[:, np.newaxis], b[:, np.newaxis]), axis=-1)
    xy = np.concatenate((xy, tmp), axis=0)

  return xy

### Exact BC option

The `exactBc` option will only be used when generating points for exact BC models. When using it, the generator will uniformly sample `nBcPts` additional  points along domain boundaries. These points become the BC vector which is repeated for every point in a batch ("the BC is broadcasted to every point in the generator batch"). When the generator yields a batch, the BC vector is part of the output that the model receives.

This way, every training point in an exact BC model knows what the BC looks like. Later on, when defining the exact BC model, we will see how models process the BC (hint: the BC is not part of the model input).


### The discrete point generator

For every grid index in the `input` data, the discrete generator should generate the `xy` coordinates and corresponding `a`, `f`, and `p` Poisson values found at grid indices in a domain with given `minCoord` and `maxCoord`.

The idea is to use this generator during the prediction stage. This way, predictions can be compared to the ground truth and error metrics such as MSE can be computed.

In [None]:
def generatePtsDisc(begin, end, input, label, batchsize=1024, loop=True, shuffle=True, \
                    minCoord=0, maxCoord=1, exactBc=False):
  assert label.shape == input[:,:,:,0:1].shape
  size = (label.shape[1], label.shape[2])

  # Transform grid indices to value range from args
  xx = np.linspace(minCoord, maxCoord, num=size[0])
  yy = np.linspace(minCoord, maxCoord, num=size[1])
  xyBcBatch, pBcBatch = None, None

  bcFlat = input[:,:,:,0].flatten()
  aFlat = input[:,:,:,1].flatten()
  fFlat = input[:,:,:,2].flatten()
  pFlat = label.flatten()
  if shuffle:
    perm = np.random.permutation(len(bcFlat))
    bcFlat = bcFlat[perm]
    aFlat = aFlat[perm]
    fFlat = fFlat[perm]
    pFlat = pFlat[perm]
  bcFlat = np.expand_dims(bcFlat, axis=-1)
  aFlat = np.expand_dims(aFlat, axis=-1)
  fFlat = np.expand_dims(fFlat, axis=-1)
  pFlat = np.expand_dims(pFlat, axis=-1)

  x, y = np.meshgrid(xx, yy, indexing='xy')
  x, y = x.flatten(), y.flatten()
  if shuffle:
    x = x[perm]
    y = y[perm]
  x = np.expand_dims(x, axis=-1)
  y = np.expand_dims(y, axis=-1)
  xy = np.concatenate((x, y), axis=-1)

  if exactBc:
    p = label[0,:,:,0]
    pLst = [p[0,:-1], p[:-1,-1], p[-1,1:], p[1:,0]] # left, top, right, bottom
    xi, yj = np.meshgrid(xx, yy, indexing='xy')
    xLst = [xi[0,:-1], xi[:-1,-1], xi[-1,1:], xi[1:,0]] # left, top, right, bottom
    yLst = [yj[0,:-1], yj[:-1,-1], yj[-1,1:], yj[1:,0]] # left, top, right, bottom

    xyBc = np.empty((0,2))
    pBc = np.empty((0,1))
    for arrX, arrY, arrP in list(zip(xLst, yLst, pLst)):
      arrX, arrY, arrP =  arrX[:, np.newaxis], arrY[:, np.newaxis], arrP[:, np.newaxis]
      tmp = np.concatenate((arrX, arrY), axis=-1)
      xyBc = np.concatenate((xyBc, tmp), axis=0)
      pBc = np.concatenate((pBc, arrP), axis=0)

    # Repeat bnd xy and p batch size times (every point in batch has a copy of bnd points)
    xyBcBatch = np.expand_dims(xyBc, axis=0)
    xyBcBatch = np.repeat(xyBcBatch, batchsize, axis=0)
    pBcBatch = np.expand_dims(pBc, axis=0)
    pBcBatch = np.repeat(pBcBatch, batchsize, axis=0)

  # Infinite generator loop, abort only after covering all pts and when loop arg is false
  abort = False
  s = begin
  while not abort:
    if loop and s + batchsize > end:
      s = 0
    e = s + batchsize
    if e >= end:
      e = end
      abort = True
    yield [xy[s:e,:], pFlat[s:e,:], xyBcBatch, pBcBatch, aFlat[s:e,:], bcFlat[s:e,:], fFlat[s:e,:]], []
    s += batchsize

### Verifying generators

We can write a simple plotting function to visualize the points from the generators.

In [None]:
def plotPoints(nPts, fname, bcAF, p, discrete=False):
  if discrete:
    ptsGen = generatePtsDisc(0, nPts, bcAF, p, batchsize=nPts, loop=False)
    name = 'discrete'
  else:
    ptsGen = generatePtsCont(0, nPts, bcAF, p, batchsize=nPts, nColPts=nPts, loop=False)
    name = 'continuous'

  xy = next(ptsGen)[0][0]

  sns.set_style('darkgrid')
  fig, ax = plt.subplots(1, 1, figsize=(5, 5))

  ax.set_title(f'{nPts} points from {name} point generator')
  ax.set_box_aspect(1)

  sns.scatterplot(x=xy[:,0], y=xy[:,1], alpha=0.3, edgecolor="black", marker="o", ax=ax)

  plt.tight_layout()
  plt.show()
  fig.savefig(fname)

For illustration purposes, we plot 400 points: Once with the continuous point generator ...

In [None]:
gs = 32
bcAF, p  = getBcAFP(n=1, fname=f'psndata_1_{gs}_0.h5', shape=(gs, gs), eqId=2)
plotPoints(400, 'continuous_points.png', bcAF, p, discrete=False)

... and once with the discrete point generator:

In [None]:
plotPoints(400, 'discrete_points.png', bcAF, p, discrete=True)

## PINN with Soft BCs

We finally get to the implementation of the first PINN! That is, the PINN for soft BCs. The model is built on top of `keras.Model` and uses custom `train_step` and `test_step` functions.

### Model setup

Here are some interesting observations about the model:
- The architecture goes in as a list of strings (`operators` argument) which the model converts to `keras` layers and stores them in `self.mlp`
- `call()` uses the layers from `self.mlp` and concatenates them in `keras` functional API fashion.
  - This kind of model setup makes it possible to quickly try out various architectures. All it takes is a new `architecture` string in the training parameters.
- `_compute_losses()` contains the functionality for automatic differentiation.
  - Note that we need to differentiate twice to solve the Poisson equation, and hence, there are two nested `GradientTape`'s.
  - When solving other PDEs, this setup should change accordingly. I.e. if the PDE has order 1, we obviously don't have to differentiate twice.
  - Depending on the distance to the domain boundaries, we compute a PDE or data loss for every point.
  - The data loss is where the model learns the soft BC!


In [None]:
class PsnPinn(tf.keras.Model):
  def __init__(self, operators=[], reg=None, alpha=0.01, save_grad_stat=False, **kwargs):
    super(PsnPinn, self).__init__(**kwargs)
    assert len(operators) > 0
    if reg is not None: assert len(reg) == len(operators)
    self.alpha        = alpha
    self.operators    = operators
    self.reg          = np.zeros(len(operators)) if reg==None else np.array(reg)
    self.saveGradStat = save_grad_stat

    tf.print('PsnPinn with architecture:')

    # Input parameters will be filled with info from input layer (1st in arch string)
    self.inputLayer = None
    self.inputShape = None
    self.batchSize  = None

    self.mlp = []
    for i, op in enumerate(self.operators):

      layerName, layerArgs = self._getLayerName(op)

      ## 1st layer must be an input layer
      if layerName == 'input':
        self.inputShape = (layerArgs[0],)
        self.batchSize = layerArgs[1]
        self.inputLayer = ( KL.InputLayer( input_shape=self.inputShape,
                                           batch_size=self.batchSize) )
        tf.print(f'==> {layerName} layer with input shape {self.inputShape}, batch size {self.batchSize}')
        continue

      assert i > 0, 'Invalid architecture, missing input layer'

      if layerName == 'batchnorm':
        self.mlp.append( KL.BatchNormalization() )

      elif layerName == 'leakyrelu':
        self.mlp.append( KL.LeakyReLU() )

      elif layerName == 'relu':
        self.mlp.append( KL.ReLU() )

      elif layerName == 'concat':
        self.mlp.append( KL.Concatenate() )

      elif layerName == 'dense':
        if layerArgs[1] == 0:
          activation = 'gelu'
        elif layerArgs[1] == 1:
          activation = 'linear'
        else:
          print('Invalid activation in Dense layer')
          sys.exit()
        self.mlp.append( KL.Dense( units=layerArgs[0], activation=activation) )

      tf.print(f'==> {layerName} layer with args {layerArgs[0:]}')

    # Dicts for metrics and statistics
    self.trainMetrics = {}
    self.validMetrics = {}
    # Construct metric names and add to train/valid dicts
    names = ['loss', 'data', 'pde']
    for key in names:
      self.trainMetrics[key] = KM.Mean(name='train_'+key)
      self.validMetrics[key] = KM.Mean(name='valid_'+key)
    self.trainMetrics['mae'] = KM.MeanAbsoluteError(name='train_mae')
    self.validMetrics['mae'] = KM.MeanAbsoluteError(name='valid_mae')

    ## Add metrics for layers' weights, if save_grad_stat is required
    ## i even for weights, odd for bias
    if self.saveGradStat:
      for i, op in enumerate(self.operators):
        if op.trainable:
          names = ['dat_'+repr(i)+'w_avg', 'dat_'+repr(i)+'w_std',\
                   'dat_'+repr(i)+'b_avg', 'dat_'+repr(i)+'b_std',\
                   'pde_'+repr(i)+'w_avg', 'pde_'+repr(i)+'w_std',\
                   'pde_'+repr(i)+'b_avg', 'pde_'+repr(i)+'b_std']
          for name in names:
            self.trainMetrics[name] = KM.Mean(name='train '+name)

    # Dicts to save training and validation statistics
    self.trainStat = {}
    self.validStat = {}


  def _getLayerName(self, operator):
    strs = operator.split('_')
    layerName = strs[0]
    layerArgs = [int(s) for s in strs[1:]] # Convert layer args from string to int
    return layerName, layerArgs


  def _isBc(self, xy, eps=1e-2, minDomain=[0,0], maxDomain=[1,1]):
    xLowerBc = tf.less_equal(    xy[:,0], minDomain[0] + eps )
    xUpperBc = tf.math.greater_equal( xy[:,0], maxDomain[0] - eps )
    yLowerBc = tf.math.less_equal(    xy[:,1], minDomain[0] + eps )
    yUpperBc = tf.math.greater_equal( xy[:,1], maxDomain[0] - eps )

    isBcX = tf.math.logical_or(xLowerBc, xUpperBc)
    isBcY = tf.math.logical_or(yLowerBc, yUpperBc)
    isBc = tf.cast(tf.math.logical_or(isBcX, isBcY), tf.float32)
    return isBc


  def call(self, inputs, training=False, withInputLayer=1):
    xy = inputs[0]
    tensors = [self.inputLayer(xy)] if withInputLayer else [xy]

    for i, (curLayer, op) in enumerate(zip(self.mlp, self.operators[0:])):
      layerName, layerArgs = self._getLayerName(op)
      if layerName == 'concat':
        p1, p2 = layerArgs[0], layerArgs[1] # indices in operator list of tensors to concat
        resultTensor = curLayer([tensors[p1], tensors[p2]])
      else:
        curTensor = tensors[-1]
        if layerName == 'batchnorm':
          resultTensor = curLayer(curTensor, training=training)
        else:
          resultTensor = curLayer(curTensor)
      tensors.append(resultTensor)

    return resultTensor


  def _compute_losses(self, xy, a, f, p, xyBc, pBc):
    with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape2:
      tape2.watch(xy)
      with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape1:
        tape1.watch(xy)
        pPred = self([xy, p, xyBc, pBc], training=True)
      # 1st order derivatives
      p_grad = tape1.gradient(pPred, xy)
      p_x, p_y = p_grad[:,0], p_grad[:,1]
      del tape1
    # 2nd order derivatives
    p_xx = tape2.gradient(p_x, xy)[:,0]
    p_yy = tape2.gradient(p_y, xy)[:,1]
    del tape2

    # Convert to row vector
    a = tf.transpose(a)[0]
    f = tf.transpose(f)[0]

    # Find points close to domain boundary
    isBc  = self._isBc(xy, eps=0.2)
    isCol = tf.abs(1.0 - isBc)

    # Compute PDE loss
    pde = a * (p_xx + p_yy)
    pdeLoss = KLS.mean_squared_error(f*isCol, pde*isCol)

    # Convert to row vector
    p = tf.transpose(p)[0]
    pPred = tf.transpose(pPred)[0]
    # Compute data loss
    dataLoss = KLS.mean_squared_error(p*isBc, pPred*isBc)

    return dataLoss, pdeLoss, p, pPred


  def train_step(self, data):
    xy   = data[0][0]
    p    = data[0][1]
    xyBc = data[0][2]
    pBc  = data[0][3]
    a    = data[0][4]
    bc   = data[0][5]
    f    = data[0][6]

    with tf.GradientTape() as tape:
      dataLoss, pdeLoss, p, pPred = self._compute_losses(xy, a, f, p, xyBc, pBc)
      loss = dataLoss + self.alpha * pdeLoss

    # Compute gradients
    gradients = tape.gradient(loss, self.trainable_variables)
    # Update weights
    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

    # Update metrics
    self.trainMetrics['loss'].update_state(loss)
    self.trainMetrics['data'].update_state(dataLoss)
    self.trainMetrics['pde'].update_state(pdeLoss)
    self.trainMetrics['mae'].update_state(p, pPred)

    # Return metrics in statistics dictionary
    for key in self.trainMetrics:
      self.trainStat[key] = self.trainMetrics[key].result()
    return self.trainStat


  def test_step(self, data):
    xy   = data[0][0]
    p    = data[0][1]
    xyBc = data[0][2]
    pBc  = data[0][3]
    a    = data[0][4]
    bc   = data[0][5]
    f    = data[0][6]

    dataLoss, pdeLoss, p, pPred = self._compute_losses(xy, a, f, p, xyBc, pBc)
    loss = dataLoss + self.alpha * pdeLoss

    # Update metrics
    self.validMetrics['loss'].update_state(loss)
    self.validMetrics['data'].update_state(dataLoss)
    self.validMetrics['pde'].update_state(pdeLoss)
    self.validMetrics['mae'].update_state(p, pPred)

    # Return metrics in statistics dictionary
    for key in self.trainMetrics:
      self.validStat[key] = self.validMetrics[key].result()
    return self.validStat


  def reset_metrics(self):
    for key in self.trainMetrics:
      self.trainMetrics[key].reset_states()
    for key in self.validMetrics:
      self.validMetrics[key].reset_states()


  @property
  def metrics(self):
    return [self.trainMetrics[key] for key in self.trainMetrics] \
         + [self.validMetrics[key] for key in self.validMetrics]

### Parameters

The following parameters will be used in both the soft and the exact BC models. We define them here once.

Here are some ideas to try when playing with the model:
- Define a new architecture by adding its string representation and `archId`.
- Increase `alpha`, the PDE loss contribution.
- Load a different dataset by adjusting `datasetId`.
- Train a different variation of the Poisson equation, e.g. `eqId=0` for zero-BC Poisson.
- Train for more epochs and adjust the learning rate along the way.

In [None]:
# Variables (feel free to adjust these!)
prefix      = 'psnNet'
archId      = 1
eqId        = 2
batchsize   = 64
reg         = None
patience    = 200
initEpoch   = 0
lr0         = 5e-4       # Initial learning rate
lrmin       = 1e-7       # Minimum learning rate (used in callbacks)
lrRestart   = 5e-5       # Learning rate to use when reloading the mode
alpha       = 0.1        # Hyperparameter for PDE loss contribution. Based on int(nBndPts / nColPts)
shape       = (128, 128) # Grid size from dataset
datasetId   = 2          # Choose one of datasets 0, 1, 2, 3
fname       = f'psndata_1_{shape[0]}_{datasetId}.h5'
initEpoch   = 0
if use_google_drive:
  fname     =  '/content/drive/MyDrive/Colab Notebooks/' + fname
if archId == 1: # Fully-connected MLP
  architecture = ['input_{}_{}',
                  'dense_128_0',
                  'dense_128_0',
                  'dense_128_0',
                  'dense_128_0',
                  'dense_1_1']
elif archId == 2:        # Fully-connected (cone-shape) MLP
  architecture = ['input_{}_{}',
                  'dense_128_0',
                  'dense_64_0',
                  'dense_32_0',
                  'dense_16_0',
                  'dense_1_1']

# Constants (do not change these!)
nGridPts    = shape[0] * shape[1]
nSample     = 1
nPinnInputs = 2          # Model inputs are x and y -> nPinnInputs = 2
eagerExec   = False      # Only enable when debugging
# Insert values variables in input layer
architecture[0] = architecture[0].format(nPinnInputs, batchsize)
reduceLrCB = KC.ReduceLROnPlateau(monitor='loss', min_delta=0.01, patience=patience, min_lr=lrmin)

assert nSample == 1, 'Can only use 1 sample with PINN architecture for now'
assert eqId in [0, 1, 2], 'Invalid equation ID'
assert archId in [1, 2], 'Invalid architecture ID'

The next code block covers parameters needed only in the soft BC model. The number of points is set to a 10:1 ratio (collocation to boundary). To see how different ratios affect model accuracy, it is a good idea to try different variations here. For example, setting `nBndPts=10` should yield less accurate results than when training with `nBndPts=1000`. Give it a try!

In [None]:
exactBc    = False
nColPts    = 10000
nBndPts    = 4 * 250
nBcPts     = 0
nEpochSoft = 80
nameSoft   = getFileName('model', prefix, eqId, archId, exactBc)

# Traing and validation split
nAllPts    = nColPts + nBndPts
nTrainSoft = int(nAllPts * 0.9)
nValidSoft = nAllPts - nTrainSoft
print(f'{nTrainSoft} in training, {nValidSoft} in validation')

# Create soft BC model
psnNetSoft = PsnPinn(operators=architecture, reg=reg, alpha=alpha)
psnNetSoft.compile(optimizer=KO.Adam(learning_rate=lr0), run_eagerly=eagerExec)

# Callbacks
checkpointCB = KC.ModelCheckpoint(filepath='./' + nameSoft + '/checkpoint', monitor='val_loss', save_best_only=True, save_weights_only=True, verbose=1)
csvLogCB     = tf.keras.callbacks.CSVLogger(nameSoft + '.log', append=True)
psnCBs       = [checkpointCB, reduceLrCB, csvLogCB]

bcAF, p  = getBcAFP(nSample, fname, shape, eqId)
trainGen = generatePtsCont(0, nTrainSoft, bcAF, p, batchsize=batchsize, nColPts=nColPts, nBndPts=nBndPts, nBcPts=nBcPts, exactBc=exactBc)
validGen = generatePtsCont(nTrainSoft, nValidSoft, bcAF, p, batchsize=batchsize, nColPts=nColPts, nBndPts=nBndPts, nBcPts=nBcPts, exactBc=exactBc)

### "make clean"

In case you want to remove a soft BC model and its logs, run the box from below.

In [None]:
#%%script false --no-raise-error
!rm -rf $nameSoft*

### Training

For illustration purposes, we will just train for a couple of epochs. At this point, the model will not be at its highest accuracy. However, the general idea and trends should be visible.

In [None]:
startTrain = time.perf_counter()

psnNetSoft.fit(
  trainGen,
  batch_size=batchsize,
  initial_epoch=initEpoch,
  epochs=nEpochSoft,
  steps_per_epoch=nTrainSoft//batchsize,
  callbacks=psnCBs,
  validation_data=validGen,
  validation_steps=nValidSoft//batchsize,
  verbose=True)

endTrain = time.perf_counter()
softBcTime = endTrain - startTrain
print(f'fit() execution time in secs: {softBcTime}')

### More training

Optionally, we can train for a couple more epochs (and with a smaller learning rate). While not necessarily required here, we also restore the model weights which have been stored as checkpoints on disk.


In [None]:
trainMoreSoft = True
if trainMoreSoft:
  assert os.path.exists(nameSoft)
  psnNetSoft.load_weights(tf.train.latest_checkpoint(nameSoft))
  KB.set_value(psnNetSoft.optimizer.learning_rate, lrRestart)

  prevEpoch = nEpochSoft
  nEpochSoft += 20 # train for a couple more epochs
  startTrain = time.perf_counter()
  psnNetSoft.fit(trainGen, batch_size=batchsize, initial_epoch=prevEpoch, \
                 epochs=nEpochSoft, steps_per_epoch=nTrainSoft//batchsize, callbacks=psnCBs, \
                 validation_data=validGen, validation_steps=nValidSoft//batchsize, verbose=True)
  softBcTime += (time.perf_counter() - startTrain)

### Predictions

The model for soft BCs has been trained. Now, it is time to use the discrete point generator. Note that predictions are returned in a single-column vector. To plot them, they need to be reshaped to the shape of the domain.

In [None]:
predGen = generatePtsDisc(0, nGridPts, bcAF, p, batchsize=1024, shuffle=False, loop=False, exactBc=False)
assert os.path.exists(nameSoft), f'Model {nameSoft} does not exist'

phat = psnNetSoft.predict(predGen)
phat = phat.reshape((shape[0], shape[1]))
phat = phat[np.newaxis, :, :, np.newaxis]

startPred, nPred = 0, 1
plotSolution(startPred, nPred, p, phat, prefix, eqId, archId, exactBc)

### Loss

Let's find out how well our model learned BCs and the Poisson PDE by plotting the losses:

In [None]:
plotMetrics(nameSoft, 0, nEpochSoft, prefix, eqId, archId, False)

## PINN with Exact BCs

Moving on to the exact BC model. It's an extension of the soft BC model with just a couple of extra functions to process the BC.

To get hold of the BC itself, we make use of the same generators defined earlier. I.e. we let the generator objects know, that this time we also need a copy of the BC (option `exactBC=True`).

### Model setup

The exact BC model is best understood through the code below. A few things to note though:
- `call()` uses its parent's feed-forward implementation.
  - At the end of this call, the exact BC is "injected" into the result tensor.
- `_getG()` implements Inverse-Distance-Weighting (IDW) interpolation
  - For every point in a batch, it computes the interpolated solution value based on the distance to the BC and values at those BC points:
  \begin{align}
    G(\boldsymbol{x}) &=
    \sum_{i}^{N_{bc}}
      \frac
      {
        w_{i}z_{i}
      }
      {
        \sum_{i}^{N_{bc}}
        {
          w_{i}
        }
     }\\
    w_i &= \lvert\boldsymbol{x}-\boldsymbol{x_i}\rvert
  \end{align}
- `_getPhi()` returns a filter function
  - It is used to reduce the network's contribution near domain boundaries where true values are enforced.
  \begin{align}
    \phi(x,y) = x \cdot (1-x) \cdot y \cdot (1-y)
  \end{align}
  - When training domains that are non-square or only use BCs on some of the domain sides, this function should be adjusted. Otherwise, network effects could filtered in areas where this is not desirable.

In [None]:
class PsnPinnExactBc(PsnPinn):
  def __init__(self, **kwargs):
    super(PsnPinnExactBc, self).__init__(**kwargs)


  def call(self, inputs, training=False, withInputLayer=1):
    resultTensor = super().call(inputs, training, withInputLayer)
    xy   = inputs[0]
    p    = inputs[1]
    xyBc = inputs[2]
    pBc  = inputs[3]
    phi = self._getPhi(xy)
    g   = self._getG(xy, p, xyBc, pBc)
    return g + resultTensor * phi


  def _getG(self, xy, p, xyBc, pBc, eps=1e-10, exp=2):
    xyExp = tf.expand_dims(xy, axis=1)       # [nBatch, 1, nDim]
    dist = tf.square(xyExp - xyBc)           # [nBatch, nXyBc, nDim]
    dist = tf.reduce_sum(dist, axis=-1)      # [nBatch, nXyBc]
    dist = tf.math.sqrt(dist)                # [nBatch, nXyBc]
    wi = tf.pow(1.0 / (dist + eps), exp)     # [nBatch, nXyBc]
    wi = tf.expand_dims(wi, axis=-1)         # [nBatch, nXyBc, 1]
    denom = tf.reduce_sum(wi, axis=1)        # [nBatch, 1, nDim]
    numer = tf.reduce_sum(wi * pBc, axis=1)  # [nBatch, 1, dimBc]
    g = numer / denom                        # [nBatch, dimBc]
    return g


  def _getPhi(self, xy):
    x, y  = xy[:,0], xy[:,1]
    xMin, yMin, xMax, yMax = 0, 0, 1, 1 # Same bounds as used in domain setup
    phi = x * (1-x) * y * (1-y)
    phi = tf.expand_dims(phi, axis=-1)
    return phi

### Parameters

The exact BC models use slightly different numbers of points. For instance,
- no additional points along the domain boundary will be needed (`nBndPts=0`). This has obviously to do with the fact that the model will always return the exact BC at domain boundaries. BCs are hardcoded into `call()`.
- only a couple of BC points will be needed for the BC vector (`nBcPts=20` for a 128x128 domain is sufficient). If the overall size of the domain increases, this number should increase too.

In [None]:
exactBc     = True
nColPts     = 10000
nBndPts     = 0
nBcPts      = 20
nEpochExact = 80
nameExact   = getFileName('model', prefix, eqId, archId, exactBc)

# Traing and validation split
nAllPts     = nColPts + nBndPts*4 # 4 because 4 domain sides
nTrainExact = int(nAllPts * 0.9)
nValidExact = nAllPts - nTrainExact
print(f'{nTrainExact} in training, {nValidExact} in validation')

# Create exact BC model
psnNetExact = PsnPinnExactBc(operators=architecture, reg=reg, alpha=alpha)
psnNetExact.compile(optimizer=KO.Adam(learning_rate=lr0), run_eagerly=eagerExec)

# Callbacks
checkpointCB = KC.ModelCheckpoint(filepath='./' + nameExact + '/checkpoint', monitor='val_loss', save_best_only=True, save_weights_only=True, verbose=1)
csvLogCB     = KC.CSVLogger(nameExact + '.log', append=True)
psnCBs       = [checkpointCB, reduceLrCB, csvLogCB]

bcAF, p  = getBcAFP(nSample, fname, shape, eqId)
trainGen = generatePtsCont(0, nTrainExact, bcAF, p, batchsize=batchsize, nColPts=nColPts, nBndPts=nBndPts, nBcPts=nBcPts, exactBc=exactBc)
validGen = generatePtsCont(nTrainExact, nValidExact, bcAF, p, batchsize=batchsize, nColPts=nColPts, nBndPts=nBndPts, nBcPts=nBcPts, exactBc=exactBc)

### "make clean"

You can remove/delete the exact BC model by executing the box below.

In [None]:
#%%script false --no-raise-error
!rm -rf $nameExact*

### Training

Similarly to the soft BC model, we will train for a couple of epochs only.

In [None]:
startTrain = time.perf_counter()

psnNetExact.fit(
    trainGen,
    batch_size=batchsize,
    initial_epoch=initEpoch,
    epochs=nEpochExact,
    steps_per_epoch=nTrainExact//batchsize,
    callbacks=psnCBs,
    validation_data=validGen,
    validation_steps=nValidExact//batchsize,
    verbose=True)

endTrain = time.perf_counter()
exactBcTime = endTrain - startTrain
print(f'fit() execution time in secs: {exactBcTime}')

### More training

Optionally, reduce the learning rate, load the model, and train for more epochs!

In [None]:
trainMoreExact = True
if trainMoreExact:
  assert os.path.exists(nameExact)
  psnNetExact.load_weights(tf.train.latest_checkpoint(nameExact))
  KB.set_value(psnNetExact.optimizer.learning_rate, lrRestart)

  prevEpoch = nEpochExact
  nEpochExact += 20 # train for a couple more epochs
  startTrain = time.perf_counter()
  psnNetExact.fit(trainGen, batch_size=batchsize, initial_epoch=prevEpoch, \
                  epochs=nEpochExact, steps_per_epoch=nTrainExact//batchsize, callbacks=psnCBs, \
                  validation_data=validGen, validation_steps=nValidExact//batchsize, verbose=True)
  exactBcTime += (time.perf_counter() - startTrain)

### Predictions

When making predictions with the exact BC model, it is important to remember that the generator needs to supply the BC vector. The vector has to be supplied during training **and** inference. Otherwise, there would be a mismatch for the scale of solution `phat`: All weights have been found for an interpolated and filtered version of `phat`, and thus, the same interpolation and filtering steps need to be followed during inference too.

In [None]:
predGen  = generatePtsDisc(0, nGridPts, bcAF, p, batchsize=1024, shuffle=False, loop=False, exactBc=True)
assert os.path.exists(nameExact), f'Model {nameExact} does not exist'

phatExact = psnNetExact.predict(predGen)
phatExact = phatExact.reshape((shape[0], shape[1]))
phatExact = phatExact[np.newaxis, :, :, np.newaxis]

startPred, nPred = 0, 1
plotSolution(startPred, nPred, p, phatExact, prefix, eqId, archId, bcId=True)

Note how the exact BC model yields a slightly smaller error. The accuracy was improved just by "showing" training data in a different manner.

In [None]:
plotMetrics(nameExact, 0, nEpochExact, prefix, eqId, archId, True)

## Comparing Soft and Exact BCs

Depending on the dataset and the total number of epochs used during training, exact BC models generally outperform soft BC models in terms of accuracy. Given that exact BC models require fewer points to learn BCs, the time overhead caused by BC enforcement can be amortized. In some cases, exact BC models can even train faster than soft BC models (although again, this is highly dependent on the number of points in the BC vector).

In [None]:
table = plotMetricsTable(p=p, phatSoft=phat, phatExact=phatExact, sTime=softBcTime, eTime=exactBcTime)
display_markdown(table, raw=True)

## Conclusions

While exact BC models have a slight edge over soft BC models, the optimal BC imposition approach will depend on the problem itself (i.e. complexity of PDE, positions where BC is enforced, number of collocation points needed for target model accuracy, etc.).

These examples should give you some inspiration for your future PINN implementations. Should you decide to implement one of the strategies yourself, I'd be happy to hear about your thoughts and findings. Feel free to reach out [@sebbas](https://twitter.com/sebbas)!