<a href="https://colab.research.google.com/github/pinakm9/sphere-fp/blob/master/data/6D/sphere6D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Import required modules**

In [7]:
# run this cell to download data and necessary modules
import os, shutil
repo = 'sphere-fp'
if os.path.isdir(repo):
  shutil.rmtree(repo)
!git clone https://github.com/pinakm9/sphere-fp.git
# add modules folder to Python's search path
import sys
sys.path.insert(0, repo + '/modules')
# import the necessary modules
import numpy as np
import tensorflow as tf
import lss_solver as lss
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

Cloning into 'sphere-fp'...
remote: Enumerating objects: 121, done.[K
remote: Total 121 (delta 0), reused 0 (delta 0), pack-reused 121[K
Receiving objects: 100% (121/121), 47.44 MiB | 27.16 MiB/s, done.
Resolving deltas: 100% (34/34), done.


**Define the equation through the $\mathcal L_{\log}$ operator**

In [8]:
DTYPE = tf.float32
D = 1.0
dim = 6
ones = np.ones(dim)
domain = [-2. * ones, 2. * ones]
save_folder = '{}/data/{}D'.format(repo, dim)

@tf.function
def diff_log_op(f, x, y, x1, y1, x2, y2):
    z = 4.*(x**2 + y**2 - 1.) 
    z1 = 4.*(x1**2 + y1**2 - 1.)
    z2 = 4.*(x2**2 + y2**2 - 1.)
    with tf.GradientTape(persistent=True) as tape:
        tape.watch([x, y, x1, y1, x2, y2])
        f_ = f(x, y, x1, y1, x2, y2)
        f_x, f_y, f_x1, f_y1, f_x2, f_y2 = tape.gradient(f_, [x, y, x1, y1, x2, y2])
    f_xx = tape.gradient(f_x, x)
    f_yy = tape.gradient(f_y, y)
    f_x1x1 = tape.gradient(f_x1, x1)
    f_y1y1 = tape.gradient(f_y1, y1)
    f_x2x2 = tape.gradient(f_x2, x2)
    f_y2y2 = tape.gradient(f_y2, y2)
    return z*(x*f_x + y*f_y) + z1*(x1*f_x1 + y1*f_y1) + z2*(x2*f_x2 + y2*f_y2) + 4.*(z+ z1 + z2 + dim) + D*(f_x**2 + f_y**2 + f_xx + f_yy + f_x1**2 + f_y1**2 + f_x1x1 + f_y1y1 + f_x2**2 + f_y2**2 + f_x2x2 + f_y2y2) 

**Define the steady state $p_\infty(\mathbf x)$**

In [9]:
from scipy.special import erf
import numpy as np

def p_inf2(x, y):
  Z = 0.5 * np.sqrt(np.pi**3 * D) * (1. + erf(1/np.sqrt(D)))
  return tf.exp(-(x**2 + y**2 - 1.)**2 / D) / Z 

def p_inf(x, y, x1, y1, x2, y2):
  return p_inf2(x, y) * p_inf2(x1, y1) * p_inf2(x2, y2)

**Set up experiment parameters and learn the stationary distribution**

In [None]:
learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([1000, 2000, 10000], [5e-3, 1e-3, 5e-4, 1e-4])
optimizer = tf.keras.optimizers.Adam(learning_rate)
solver = lss.LogSteadyStateSolver(num_nodes=50, num_blocks=3, dtype=DTYPE, name='sphere{}D'.format(dim), diff_log_op=diff_log_op, optimizer=optimizer, domain=domain)
solver.learn(epochs = 5000, n_sample = 1000, save_folder=save_folder)

 Epoch        Loss        Runtime(s)
     013432.664062           13.2231
    10 8438.982422           13.5877
    20 5234.733398           13.9400
    30 2255.836914           14.4556
    40 1058.288452           14.8020
    50  629.859619           15.1445
    60  388.939911           15.4800
    70  277.967682           15.8198
    80  220.200256           16.1545
    90  190.514191           16.4917
   100  179.540710           16.8228
   110  164.808655           17.1533
   120  140.847687           17.4891
   130  136.604080           17.8250
   140  126.056511           18.1633
   150  116.668999           18.5001
   160  113.209435           18.9237
   170   98.460663           19.4425
   180   98.791435           19.9379
   190   89.620224           20.4225
   200   84.902748           20.9343
   210   74.737778           21.3012
   220   71.985832           21.6277
   230   71.187965           21.9632
   240   62.573185           22.3069
   250   65.493309           22.6361
 

**Visualize the learned distribution**

In [None]:
import matplotlib.pyplot as plt

def plot_solutions(learned, true, resolution=30, low=domain[0], high=domain[1]):
  fig = plt.figure(figsize=(16, 8))
  ax_l = fig.add_subplot(121, projection='3d')
  ax_t = fig.add_subplot(122, projection='3d')
  x = np.linspace(low[0], high[0], num=resolution, endpoint=True).astype('float32')
  y = np.linspace(low[1], high[1], num=resolution, endpoint=True).astype('float32')
  y = np.repeat(y, resolution, axis=0).reshape((-1, 1))
  x = np.array(list(x) * resolution).reshape((-1, 1))
  z_l = learned(x, y).numpy()
  z_t = true(x, y).numpy()
  grid = (resolution, resolution)
  x = x.reshape(grid)
  y = y.reshape(grid)
  z_l = z_l.reshape(grid)
  z_t = z_t.reshape(grid)
  ax_l.plot_wireframe(x, y, z_l, color='deeppink')
  ax_l.set_title('learned $p_\infty$', fontsize=15)
  ax_t.plot_wireframe(x, y, z_t, color='blue')
  ax_t.set_title('true $p_\infty$', fontsize=15)
  plt.tight_layout()
  plt.show()
  
def learned(x, y):
  z = tf.zeros_like(x)
  zs = [z] * (dim - 2)
  return tf.exp(solver.net(x, y, *zs)

def true(x, y):
  z = tf.zeros_like(x)
  zs = [z] * (dim - 2)
  return p_inf(x, y, *zs)

plot_solutions(learned=learned, true=lambda x, y: p_inf(x, y, *args))

**Investigate the size of $θ$**

In [None]:
solver.net.summary()