<a href="https://colab.research.google.com/github/sjiang23/senbaojiang.github.io/blob/main/Heat2D_Solver_v3_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Difference with v3:

Weighted MSE

In [None]:
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, jit, vmap, jacobian, hessian, jacrev, pmap, lax
from jax import random
from jax.scipy import optimize
from jax.example_libraries import optimizers
import matplotlib.pyplot as plt
from matplotlib import cm
from jax.flatten_util import ravel_pytree
from functools import partial
import numpy as np
import scipy as sp

import seaborn as sns
from tqdm import tqdm
import time
import itertools

import plotly.graph_objects as go
import datetime

In [None]:
# pip install -U kaleido plotly==5.5.0

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

# Parameters, functions

In [None]:
def random_layer_params(m, n, key):
  # Glorot Initialization
  W_key, b_key = random.split(key, 2)
  return  random.uniform(W_key, shape = (m, n), minval = -jnp.sqrt(6./(n+m)), maxval = jnp.sqrt(6./(n+m))), \
          random.uniform(b_key, shape = (n,), minval = -jnp.sqrt(6./(n+m)), maxval = jnp.sqrt(6./(n+m)))

def init_network_params(sizes, key):
  # sizes[0] == d+1 !!!
  keys = random.split(key, len(sizes))
  return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]


In [None]:
def sigmoid(x):
  return 1./(1 + jnp.exp(-x))

def relu(x):
  return jnp.maximum(x,0)

def init_cond(x1, x2):
  return (315/256)**2 * ( (1 - x1**2) * (1 - x2**2) )**4
  # return jnp.sin(jnp.pi * x1)**2 * jnp.sin(jnp.pi * x2)**4


In [None]:
@jit
def uNN(params, t, x1, x2):
  '''
  inputs: t, x1, x2: scalars/1D
  output: 1D
  '''
  # First hidden layer
  W, b = params[0]
  output = jnp.outer(t, W[0]) + jnp.outer(x1, W[1]) + jnp.outer(x2, W[2]) + b
  activation = sigmoid(output)
  # Other hidden layers
  for W, b in params[1:]:
    output = jnp.dot(activation, W) + b
    activation = sigmoid(output)
  # Impose boundary/ic condition
  output = (1 - jnp.exp(-t)) * jnp.sqrt((1 - x1**2) * (1 - x2**2)) * jnp.ravel(output - b) + init_cond(x1, x2) # -b to get rid of last b in output

  return jnp.squeeze(output)


In [None]:
def levy_const(alpha):
    return 2**(alpha-1) * alpha * sp.special.gamma((2+alpha)/2) \
        /( jnp.pi**(2/2) * sp.special.gamma(1 - alpha/2) )

In [None]:
@jit
def discrete_part(params, t, x1, x2):
  # meshgrid
  D = jnp.sqrt((Y1 - x1)**2 + (Y2 - x2)**2)**(2 + alpha)
  D = D.at[0,:].multiply(2)
  D = D.at[-1,:].multiply(2)
  D = D.at[:,0].multiply(2)
  D = D.at[:,-1].multiply(2)
  D = jnp.where(D == 0, jnp.inf, D)

  # Trapezoidal rule summation with boundary correction
  trapz = h**2 * jnp.sum( ( uNN(params, t, x1, x2) - uNN(params, t, jnp.ravel(Y1), jnp.ravel(Y2)) ) * jnp.ravel(1./D) ) 

  # singularity correction
  trapz = trapz - omega0/2 * h**(2-alpha) * ( grad(grad(uNN,2),2)(params, t, x1, x2) \
                                             + grad(grad(uNN,3),3)(params, t, x1, x2) )
 
  return trapz

In [None]:
def analytical_part(x1, x2):

  H1 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. - x2)/(1. + x1))**2 )
  
  H2 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. - x2)/(1. + x1))**2 )
  
  H3 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. + x2)/(1. + x1))**2 )
  
  H4 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. + x2)/(1. + x1))**2 )
  
  H5 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. - x2)/(1. - x1))**2 )
  
  H6 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. - x2)/(1. - x1))**2 )
  
  H7 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. + x2)/(1. - x1))**2 )
  
  H8 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. + x2)/(1. - x1))**2 )
  
  I = np.sqrt(np.pi) * ( (1. + x2)**(-alpha) + (1. - x2)**(-alpha) ) * sp.special.gamma((1 + alpha)/2) \
  /(alpha * sp.special.gamma(1 + alpha/2)) + \
  (1. + x1)**(-1. - alpha)/(alpha * (1. + alpha)) * ((1. - x2) * ((1. + alpha) * H1 - H2) + \
                                                    (1. + x2) * ((1. + alpha) * H3 - H4) ) + \
  (1. - x1)**(-1. - alpha)/(alpha * (1. + alpha)) * ((1. - x2) * ((1. + alpha) * H5 - H6) + \
                                                    (1. + x2) * ((1. + alpha) * H7 - H8) )
  return I

