In [10]:
# add folders to Python's search space
import os, sys
from pathlib import Path
script_dir = Path(os.path.dirname(os.path.abspath('')))
module_dir = str(script_dir)
sys.path.insert(0, module_dir + '/modules')
print(module_dir)

# import the rest of the modules
%matplotlib nbagg
import numpy as np
import tensorflow as tf 
import matplotlib.pyplot as plt
import arch
import pandas as pd
from scipy.signal import savgol_filter
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.transforms as mtrans
import dom
DTYPE = 'float32'

C:\Users\pinak\Documents\GitHub\var-al


In [11]:
# load learned solutions
net_p = arch.LSTMForgetNet(50, 3, DTYPE, dim=3, name="Beltrami")
net_p.load_weights('../data/Beltrami/{}'.format(net_p.name)).expect_partial()

net_al = arch.LSTMForgetNet(50, 3, DTYPE, dim=3, name='Beltrami-al')
net_al.load_weights('../data/Beltrami-al/{}'.format(net_al.name)).expect_partial()

def curl(f, x, y, z):
    with tf.GradientTape(persistent=True) as tape:
        tape.watch([x, y, z])
        Ax, Ay, Az = tf.split(f(x, y, z), 3, axis=-1)
    Ax_y = tape.gradient(Ax, y)
    Ay_x = tape.gradient(Ay, x)
    Ax_z = tape.gradient(Ax, z)
    Az_x = tape.gradient(Az, x)
    Ay_z = tape.gradient(Ay, z)
    Az_y = tape.gradient(Az, y)
    return tf.concat([(Az_y - Ay_z), (Ax_z - Az_x), (Ay_x - Ax_y)], axis=-1) 

In [13]:
# set up plotting parameters
scale = 0
xlabel_size = ylabel_size = 15 + scale
tick_size = 7 + scale
legend_size = 15 + scale
title_size = 15 + scale
cbar_tick_size = 10 + scale
line_color = "darkgrey"

# plot solutions
def plot_solutions(filename, resolution=6):
    box = dom.Box3D(dtype=DTYPE)
    
    fig = plt.figure(figsize=(10, 10))
    gs = GridSpec(2, 4, figure=fig)
    ax_p = fig.add_subplot(gs[0,:-2], projection='3d')
    ax_al = fig.add_subplot(gs[0,2:], projection='3d')
    ax_t = fig.add_subplot(gs[1,1:-1], projection='3d')
    # plotting params  
    ax_p.tick_params(axis='both', which='major', labelsize=tick_size)
    ax_p.tick_params(axis='both', which='minor', labelsize=tick_size)
    ax_p.set_title('Penalty solution', fontsize=title_size)
    ax_p.set_xlabel('x', fontsize=xlabel_size)
    ax_p.set_ylabel('y', fontsize=ylabel_size)
    
    ax_al.tick_params(axis='both', which='major', labelsize=tick_size)
    ax_al.tick_params(axis='both', which='minor', labelsize=tick_size)
    ax_al.set_title('Augmented Lagrangian solution', fontsize=title_size)
    ax_al.set_xlabel('x', fontsize=xlabel_size)
    ax_al.set_ylabel('y', fontsize=ylabel_size)

    ax_t.tick_params(axis='both', which='major', labelsize=tick_size)
    ax_t.tick_params(axis='both', which='minor', labelsize=tick_size)
    ax_t.set_title('True solution', fontsize=title_size)
    ax_t.set_xlabel('x', fontsize=xlabel_size)
    ax_t.set_ylabel('y', fontsize=ylabel_size)
    ax_t.set_ylabel('z', fontsize=ylabel_size)
    
    
    x, y, z = box.grid_sample(resolution)
    xt, xf = tf.convert_to_tensor(x, dtype=DTYPE), x.flatten()
    yt, yf = tf.convert_to_tensor(y, dtype=DTYPE), y.flatten()
    zt, zf = tf.convert_to_tensor(z, dtype=DTYPE), z.flatten()
    grid3 = (resolution, resolution, resolution)
    grid2 = (resolution, resolution)
    
    Bx, By, Bz = tf.split(curl(net_p, xt, yt, zt), 3, axis=-1)
    p, q, r = Bx.numpy(), By.numpy(), Bz.numpy()
    R = max(np.sqrt(p*p + q*q + r*r))
    p, q, r = p/R, q/R, r/R
    ax_p.quiver(xf, yf, zf, p.flatten(), q.flatten(), r.flatten(), length=0.1, colors=['red']*len(x))
    ax_p.grid(False)

  
    Bx, By, Bz = tf.split(curl(net_p, xt, yt, zt), 3, axis=-1)
    p, q, r = Bx.numpy(), By.numpy(), Bz.numpy()
    R = max(np.sqrt(p*p + q*q + r*r))
    p, q, r = p/R, q/R, r/R
    ax_al.quiver(xf, yf, zf, p.flatten(), q.flatten(), r.flatten(), length=0.1, colors=['blue']*len(x))
    ax_al.grid(False)
    
    
    # plot and save
    fig.subplots_adjust(wspace=0., hspace=0.2)
#     Get the bounding boxes of the axes including text decorations
    l = 0.46
    line = plt.Line2D([l+.045,l+.045],[0.5,1], transform=fig.transFigure, color=line_color)
    fig.add_artist(line)
    l = 0.45
    line = plt.Line2D([0.0,1],[l+.045,l+.045], transform=fig.transFigure, color=line_color)
    fig.add_artist(line)
    fig.tight_layout()
    plt.savefig('{}.png'.format(filename), dpi=300)
    plt.show()

plot_solutions(filename='../plots/Beltrami-surface')

<IPython.core.display.Javascript object>