<a href="https://colab.research.google.com/github/sjiang23/senbaojiang.github.io/blob/main/Heat_3D_FDM_computation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax.numpy as jnp
import jax.scipy as jsp
from jax import grad, jit, vmap, jacobian, hessian, jacrev, pmap, lax
from jax import random
from jax.scipy import optimize
from jax.example_libraries import optimizers
from functools import partial
import numpy as np
import scipy as sp
from tqdm import tqdm
import datetime

In [None]:
def sigmoid(x):
  return 1./(1 + jnp.exp(-x))

def init_cond(x1, x2, x3):
  return (693/512)**3 * ( (1 - x1**2) * (1 - x2**2) * (1 - x3**2) )**5

In [None]:
def levy_const(alpha):
    return 2**alpha * sp.special.gamma((3+alpha)/2) \
        /( jnp.sqrt(jnp.pi)**3 * jnp.abs(sp.special.gamma(-alpha/2)) )

In [None]:
def analytical_part(x1, x2, x3):
  I = ( jnp.arctan((1 - x3)/jnp.sqrt((1 - x1)**2 + (1 - x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 - x1)**2 + (1 - x2)**2)) ) \
  * jnp.sqrt((1 - x1)**2 + (1 - x2)**2)/((1 - x1) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x3)/jnp.sqrt((1 + x1)**2 + (1 - x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 + x1)**2 + (1 - x2)**2)) ) \
  * jnp.sqrt((1 + x1)**2 + (1 - x2)**2)/((1 + x1) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x3)/jnp.sqrt((1 + x1)**2 + (1 + x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 + x1)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 + x1)**2 + (1 + x2)**2)/((1 + x1) * (1 + x2)) \
  \
  + ( jnp.arctan((1 - x3)/jnp.sqrt((1 - x1)**2 + (1 + x2)**2)) + jnp.arctan((1 + x3)/jnp.sqrt((1 - x1)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 - x1)**2 + (1 + x2)**2)/((1 - x1) * (1 + x2)) \
  \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 - x1)**2 + (1 - x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 - x1)**2 + (1 - x3)**2)) )\
  * jnp.sqrt((1 - x1)**2 + (1 - x3)**2)/((1 - x1) * (1 - x3)) \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 + x1)**2 + (1 - x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 + x1)**2 + (1 - x3)**2)) )\
  * jnp.sqrt((1 + x1)**2 + (1 - x3)**2)/((1 + x1) * (1 - x3)) \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 + x1)**2 + (1 + x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 + x1)**2 + (1 + x3)**2)) )\
  * jnp.sqrt((1 + x1)**2 + (1 + x3)**2)/((1 + x1) * (1 + x3)) \
  \
  + ( jnp.arctan((1 - x2)/jnp.sqrt((1 - x1)**2 + (1 + x3)**2)) + jnp.arctan((1 + x2)/jnp.sqrt((1 - x1)**2 + (1 + x3)**2)) )\
  * jnp.sqrt((1 - x1)**2 + (1 + x3)**2)/((1 - x1) * (1 + x3)) \
  \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 - x3)**2 + (1 - x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 - x3)**2 + (1 - x2)**2)) )\
  * jnp.sqrt((1 - x3)**2 + (1 - x2)**2)/((1 - x3) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 + x3)**2 + (1 - x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 + x3)**2 + (1 - x2)**2)) )\
  * jnp.sqrt((1 + x3)**2 + (1 - x2)**2)/((1 + x3) * (1 - x2)) \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 + x3)**2 + (1 + x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 + x3)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 + x3)**2 + (1 + x2)**2)/((1 + x3) * (1 + x2)) \
  \
  + ( jnp.arctan((1 - x1)/jnp.sqrt((1 - x1)**2 + (1 + x2)**2)) + jnp.arctan((1 + x1)/jnp.sqrt((1 - x3)**2 + (1 + x2)**2)) )\
  * jnp.sqrt((1 - x3)**2 + (1 + x2)**2)/((1 - x3) * (1 + x2))

  return I/2

