**Import required modules**

In [2]:
import os, sys
from pathlib import Path
script_dir = Path(os.path.dirname(os.path.abspath('')))
module_dir = str(script_dir)
sys.path.insert(0, module_dir + '/modules')
print(module_dir)
# import the necessary modules
import numpy as np
import tensorflow as tf
import lss_solver as lss
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

C:\Users\pinak\Documents\GitHub\non-grad3D


**Define the equation through the $\mathcal L_{\log}$ operator**

In [5]:
DTYPE = tf.float32
D = 1.0
b = 0.2
dim = 3
ones = np.ones(dim)
domain = [-4. * ones, 4. * ones]
save_folder = '../data/Thomas-true-vs-learned'

def mu(x, y, z):
  p = tf.math.sin(y) - b * x
  q = tf.math.sin(z) - b * y 
  r = tf.math.sin(x) - b * z
  return p, q, r


@tf.function
def diff_log_op(f, x, y, z):
    with tf.GradientTape(persistent=True) as tape:
        tape.watch([x, y, z])
        f_ = f(x, y, z)
        f_x, f_y, f_z = tape.gradient(f_, [x, y, z])
    f_xx = tape.gradient(f_x, x)
    f_yy = tape.gradient(f_y, y)
    f_zz = tape.gradient(f_z, z)
    p, q, r = mu(x, y, z)
    return -(p*f_x + q*f_y + r*f_z) + 3. * b + D*(f_xx + f_yy + f_zz + f_x**2 + f_y**2 + f_z**2)

**Set up experiment parameters and learn the stationary distribution**

In [7]:
learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay([1000, 2000, 10000, 50000], [5e-3, 1e-3, 5e-4, 1e-4, 1e-5])
optimizer = tf.keras.optimizers.Adam(learning_rate)
solver = lss.LogSteadyStateSolver(num_nodes=50, num_blocks=3, dtype=DTYPE, name='Thomas'.format(dim),\
                                  diff_log_op=diff_log_op, optimizer=optimizer, domain=domain)
