In [1]:
import os, sys
from pathlib import Path
script_dir = Path(os.path.dirname(os.path.abspath('')))
module_dir = str(script_dir)
sys.path.insert(0, module_dir + '/modules')
print(module_dir)

# import the rest of the modules
%matplotlib nbagg
import numpy as np
import tensorflow as tf 
import matplotlib.pyplot as plt
import pandas as pd
import arch
import time  
import sim
from mpl_toolkits.axes_grid1 import make_axes_locatable

C:\Users\pinak\Documents\GitHub\fp-solvers


In [2]:
# initialize the Monte-Carlo scheme
dim = 2
def mu(X):
    x, y = tf.split(X, dim, axis=-1)
    z = 4. * (x*x + y*y - 1.0)
    return tf.concat([-x*z, -y*z], axis=-1).numpy()

sigma = np.sqrt(2.0)
N = int(1e6)
n_subdivs = 100
dt = 0.01
n_steps = 1000
save_folder = '../circle-fp/data/{0}D'.format(dim)
X0 = tf.random.uniform(minval=-2., maxval=2., shape=(N, 2)).numpy()
mc = sim.MCProb(save_folder=save_folder, n_subdivs=n_subdivs, mu=mu, sigma=sigma, X0=X0)
mc.ready(n_steps=n_steps, dt=dt)
#mc.compute_all_wo_prop(n_steps=n_steps, dt=dt)
#mc.set_grid(None)

Time taken by propagate is 79.14346933364868 seconds
Time taken by set_grid is 2.336635112762451 seconds
Time taken by assign_pts is 2.8146474361419678 seconds
Time taken by ready is 84.2947518825531 seconds


In [3]:
# load the learned solution
domain = mc.get_grid()
net = arch.LSTMForgetNet(num_nodes=50, num_blocks=3)
net.load_weights('../circle-fp/data/{0}D/circle{0}D'.format(dim)).expect_partial()
sol = lambda *args: tf.exp(net(*args))

# compute MC estimate 
i, j = 0, 1
coords = np.genfromtxt('{}/coordinates.csv'.format(save_folder), delimiter=',')
boxes, counts = np.unique(coords[:, [i, j]], return_counts=True, axis=0)
pd.DataFrame(boxes).to_csv('{}/boxes_{}_{}.csv'.format(save_folder, i, j), index=None, header=None)
pd.DataFrame(counts).to_csv('{}/counts_{}_{}.csv'.format(save_folder, i, j), index=None, header=None)
prob = np.zeros((mc.n_subdivs, mc.n_subdivs))
for k, b in enumerate(boxes):
    prob[int(b[0]), int(b[1])] = counts[k] 
prob /= np.sum(prob)

In [4]:
from scipy.special import erf
D = 1.0
def p_inf(x, y):
  Z = 0.5 * np.sqrt(np.pi**3 * D) * (1. + erf(1/np.sqrt(D)))
  return tf.exp(-(x**2 + y**2 - 1.)**2 / D) / Z


def fmt(x, pos):
    a, b = '{:.2e}'.format(x).split('e')
    b = int(b)
    return r'${} \times 10^{{{}}}$'.format(a, b)

import matplotlib.ticker as ticker
# set up plotting parameters
scale = 5
xlabel_size = ylabel_size = 15 + scale
tick_size = 10 + scale
legend_size = 15 + scale
title_size = 15 + scale
cbar_tick_size = 10 + scale


resolution = n_subdivs
def plot_error(learned, low, high, mc_data):
  start = time.time()
  fig = plt.figure(figsize=(16, 8))
  ax_l = fig.add_subplot(121)
  ax_m = fig.add_subplot(122)
    
  div_l = make_axes_locatable(ax_l)
  cax_l = div_l.append_axes('right', '5%', '5%')
  div_m = make_axes_locatable(ax_m)
  cax_m = div_m.append_axes('right', '5%', '5%')
  
  x = np.linspace(low[0], high[0], num=resolution, endpoint=True).astype('float32')
  y = np.linspace(low[1], high[1], num=resolution, endpoint=True).astype('float32')
  y = np.repeat(y, resolution, axis=0).reshape((-1, 1))
  x = np.array(list(x) * resolution).reshape((-1, 1))
  z_l = learned(x, y).numpy()
  truth = p_inf(x, y).numpy()
  truth = truth / truth.sum()
  z_l = z_l/z_l.sum()
  grid = (resolution, resolution)
  x = x.reshape(grid)
  y = y.reshape(grid)
  z_l = np.abs(z_l.reshape(grid) - truth.reshape(grid))
  im = ax_l.pcolormesh(x, y, z_l, cmap='inferno', shading='auto')
  cbar = fig.colorbar(im, format=ticker.FuncFormatter(fmt), ax=ax_l, cax=cax_l)
  ax_l.tick_params(axis='both', which='major', labelsize=tick_size)
  ax_l.tick_params(axis='both', which='minor', labelsize=tick_size)
  cbar.ax.tick_params(labelsize=15) 
  ax_l.set_title('Deep learning absolute error', fontsize=title_size)
    
  z_m = np.abs(truth.reshape(grid) - mc_data)
  im = ax_m.pcolormesh(x, y, z_m, cmap='inferno', shading='auto')
  cbar = fig.colorbar(im, format=ticker.FuncFormatter(fmt), ax=ax_m, cax=cax_m)
  ax_m.tick_params(axis='both', which='major', labelsize=tick_size)
  ax_m.tick_params(axis='both', which='minor', labelsize=tick_size)
  cbar.ax.tick_params(labelsize=15)
  ax_m.set_title('Monte Carlo absolute error', fontsize=title_size)
  
  plt.tight_layout()
  plt.savefig('../plots/2D-error.png')
  plt.show()
  end = time.time()
  print('Time taken by this cell is {} seconds'.format(end - start))

plot_error(learned=sol, low=domain.mins, high=domain.maxs, mc_data=prob)

<IPython.core.display.Javascript object>

Time taken by this cell is 1.5597317218780518 seconds