In [None]:
# def l2_params(params):
#   res = 0
#   for W, b in params:
#     res = res + jnp.sum(W**2) + jnp.sum(b**2)
#   return res

In [None]:
def wMSE(params, batch):
  t, x1, x2 = batch
  return jnp.mean( ( batched_udt(params, t, x1, x2) + \
  levy_constant * ( batched_discrete_part(params, t, x1, x2) + uNN(params, t, x1, x2) * analytical_part(x1, x2) ) )**2 )

# Hyperparameters

In [None]:
key_init, key0, key1, key2 = random.PRNGKey(12345), random.PRNGKey(5), random.PRNGKey(2), random.PRNGKey(3)
key_test = random.PRNGKey(7)

sizes = [3, 20, 20, 20, 20, 1]
params_len = len(sizes) - 1

alpha = 1

if alpha == 1:
  omega0 = 1.950132460000978
elif alpha == 1.5:
  omega0 = 5.038779739396576
elif alpha == 0.5:
  omega0 = 0.960844610589965
else:
  print('no such option')

levy_constant = levy_const(alpha)

h = 1/32
dt = 0.01
T_train, T_pred = 0.2, 0.21

def step_size(n):
  return 10**-3

num_epochs = int(3.5 * 10**5)
batch_size = 64

In [None]:
y = jnp.linspace(-1, 1, int(2/h) + 1)
Y1, Y2 = jnp.meshgrid(y, y)

# Data Preparation

In [None]:
def get_batch(batch_size, *keys):
  x = jnp.linspace(-1, 1, int(2/h) + 1); x = x.at[1:-1].get()
  t = jnp.arange(0, T_train, dt); t = t.at[1:].get()
  return  t.at[random.randint(key = keys[0], shape = (batch_size,), minval = 0, maxval= t.shape[0])].get(), \
          x.at[random.randint(key = keys[1], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get(), \
          x.at[random.randint(key = keys[2], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get()

def get_test_batch(batch_size, *keys):
  x = jnp.linspace(-1, 1, int(2/h) + 1); x = x.at[1:-1].get()
  # t = jnp.arange(T, T_pred, dt); t = t.at[1:].get()
  return  T_pred * jnp.ones((batch_size,)), \
          x.at[random.randint(key = keys[0], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get(), \
          x.at[random.randint(key = keys[1], shape = (batch_size,), minval = 0, maxval= x.shape[0])].get()

# Batched Functions

In [None]:
batched_udt = vmap(grad(uNN, 1), in_axes= ([(None,None)] * params_len,0,0,0), out_axes = 0)

batched_discrete_part = vmap(discrete_part, in_axes= ([(None,None)] * params_len,0,0,0), out_axes = 0)

# Training by Adam

In [None]:
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return update_fun(i, grad(wMSE, 0)(params, batch), opt_state)

In [None]:
init_fun, update_fun, get_params = optimizers.adam(step_size)

init_params = init_network_params(sizes, key_init)
opt_state = init_fun(init_params)
itercount = itertools.count()

mse_training_history, mse_test_history = [], []


In [None]:
print("\n Start training...")

for epoch in tqdm(range(num_epochs)):

  key0, subkey0 = random.split(key0)
  key1, subkey1 = random.split(key1)
  key2, subkey2 = random.split(key2)

  batch = get_batch(batch_size, subkey0, subkey1, subkey2)
  opt_state = update(next(itercount), opt_state, batch)
  params = get_params(opt_state)

  if epoch == 0 or (epoch + 1) % 500 == 0:
    train_acc = wMSE(params, batch)

    # generate test batch
    key_test, *sub_key_tests = random.split(key_test, 3)
    batch = get_test_batch(batch_size, *sub_key_tests)
    test_acc = wMSE(params, batch)

    # record history
    mse_training_history.append(train_acc)
    mse_test_history.append(test_acc)
    print(" wMSE train/test {}/{}".format(train_acc, test_acc))



# Verification

In [None]:
# jnp.save('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/{}-{}.npy'.format(datetime.date.today(),num_epochs), \
#          np.array(params, dtype = object), allow_pickle = True)

In [None]:
# data = np.load('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 2D/Experiments/Predictive Capacity/2022-03-16-300000.npy', allow_pickle = True)
# params = []
# for W, b in data:
#   params.append((jnp.array(W, dtype=jnp.float32), jnp.array(b, dtype=jnp.float32)))
# del data

In [None]:
dx = 1/32
dt_ = dx**2/4
# meshgrid
yt = jnp.linspace(-1, 1, int(2/dx) + 1) # _t for test
xt = yt.at[1:-1].get()
Yt1, Yt2 = jnp.meshgrid(yt, yt) 
Xt1, Xt2 = jnp.meshgrid(xt, xt) 

In [None]:
def unravel(U):
  N = int(jnp.sqrt(len(U)))
  Umat = jnp.empty(shape = (N,N))
  for i in range(N):
    Umat = Umat.at[i].set(U[i*N:(i+1)*N])
  return Umat

In [None]:
def analy_integral(x1, x2):

  H1 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. - x2)/(1. + x1))**2 )
  
  H2 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. - x2)/(1. + x1))**2 )
  
  H3 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. + x2)/(1. + x1))**2 )
  
  H4 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. + x2)/(1. + x1))**2 )
  
  H5 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. - x2)/(1. - x1))**2 )
  
  H6 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. - x2)/(1. - x1))**2 )
  
  H7 = sp.special.hyp2f1(1/2, (2. + alpha)/2, 3/2, -((1. + x2)/(1. - x1))**2 )
  
  H8 = sp.special.hyp2f1((1. + alpha)/2, (2. + alpha)/2, (3. + alpha)/2, -((1. + x2)/(1. - x1))**2 )
  
  I = np.sqrt(np.pi) * ( (1. + x2)**(-alpha) + (1. - x2)**(-alpha) ) * sp.special.gamma((1 + alpha)/2) \
  /(alpha * sp.special.gamma(1 + alpha/2)) + \
  (1. + x1)**(-1 - alpha)/(alpha * (1. + alpha)) * ((1. - x2) * ((1. + alpha) * H1 - H2) + \
                                                    (1. + x2) * ((1. + alpha) * H3 - H4) ) + \
  (1. - x1)**(-1 - alpha)/(alpha * (1. + alpha)) * ((1. - x2) * ((1. + alpha) * H5 - H6) + \
                                                    (1. + x2) * ((1. + alpha) * H7 - H8) )
  return I

