<a href="https://colab.research.google.com/github/sjiang23/senbaojiang.github.io/blob/main/Heat3D_Solver.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 3D AB Solver with $\alpha = 1$

In [None]:
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, jit, vmap, jacobian, hessian, jacrev, pmap, lax
from jax import random
from jax.scipy import optimize
from jax.example_libraries import optimizers
import matplotlib.pyplot as plt
from matplotlib import cm
from jax.flatten_util import ravel_pytree
from functools import partial
import numpy as np
import scipy as sp

import seaborn as sns
from tqdm import tqdm
import time
import itertools

import plotly.graph_objects as go
import datetime

In [None]:
pip install -U kaleido plotly==5.5.0

Collecting kaleido
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[K     |████████████████████████████████| 79.9 MB 1.1 MB/s 
Installing collected packages: kaleido
Successfully installed kaleido-0.2.1


In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
def random_layer_params(m, n, key):
  # Glorot Initialization
  W_key, b_key = random.split(key, 2)
  return  random.uniform(W_key, shape = (m, n), minval = -jnp.sqrt(6./(n+m)), maxval = jnp.sqrt(6./(n+m))), \
          random.uniform(b_key, shape = (n,), minval = -jnp.sqrt(6./(n+m)), maxval = jnp.sqrt(6./(n+m)))

def init_network_params(sizes, key):
  # sizes[0] == d+1 !!!
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


In [None]:
def unravel(U,d):
  if d == 1:
    return U
  else:
    N = int(jnp.round(len(U)**(1./d)))
    Umat = jnp.empty(shape = tuple([N]*d) )
    for i in range(N):
      Umat = Umat.at[i].set( unravel(U.at[i*N**(d-1):(i+1)*N**(d-1)].get(),d-1) )
    return Umat

In [None]:
def sigmoid(x):
  return 1./(1 + jnp.exp(-x))

def init_cond(x1, x2, x3):
  return (693/512)**3 * ( (1 - x1**2) * (1 - x2**2) * (1 - x3**2) )**5

In [None]:
@jit
def uNN(params, t, x1, x2, x3):
  # First hidden layer
  W, b = params[0]
  output = jnp.outer(t, W[0]) + jnp.outer(x1, W[1]) + jnp.outer(x2, W[2]) \
  + jnp.outer(x3, W[3]) + b
  activation = sigmoid(output)
  # Other hidden layers
  for W, b in params[1:]:
    output = jnp.dot(activation, W) + b
    activation = sigmoid(output)
  # Impose boundary/ic condition
  output = (1 - jnp.exp(-t)) * jnp.sqrt((1 - x1**2) * (1 - x2**2) * (1 - x3**2)) * jnp.ravel(output - b) \
  + init_cond(x1, x2, x3) # -b to get rid of last b in output

  return jnp.squeeze(output)


In [None]:
def levy_const(alpha):
    return 2**alpha * sp.special.gamma((3+alpha)/2) \
        /( jnp.sqrt(jnp.pi)**3 * jnp.abs(sp.special.gamma(-alpha/2)) )

In [None]:
@jit
def discrete_part(params, t, x1, x2, x3):
  # meshgrid
  D = ((Y1 - x1)**2 + (Y2 - x2)**2 + (Y3 - x3)**2)**2
  D = D.at[0,:,:].multiply(2)
  D = D.at[-1,:,:].multiply(2)
  D = D.at[:,0,:].multiply(2)
  D = D.at[:,-1,:].multiply(2)
  D = D.at[:,:,0].multiply(2)
  D = D.at[:,:,-1].multiply(2)
  D = jnp.where(D == 0, jnp.inf, D)

  # Trapezoidal rule summation with boundary correction
  trapz = h**3 * jnp.sum( ( uNN(params, t, x1, x2, x3) - uNN(params, t, jnp.ravel(Y1), jnp.ravel(Y2), jnp.ravel(Y3)) ) * jnp.ravel(1./D) ) 

  # singularity correction
  trapz = trapz - omega0/2 * h * ( grad(grad(uNN,2),2)(params, t, x1, x2, x3) \
                                  + grad(grad(uNN,3),3)(params, t, x1, x2, x3) \
                                  + grad(grad(uNN,4),4)(params, t, x1, x2, x3) )
 
  return trapz

In [None]:
@jit
def analytical_part(x1, x2, x3):
  I = ( jnp.arctan((1 - x3)/jnp.sqrt((1 - x1)**2 + (1 - x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 - x1)**2 + (1 - x2)**2)) ) \
  * jnp.sqrt((1 - x1)**2 + (1 - x2)**2)/((1 - x1) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x3)/jnp.sqrt((1 + x1)**2 + (1 - x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 + x1)**2 + (1 - x2)**2)) ) \
  * jnp.sqrt((1 + x1)**2 + (1 - x2)**2)/((1 + x1) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x3)/jnp.sqrt((1 + x1)**2 + (1 + x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 + x1)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 + x1)**2 + (1 + x2)**2)/((1 + x1) * (1 + x2)) \
  \
  + ( jnp.arctan((1 - x3)/jnp.sqrt((1 - x1)**2 + (1 + x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 - x1)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 - x1)**2 + (1 + x2)**2)/((1 - x1) * (1 + x2)) \
  \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 - x1)**2 + (1 - x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 - x1)**2 + (1 - x3)**2)) )\
  * jnp.sqrt((1 - x1)**2 + (1 - x3)**2)/((1 - x1) * (1 - x3)) \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 + x1)**2 + (1 - x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 + x1)**2 + (1 - x3)**2)) )\
  * jnp.sqrt((1 + x1)**2 + (1 - x3)**2)/((1 + x1) * (1 - x3)) \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 + x1)**2 + (1 + x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 + x1)**2 + (1 + x3)**2)) )\
  * jnp.sqrt((1 + x1)**2 + (1 + x3)**2)/((1 + x1) * (1 + x3)) \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 - x1)**2 + (1 + x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 - x1)**2 + (1 + x3)**2)) )\
  * jnp.sqrt((1 - x1)**2 + (1 + x3)**2)/((1 - x1) * (1 + x3)) \
  \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 - x3)**2 + (1 - x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 - x3)**2 + (1 - x2)**2)) )\
  * jnp.sqrt((1 - x3)**2 + (1 - x2)**2)/((1 - x3) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 + x3)**2 + (1 - x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 + x3)**2 + (1 - x2)**2)) )\
  * jnp.sqrt((1 + x3)**2 + (1 - x2)**2)/((1 + x3) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 + x3)**2 + (1 + x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 + x3)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 + x3)**2 + (1 + x2)**2)/((1 + x3) * (1 + x2)) \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 - x1)**2 + (1 + x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 - x3)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 - x3)**2 + (1 + x2)**2)/((1 - x3) * (1 + x2))

  return I/2