In [None]:
alpha = 1
omega0 = 2.97121097252838375756

levy_constant = levy_const(alpha)

# FDM

In [None]:
dx = 1/20
dt_ = dx**2/4
# meshgrid
yt = np.linspace(-1, 1, int(2/dx) + 1) # _t for test
xt = yt[1:-1]
Yt1, Yt2, Yt3 = np.meshgrid(yt, yt, yt) 
Xt1, Xt2, Xt3 = np.meshgrid(xt, xt, xt) 

In [None]:
def unravel(U,d):
  if d == 1:
    return U
  else:
    N = int(np.round(len(U)**(1./d)))
    Umat = jnp.empty(shape = tuple([N]*d) )
    for i in range(N):
      Umat = Umat.at[i].set( unravel(U.at[i*N**(d-1):(i+1)*N**(d-1)].get(),d-1) )
    return Umat

In [None]:
def get_matrix_3d(dx, X1, X2, X3, Y1, Y2, Y3):
  raveled_X1, raveled_X2, raveled_X3 = np.ravel(X1), np.ravel(X2), np.ravel(X3)
  def discrete_sum(x1, x2, x3):
    arr = ( (Y1 - x1)**2 + (Y2 - x2)**2 + (Y3 - x3)**2 )**2
    arr[0,:,:] = arr[0,:,:] * 2
    arr[-1,:,:] = arr[-1,:,:] * 2
    arr[:,0,:] = arr[:,0,:] * 2
    arr[:,-1,:] = arr[:,-1,:] * 2
    arr[:,:,0] = arr[:,:,0] * 2
    arr[:,:,-1] = arr[:,:,-1] * 2
    arr = np.where(arr == 0, np.inf, arr)
    return np.sum(1./arr)

  def dense_mat(x1, x2, x3):
    arr = ((Y1 - x1)**2 + (Y2 - x2)**2 + (Y3 - x3)**2 )**2
    arr = np.where(arr == 0, np.inf, arr)
    arr = arr[1:-1,1:-1,1:-1]
    return np.ravel(1./arr)

  # dense part
  M = np.zeros(shape = (X1.size, X1.size))
  for k in tqdm(range(X1.size)):
    M[k] = levy_constant * dx**3 * dense_mat(raveled_X1[k], raveled_X2[k], raveled_X3[k])

  # diag part
  for k in tqdm(range(X1.size)):
    M[k,k] = M[k,k] - levy_constant * (discrete_sum(raveled_X1[k],raveled_X2[k],raveled_X3[k]) * dx**3  \
                                       + analytical_part(raveled_X1[k],raveled_X2[k],raveled_X3[k]))
  
  # corr part
  N = X1.shape[0]
  left_end, right_end = X1.min(), X1.max()
  for k in tqdm(range(X1.size)):
    x1, x2, x3 = raveled_X1[k], raveled_X2[k], raveled_X3[k]
    E = np.zeros(shape = (X1.size,))
    # 1
    if (x1 == left_end) and (x2 == left_end) and (x3 == left_end):
      # E = E.at[jnp.array([k, k+1, k+2, k+N, k+2*N, k+N**2, k+2*N**2])].set([3,-2,1,-2,1,-2,1])
      E[[k, k+1, k+2, k+N, k+2*N, k+N**2, k+2*N**2]] = [3,-2,1,-2,1,-2,1]
    # 2
    elif (x1 == left_end) and (x2 == left_end) and (x3 > left_end) and (x3 < right_end):
      # E = E.at[jnp.array([k-1, k, k+1, k+N, k+2*N, k+N**2, k+2*N**2])].set([1,0,1,-2,1,-2,1])
      E[[k-1, k, k+1, k+N, k+2*N, k+N**2, k+2*N**2]] = [1,0,1,-2,1,-2,1]
    # 3
    elif (x1 == left_end) and (x2 == left_end) and (x3 == right_end): 
      # E = E.at[jnp.array([k-2, k-1, k, k+N, k+2*N, k+N**2, k+2*N**2])].set([1,-2,3,-2,1,-2,1])
      E[[k-2, k-1, k, k+N, k+2*N, k+N**2, k+2*N**2]] = [1,-2,3,-2,1,-2,1]
    # 4
    elif (x1 > left_end) and (x1 < right_end) and (x2 == left_end) and (x3 == left_end): 
      # E = E.at[jnp.array([k-N, k, k+1, k+2, k+N, k+N**2, k+2*N**2])].set([1,0,-2,1,1,-2,1])
      E[[k-N, k, k+1, k+2, k+N, k+N**2, k+2*N**2]] = [1,0,-2,1,1,-2,1]
    # 5
    elif (x1 > left_end) and (x1 < right_end) and (x2 == left_end) and (x3 > left_end) and (x3 < right_end): 
      # E = E.at[jnp.array([k-N, k-1, k, k+1, k+N, k+N**2, k+2*N**2])].set([1,1,-3,1,1,-2,1]) 
      E[[k-N, k-1, k, k+1, k+N, k+N**2, k+2*N**2]] = [1,1,-3,1,1,-2,1]
    # 6
    elif (x1 > left_end) and (x1 < right_end) and (x2 == left_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-N, k-2, k-1, k, k+N, k+N**2, k+2*N**2])].set([1,1,-2,0,1,-2,1])
      E[[k-N, k-2, k-1, k, k+N, k+N**2, k+2*N**2]] = [1,1,-2,0,1,-2,1]
    # 7
    elif (x1 == right_end) and (x2 == left_end) and (x3 == left_end): 
      # E = E.at[jnp.array([k-2*N, k-N, k, k+1, k+2, k+N**2, k+2*N**2])].set([1,-2,3,-2,1,-2,1]) 
      E[[k-2*N, k-N, k, k+1, k+2, k+N**2, k+2*N**2]] = [1,-2,3,-2,1,-2,1]
    # 8
    elif (x1 == right_end) and (x2 == left_end) and (x3 > left_end) and (x3 < right_end):
      # E = E.at[jnp.array([k-2*N, k-N, k-1, k, k+1, k+N**2, k+2*N**2])].set([1,-2,1,0,1,-2,1]) 
      E[[k-2*N, k-N, k-1, k, k+1, k+N**2, k+2*N**2]] = [1,-2,1,0,1,-2,1]
    # 9
    elif (x1 == right_end) and (x2 == left_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-2*N, k-N, k-2, k-1, k, k+N**2, k+2*N**2])].set([1,-2,1,-2,3,-2,1])
      E[[k-2*N, k-N, k-2, k-1, k, k+N**2, k+2*N**2]] = [1,-2,1,-2,3,-2,1]
    # 10
    elif (x1 == left_end) and (x2 > left_end) and (x2 < right_end) and (x3 == left_end):
      # E = E.at[jnp.array([k-N**2, k, k+1, k+2, k+N, k+2*N, k+N**2])].set([1,0,-2,1,-2,1,1])
      E[[k-N**2, k, k+1, k+2, k+N, k+2*N, k+N**2]] = [1,0,-2,1,-2,1,1]
    # 11
    elif (x1 == left_end) and (x2 > left_end) and (x2 < right_end) and (x3 > left_end) and (x3 < right_end):
      # E = E.at[jnp.array([k-N**2, k-1, k, k+1, k+N, k+2*N, k+N**2])].set([1,1,-3,1,-2,1,1])
      E[[k-N**2, k-1, k, k+1, k+N, k+2*N, k+N**2]] = [1,1,-3,1,-2,1,1]
    # 12
    elif (x1 == left_end) and (x2 > left_end) and (x2 < right_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-N**2, k-2, k-1, k, k+N, k+2*N, k+N**2])].set([1,1,-2,0,-2,1,1]) 
      E[[k-N**2, k-2, k-1, k, k+N, k+2*N, k+N**2]] = [1,1,-2,0,-2,1,1]
    # 13
    elif (x1 > left_end) and (x1 < right_end) and (x2 > left_end) and (x2 < right_end) and (x3 == left_end):
      # E = E.at[jnp.array([k-N**2, k-N, k, k+1, k+2, k+N, k+N**2])].set([1,1,-3,-2,1,1,1]) 
      E[[k-N**2, k-N, k, k+1, k+2, k+N, k+N**2]] = [1,1,-3,-2,1,1,1]
    # 15
    elif (x1 > left_end) and (x1 < right_end) and (x2 > left_end) and (x2 < right_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-N**2, k-N, k-2, k-1, k, k+N, k+N**2])].set([1,1,1,-2,-3,1,1]) 
      E[[k-N**2, k-N, k-2, k-1, k, k+N, k+N**2]] = [1,1,1,-2,-3,1,1]
    # 16
    elif (x1 == right_end) and (x2 > left_end) and (x2 < right_end) and (x3 == left_end):
      # E = E.at[jnp.array([k-N**2, k-2*N, k-N, k, k+1, k+2, k+N**2])].set([1,1,-2,0,-2,1,1])
      E[[k-N**2, k-2*N, k-N, k, k+1, k+2, k+N**2]] = [1,1,-2,0,-2,1,1]
    # 17
    elif (x1 == right_end) and (x2 > left_end) and (x2 < right_end) and (x3 > left_end) and (x3 < right_end):
      # E = E.at[jnp.array([k-N**2, k-2*N, k-N, k-1, k, k+1, k+N**2])].set([1,1,-2,1,-3,1,1]) 
      E[[k-N**2, k-2*N, k-N, k-1, k, k+1, k+N**2]] = [1,1,-2,1,-3,1,1]
    # 18
    elif (x1 == right_end) and (x2 > left_end) and (x2 < right_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-N**2, k-2*N, k-N, k-2, k-1, k, k+N**2])].set([1,1,-2,1,-2,0,1]) 
      E[[k-N**2, k-2*N, k-N, k-2, k-1, k, k+N**2]] = [1,1,-2,1,-2,0,1]
    # 19
    elif (x1 == left_end) and (x2 == right_end) and (x3 == left_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k, k+1, k+2, k+N, k+2*N])].set([1,-2,3,-2,1,-2,1])
      E[[k-2*N**2, k-N**2, k, k+1, k+2, k+N, k+2*N]] = [1,-2,3,-2,1,-2,1]
    # 20
    elif (x1 == left_end) and (x2 == right_end) and (x3 > left_end) and (x3 < right_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-1, k, k+1, k+N, k+2*N])].set([1,-2,1,0,1,-2,1])
      E[[k-2*N**2, k-N**2, k-1, k, k+1, k+N, k+2*N]] = [1,-2,1,0,1,-2,1]
    # 21 
    elif (x1 == left_end) and (x2 == right_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-2, k-1, k, k+N, k+2*N])].set([1,-2,1,-2,3,-2,1])
      E[[k-2*N**2, k-N**2, k-2, k-1, k, k+N, k+2*N]] = [1,-2,1,-2,3,-2,1]
    # 22
    elif (x1 > left_end) and (x1 < right_end) and (x2 == right_end) and (x3 == left_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-N, k, k+1, k+2, k+N])].set([1,-2,1,0,-2,1,1])
      E[[k-2*N**2, k-N**2, k-N, k, k+1, k+2, k+N]] = [1,-2,1,0,-2,1,1]
    # 23
    elif (x1 > left_end) and (x1 < right_end) and (x2 == right_end) and (x3 > left_end) and (x3 < right_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-N, k-1, k, k+1, k+N])].set([1,-2,1,1,-3,1,1])
      E[[k-2*N**2, k-N**2, k-N, k-1, k, k+1, k+N]] = [1,-2,1,1,-3,1,1]
    # 24
    elif (x1 > left_end) and (x1 < right_end) and (x2 == right_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-N, k-2, k-1, k, k+N])].set([1,-2,1,1,-2,0,1])
      E[[k-2*N**2, k-N**2, k-N, k-2, k-1, k, k+N]] = [1,-2,1,1,-2,0,1]
    # 25
    elif (x1 == right_end) and (x2 == right_end) and (x3 == left_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-2*N, k-N, k, k+1, k+2])].set([1,-2,1,-2,3,-2,1])
      E[[k-2*N**2, k-N**2, k-2*N, k-N, k, k+1, k+2]] = [1,-2,1,-2,3,-2,1]
    # 26
    elif (x1 == right_end) and (x2 == right_end) and (x3 > left_end) and (x3 < right_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-2*N, k-N, k-1, k, k+1])].set([1,-2,1,-2,1,0,1])
      E[[k-2*N**2, k-N**2, k-2*N, k-N, k-1, k, k+1]]= [1,-2,1,-2,1,0,1]
    # 27
    elif (x1 == right_end) and (x2 == right_end) and (x3 == right_end):
      # E = E.at[jnp.array([k-2*N**2, k-N**2, k-2*N, k-N, k-2, k-1, k])].set([1,-2,1,-2,1,-2,3])
      E[[k-2*N**2, k-N**2, k-2*N, k-N, k-2, k-1, k]] = [1,-2,1,-2,1,-2,3]
    # 14 (internal points)
    else:
      # E = E.at[jnp.array([k-N**2, k-N, k-1, k, k+1, k+N, k+N**2])].set([1, 1, 1, -6, 1, 1, 1]) 
      E[[k-N**2, k-N, k-1, k, k+1, k+N, k+N**2]] = [1, 1, 1, -6, 1, 1, 1]
    
    M[k] = M[k] + (0.5 * levy_constant * omega0/dx) * E
  
  return M  