def get_mat(dx, X1, X2, Y1, Y2):
  def discrete_sum(x1, x2):
    arr = jnp.sqrt( (Y1 - x1)**2 + (Y2 - x2)**2 )**(2 + alpha)
    arr = arr.at[0,:].multiply(2)
    arr = arr.at[-1,:].multiply(2)
    arr = arr.at[:,0].multiply(2)
    arr = arr.at[:,-1].multiply(2)
    arr = jnp.where(arr == 0, jnp.inf, arr)
    return jnp.sum(1./arr)

  batched_discrete_sum = vmap(discrete_sum, in_axes=(0,0), out_axes=0)

  def diag_mat(a, b):
    return batched_discrete_sum(a, b) * dx**2 + analy_integral(a, b) # + 2 * omega0/dx**alpha

  def dense_mat(x1, x2):
    arr = jnp.sqrt((Y1 - x1)**2 + (Y2 - x2)**2)**(2 + alpha)
    arr = jnp.where(arr == 0, jnp.inf, arr)
    arr = arr.at[1:-1,1:-1].get() # removing bdry, since U = 0 on bdry.
    return jnp.ravel(1./arr)

  batched_dense_mat = vmap(dense_mat, in_axes=(0,0), out_axes=0)

  # dense part and diag part
  M = batched_dense_mat(X1[0], X2[0])
  D = diag_mat(X1[0], X2[0])
  for a, b in zip(X1[1:], X2[1:]):
    M = jnp.concatenate((M, batched_dense_mat(a,b)))
    D = jnp.concatenate((D, diag_mat(a,b)))
  M = levy_constant * dx**2 * M
  D = -levy_constant * jnp.diag(D)
  # correction part
  E = jnp.zeros_like(M)
  N = X1.shape[0]
  for k in range(E.shape[0]):
    # upper-left corner
    if k == 0:
      E = E.at[k,[k, k+1, k+2, k+N, k+2*N]].set([2,-2,1,-2,1])
    # upper side
    elif (k > 0) and (k < N-1): 
      E = E.at[k,[k-1, k, k+1, k+N, k+2*N]].set([1,-1,1,-2,1])
    # upper-right corner
    elif k == N-1: 
      E = E.at[k,[k-2, k-1, k, k+N, k+2*N]].set([1,-2,2,-2,1])
    # left side
    elif (k % N == 0) and (k > 0) and (k < N**2-N): 
      E = E.at[k,[k-N, k, k+1, k+2, k+N]].set([1,-1,-2,1,1])
    # right side
    elif ((k+1) % N == 0) and (k > N-1) and (k < N**2-1): 
      E = E.at[k,[k-N, k-2, k-1, k, k+N]].set([1,1,-2,-1,1]) 
    # lower-left corner
    elif k == N**2-N:
      E = E.at[k,[k-2*N, k-N, k, k+1, k+2]].set([1,-2,2,-2,1]) 
    # lower side
    elif (k > N**2-N) and (k < N**2-1):
      E = E.at[k,[k-2*N, k-N, k-1, k, k+1]].set([1,-2,1,-1,1]) 
    # lower right corner
    elif k == N**2-1:
      E = E.at[k,[k-2*N, k-N, k-2, k-1, k]].set([1,-2,1,-2,2]) 
    # internal 
    else:
      E = E.at[k,[k-N, k-1, k, k+1, k+N]].set([1,1,-4,1,1]) 
    
  E = (0.5 * levy_constant * omega0/dx**alpha) * E
  
  return M + D + E