In [None]:
@jit
def wMSE(params, batch):
  t, x1, x2, x3 = batch
  return jnp.mean( ( batched_udt(params, t, x1, x2, x3) + \
  levy_constant * ( batched_discrete_part(params, t, x1, x2, x3) + uNN(params, t, x1, x2, x3) * analytical_part(x1, x2, x3) ) )**2)
  # return jnp.average( ( batched_udt(params, t, x1, x2, x3) + \
  # levy_constant * ( batched_discrete_part(params, t, x1, x2, x3) + uNN(params, t, x1, x2, x3) * analytical_part(x1, x2, x3) ) )**2,
  # weights = 1/(3 - x1**2 - x2**2 - x3**2) )

# Hyperparameters

In [None]:
key_init, key0, key1, key2, key3 = random.PRNGKey(12345), random.PRNGKey(2), random.PRNGKey(3), random.PRNGKey(5), random.PRNGKey(7)
key_test = random.PRNGKey(11)

sizes = [4, 20, 20, 20, 20, 20, 20, 1] # d = 3+1 = 4  !!!
params_len = len(sizes) - 1

alpha = 1
omega0 = 2.97121097252838375756

levy_constant = levy_const(alpha)

h = 1/20
dt = 0.01
T, T_pred = 0.2, 0.3