solver.learn(epochs = 10000, n_sample = 1000, save_folder=save_folder, save_along=100, stop_saving=10000)

 Epoch        Loss        Runtime(s)
     0    0.536071           11.1047
    10    0.052583           11.5076
    20    0.027198           11.9191
    30    0.017032           12.2758
    40    0.012083           12.6525
    50    0.009153           13.0615
    60    0.005836           13.4362
    70    0.004756           13.8354
    80    0.003275           14.2221
    90    0.002630           14.5771
   100    0.001957           14.9679
   110    0.001652           15.3672
   120    0.001307           15.7346
   130    0.001149           16.1454
   140    0.000837           16.5151
   150    0.000739           16.8742
   160    0.000684           17.2618
   170    0.000559           17.6167
   180    0.000553           18.0572
   190    0.000475           18.4373
   200    0.000467           18.8386
   210    0.000402           19.2678
   220    0.000328           19.6200
   230    0.000334           20.0055
   240    0.000337           20.4135
   250    0.000320           20.7609
 

  2210    0.000021           98.6662
  2220    0.000021           99.0281
  2230    0.000019           99.5284
  2240    0.000020           99.9223
  2250    0.000023          100.3017
  2260    0.000021          100.7350
  2270    0.000021          101.1238
  2280    0.000022          101.4979
  2290    0.000018          101.8825
  2300    0.000021          102.2652
  2310    0.000018          102.7175
  2320    0.000021          103.1152
  2330    0.000019          103.4889
  2340    0.000018          103.8992
  2350    0.000018          104.2727
  2360    0.000017          104.6687
  2370    0.000017          105.0902
  2380    0.000016          105.4826
  2390    0.000020          105.8433
  2400    0.000020          106.2379
  2410    0.000017          106.6607
  2420    0.000017          107.0835
  2430    0.000018          107.5407
  2440    0.000022          107.9818
  2450    0.000018          108.4457
  2460    0.000018          108.9188
  2470    0.000016          109.3944
 

  4430    0.000013          190.1535
  4440    0.000009          190.5382
  4450    0.000010          190.9283
  4460    0.000010          191.3376
  4470    0.000010          191.7159
  4480    0.000009          192.1087
  4490    0.000011          192.5254
  4500    0.000011          192.9217
  4510    0.000009          193.3465
  4520    0.000009          193.7231
  4530    0.000009          194.1040
  4540    0.000009          194.4899
  4550    0.000010          194.9146
  4560    0.000012          195.3054
  4570    0.000010          195.7281
  4580    0.000008          196.1230
  4590    0.000008          196.5042
  4600    0.000009          196.8845
  4610    0.000009          197.3186
  4620    0.000009          197.6918
  4630    0.000008          198.0908
  4640    0.000009          198.5049
  4650    0.000008          198.9111
  4660    0.000008          199.2919
  4670    0.000009          199.6746
  4680    0.000008          200.0481
  4690    0.000008          200.4527
 

  6650    0.000005          644.3478
  6660    0.000004          644.7120
  6670    0.000005          645.0714
  6680    0.000005          645.4625
  6690    0.000005          645.8308
  6700    0.000004          646.1769
  6710    0.000005          646.5562
  6720    0.000004          646.9106
  6730    0.000005          647.2869
  6740    0.000005          647.6574
  6750    0.000005          648.0041
  6760    0.000004          648.3697
  6770    0.000005          648.7598
  6780    0.000005          649.1551
  6790    0.000004          649.5639
  6800    0.000006          649.9222
  6810    0.000005          650.3797
  6820    0.000005          650.7818
  6830    0.000005          651.1947
  6840    0.000005          651.5883
  6850    0.000004          651.9836
  6860    0.000006          652.3607
  6870    0.000005          652.7279
  6880    0.000005          653.1206
  6890    0.000005          653.7018
  6900    0.000005          654.1694
  6910    0.000005          654.6111
 

  8870    0.000003          732.4609
  8880    0.000003          732.8331
  8890    0.000003          733.2094
  8900    0.000002          733.5764
  8910    0.000003          734.0061
  8920    0.000003          734.3630
  8930    0.000003          734.7029
  8940    0.000003          735.0966
  8950    0.000002          735.4554
  8960    0.000003          735.8324
  8970    0.000003          736.2360
  8980    0.000003          736.6429
  8990    0.000003          737.0197
  9000    0.000002          737.4117
  9010    0.000003          737.8530
  9020    0.000002          738.2465
  9030    0.000003          738.6202
  9040    0.000002          739.0208
  9050    0.000003          739.3809
  9060    0.000003          739.8414
  9070    0.000003          740.2107
  9080    0.000002          740.5644
  9090    0.000002          740.9463
  9100    0.000002          741.3067
  9110    0.000002          741.7013
  9120    0.000003          742.0708
  9130    0.000002          742.4827
 

**Visualize the learned distribution**

In [None]:
import matplotlib.pyplot as plt

def plot_solutions(learned, true, resolution=30, low=domain[0], high=domain[1]):
  fig = plt.figure(figsize=(16, 8))
  ax_l = fig.add_subplot(121, projection='3d')
  ax_t = fig.add_subplot(122, projection='3d')
  x = np.linspace(low[0], high[0], num=resolution, endpoint=True).astype('float32')
  y = np.linspace(low[1], high[1], num=resolution, endpoint=True).astype('float32')
  y = np.repeat(y, resolution, axis=0).reshape((-1, 1))
  x = np.array(list(x) * resolution).reshape((-1, 1))
  z_l = learned(x, y).numpy()
  z_t = true(x, y).numpy()
  grid = (resolution, resolution)
  x = x.reshape(grid)
  y = y.reshape(grid)
  z_l = z_l.reshape(grid)
  z_t = z_t.reshape(grid)
  ax_l.plot_wireframe(x, y, z_l, color='deeppink')
  ax_l.set_title('learned $p_\infty$', fontsize=15)
  ax_t.plot_wireframe(x, y, z_t, color='blue')
  ax_t.set_title('true $p_\infty$', fontsize=15)
  plt.tight_layout()
  plt.show()
  
def learned(x, y):
  z = tf.zeros_like(x)
  zs = [z] * (dim - 2)
  return tf.exp(solver.net(x, y, *zs))

def true(x, y):
  z = tf.zeros_like(x)
  zs = [z] * (dim - 2)
  return p_inf(x, y, *zs)

plot_solutions(learned=learned, true=true)

**Investigate the size of $θ$**

In [None]:
solver.net.summary()