In [None]:
M = get_matrix_3d(dx, Xt1, Xt2, Xt3, Yt1, Yt2, Yt3)

In [None]:
np.save('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Experiment Data/Heat 3D/Ref_mat_{0}_alpha_{1}_dx_{2}.npy'\
         .format(datetime.date.today(),alpha, dx), \
         M, allow_pickle = True)

# Reference Solution Computation

In [None]:
dx = 1/20
dt_ = dx**2/4
T_preds = jnp.linspace(0.2,0.25,6)
# meshgrid
yt = np.linspace(-1, 1, int(2/dx) + 1) # _t for test
xt = yt[1:-1]
Yt1, Yt2, Yt3 = np.meshgrid(yt, yt, yt) 
Xt1, Xt2, Xt3 = np.meshgrid(xt, xt, xt) 

In [None]:
A = jnp.load('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/Ref_mat_2022-03-23_alpha_1_dx_0.05.npy', allow_pickle = True)

In [None]:
for T_pred in T_preds:
  print(T_pred)
  steps = int(jnp.round(T_pred/dt_))
  U = jnp.ravel(init_cond(Xt1,Xt2,Xt3))
  for _ in tqdm((range(steps))):
    U1 = U + dt_ * jnp.dot(A, U)
    U2 = 3/4 * U + 1/4 * U1 + 1/4 * dt_ * jnp.dot(A, U1)
    U = 1/3 * U + 2/3 * U2 + 2/3 * dt_ * jnp.dot(A, U2)
  jnp.save('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/OU_3D_ref_sol_T_{}.npy'\
         .format(T_pred), U, allow_pickle = True)

In [None]:
steps = int(jnp.round(T_pred/dt_))
U = jnp.ravel(init_cond(Xt1,Xt2,Xt3))

In [None]:
for _ in tqdm((range(steps))):
  U1 = U + dt_ * jnp.dot(A, U)
  U2 = 3/4 * U + 1/4 * U1 + 1/4 * dt_ * jnp.dot(A, U1)
  U = 1/3 * U + 2/3 * U2 + 2/3 * dt_ * jnp.dot(A, U2)

In [None]:
jnp.save('/content/drive/MyDrive/Colab Notebooks/Fokker-Planck Equations/Absorbing Boundary/Heat 3D/Experiments/OU_3D_ref_sol_T_{}.npy'\
         .format(T_pred), U, allow_pickle = True)