def step_size(n):
  return 10**-3

num_epochs = int(3 * 10**5)
batch_size = 100

In [None]:
# data = np.load('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/params_OU_3D_2022-04-02.npy', allow_pickle = True)
# params = []
# for W, b in data:
#   params.append((jnp.array(W, dtype=jnp.float32), jnp.array(b, dtype=jnp.float32)))
# del data

In [None]:
y = jnp.linspace(-1, 1, int(2/h) + 1)
Y1, Y2, Y3 = jnp.meshgrid(y, y, y)

# Data Preparation

In [None]:
def get_batch(batch_size, *keys):
  x = jnp.linspace(-1, 1, int(2/h) + 1); x = x.at[1:-1].get()
  t = jnp.arange(0, T, dt); t = t.at[1:].get()
  return  t.at[random.randint(key = keys[0], shape = (batch_size,), minval = 0, maxval= t.shape[0])].get(), \
          x.at[random.randint(key = keys[1], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get(), \
          x.at[random.randint(key = keys[2], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get(), \
          x.at[random.randint(key = keys[3], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get()

def get_test_batch(batch_size, *keys):
  x = jnp.linspace(-1, 1, int(2/h) + 1); x = x.at[1:-1].get()
  return  T_pred * jnp.ones((batch_size,)), \
          x.at[random.randint(key = keys[0], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get(), \
          x.at[random.randint(key = keys[1], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get(), \
          x.at[random.randint(key = keys[2], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get()

# Batched Fun

In [None]:
batched_udt = vmap(grad(uNN, 1), in_axes= ([(None,None)] * params_len,0,0,0,0), out_axes = 0)

batched_discrete_part = vmap(discrete_part, in_axes= ([(None,None)] * params_len,0,0,0,0), out_axes = 0)

# Training

In [None]:
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return update_fun(i, grad(wMSE, 0)(params, batch), opt_state)

In [None]:
init_fun, update_fun, get_params = optimizers.adam(step_size)

init_params = init_network_params(sizes, key_init)
opt_state = init_fun(init_params)
itercount = itertools.count()

mse_training_history, mse_test_history = [], []


In [None]:
print("\n Start training...")

for epoch in tqdm(range(num_epochs)):

  key0, subkey0 = random.split(key0)
  key1, subkey1 = random.split(key1)
  key2, subkey2 = random.split(key2)
  key3, subkey3 = random.split(key3)

  batch = get_batch(batch_size, subkey0, subkey1, subkey2, subkey3)
  opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  if epoch == 0 or (epoch + 1) % 500 == 0:
    train_acc = wMSE(params, batch)
    if jnp.isnan(train_acc):
      break

    # generate test batch
    key_test, *sub_key_tests = random.split(key_test, 4)
    batch = get_test_batch(batch_size, *sub_key_tests)
    test_acc = wMSE(params, batch)

    # record history
    mse_training_history.append(train_acc)
    mse_test_history.append(test_acc)
    print(" MSE train/test {}/{}".format(train_acc, test_acc))
  



# Numerical Verification

In [None]:
dx = 1/20
# meshgrid
yt = jnp.linspace(-1, 1, int(2/dx) + 1) 
xt = yt.at[1:-1].get()
Yt1, Yt2, Yt3 = jnp.meshgrid(yt, yt, yt) 
Xt1, Xt2, Xt3 = jnp.meshgrid(xt, xt, xt) 

In [None]:
U = jnp.load('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/heat_3D_ref_sol_T_0.3.npy', allow_pickle = True)

In [None]:
L2_rel_err = jnp.sqrt(jnp.sum((uNN(params, T_pred, jnp.ravel(Xt1), jnp.ravel(Xt2), jnp.ravel(Xt3)) - U)**2)/jnp.sum(U**2))

In [None]:
print(L2_rel_err)

### Maximum Absolute Error

In [None]:
Umat = unravel(U,3)
Z = unravel(uNN(params, T_pred, jnp.ravel(Xt1), jnp.ravel(Xt2), jnp.ravel(Xt3)),3)

In [None]:
K_top, K_mid = 0, int((Umat.shape[0]+1)/2)

In [None]:
fig_max_err = go.Figure(data = go.Contour( z= jnp.abs(Umat[K_mid] - Z[K_mid]), x = xt, y = xt ))
fig_max_err.update_layout(autosize = False, width = 500, height = 500)
fig_max_err.update_xaxes(title_text='x')
fig_max_err.update_yaxes(title_text='y')
fig_max_err.show()

In [None]:
fig_max_err = go.Figure(data = go.Contour( z= jnp.abs(Umat[K_top] - Z[K_top]), x = xt, y = xt ))
fig_max_err.update_layout(autosize = False, width = 500, height = 500)
fig_max_err.update_xaxes(title_text='x')
fig_max_err.update_yaxes(title_text='y')
fig_max_err.show()

### Maximum Relative Error

In [None]:
fig_r_err = go.Figure(data = go.Contour( z = jnp.log10(jnp.abs((Umat[K_mid] - Z[K_mid])/Umat[K_mid])), x = xt, y = xt))
fig_r_err.update_layout( autosize = False, width = 500, height = 500)
fig_r_err.update_xaxes(title_text='x')
fig_r_err.update_yaxes(title_text='y')
fig_r_err.show()

In [None]:
fig_r_err = go.Figure(data = go.Contour( z = jnp.log10(jnp.abs((Umat[K_top] - Z[K_top])/Umat[K_top])), x = xt, y = xt))
fig_r_err.update_layout( autosize = False, width = 500, height = 500)
fig_r_err.update_xaxes(title_text='x')
fig_r_err.update_yaxes(title_text='y')
fig_r_err.show()

### Training History

In [None]:
fig_training = go.Figure()
fig_training.add_trace(go.Scatter(x = jnp.arange(0,num_epochs,500)+1, y = jnp.log10(jnp.array(mse_training_history)), mode='lines', name='train'))
fig_training.add_trace(go.Scatter(x = jnp.arange(0,num_epochs,500)+1, y = jnp.log10(jnp.array(mse_test_history)), mode='lines', name='test'))
fig_training.update_layout(title = 'wMSE History', autosize = False, width = 500, height = 500)
fig_training.update_xaxes(title_text='epoch')
fig_training.update_yaxes(title_text='log10(wMSE)')
fig_training.show()

# Output

In [None]:
# fig = go.Figure()
# fig.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/fig_empty_{}.pdf"\
#                         .format(datetime.date.today()), format = 'pdf')

In [None]:
# fig_max_err.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/fig_max_err_T_{1}_{0}.pdf"\
#                         .format(datetime.date.today(), T_pred), format = 'pdf')

In [None]:
# fig_r_err.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/fig_r_err_T_{1}_{0}.pdf"\
#                       .format(datetime.date.today(), T_pred), format = 'pdf')

In [None]:
# fig_training.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/fig_training_{}.pdf"\
#                          .format(datetime.date.today()), format = 'pdf')

In [None]:
jnp.save('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/params_heat_3D_date_{}.npy'.format(datetime.date.today()), \
         np.array(params, dtype = object), allow_pickle = True)

In [None]:
jnp.save('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/train_test_history_{0}.npy'\
         .format(datetime.date.today()), jnp.vstack((jnp.array(mse_training_history),jnp.array(mse_test_history))), allow_pickle = True)