In [None]:
Z = unravel(uNN(params, T_pred, jnp.ravel(Xt1), jnp.ravel(Xt2)))

In [None]:
A = get_mat(dx, Xt1, Xt2, Yt1, Yt2)

In [None]:
steps = int(jnp.round(T_pred/dt_))
U = jnp.ravel(init_cond(Xt1,Xt2))

for _ in range(steps):
  U1 = U + dt_ * jnp.dot(A, U)
  U2 = 3/4 * U + 1/4 * U1 + 1/4 * dt_ * jnp.dot(A, U1)
  U = 1/3 * U + 2/3 * U2 + 2/3 * dt_ * jnp.dot(A, U2)

In [None]:
Umat = unravel(U)

In [None]:
# fig = go.Figure(data = go.Surface( z = Z, x = xt, y = xt))
# fig.update_layout(title='ML Sol', autosize=False, width = 500, height = 500)
# fig.show()

### Maximum Absolute Error

In [None]:
fig_max_err = go.Figure(data = go.Contour( z= jnp.abs(Umat - Z), x = xt, y = xt ))
fig_max_err.update_layout(title = 'Maximum Absolute Error', autosize = False, width = 500, height = 500)
fig_max_err.update_xaxes(title_text='x')
fig_max_err.update_yaxes(title_text='y')
fig_max_err.show()

### Maximum Relative Error

In [None]:
fig_r_err = go.Figure(data = go.Contour( z = jnp.log10(jnp.abs((Umat - Z)/Umat)), x = xt, y = xt))
fig_r_err.update_layout(title = 'Maximum Relative Error', autosize = False, width = 500, height = 500)
fig_r_err.update_xaxes(title_text='x')
fig_r_err.update_yaxes(title_text='y')
fig_r_err.show()

### Learning history

In [None]:
fig_training = go.Figure()
fig_training.add_trace(go.Scatter(x = jnp.arange(0,num_epochs,500)+1, y = jnp.log10(jnp.array(mse_training_history)), mode='lines', name='train'))
fig_training.add_trace(go.Scatter(x = jnp.arange(0,num_epochs,500)+1, y = jnp.log10(jnp.array(mse_test_history)), mode='lines', name='test'))
fig_training.update_layout(title = 'wMSE History', autosize = False, width = 400, height = 400)
fig_training.update_xaxes(title_text='epoch')
fig_training.update_yaxes(title_text='log10(wMSE)')
fig_training.show()

### L2 relative error

In [None]:
L2_r_err = jnp.sqrt(jnp.sum((uNN(params, T_pred, jnp.ravel(Xt1), jnp.ravel(Xt2)) - U)**2)/jnp.sum(U**2))

In [None]:
print(L2_r_err)

0.007172835


# Output

In [None]:
jnp.save('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 2D/Experiments/Approx Capacity/params_alpha_{1}_date_{0}.npy'\
         .format(datetime.date.today(),alpha), np.array(params, dtype = object), allow_pickle = True)

In [None]:
# fig_empty = go.Figure()
# fig_empty.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 2D/Experiments/Approx Capacity/fig_empty.pdf", format = 'pdf')

In [None]:
# fig_max_err.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 2D/Experiments/Approx Capacity/fig_max_err_alpha_{1}_date_{0}.pdf"\
#                         .format(datetime.date.today(),alpha), format = 'pdf')

In [None]:
# fig_r_err.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 2D/Experiments/Approx Capacity/fig_r_err_alpha_{1}_date_{0}.pdf"\
#                       .format(datetime.date.today(),alpha), format = 'pdf')

In [None]:
# fig_training.write_image("/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 2D/Experiments/Approx Capacity/fig_training_alpha_{1}_date_{0}.pdf"\
#                          .format(datetime.date.today(),alpha), format = 'pdf')