In [1]:
import numpy as np
import random
import os
import pickle
import copy 
import trimesh 
import copy
import time 

import plotly.graph_objects as go
from plotly.offline import plot
import plotly.figure_factory as ff

from ega import default_meshgraphnet_dataset_path
from ega.algorithms.brute_force import BFGFIntegrator
from ega.util.gaussian_kernel import GaussianKernel
from ega.util.interpolator import Interpolator

In [2]:
def calculate_interpolation_metrics(true_fields, interpolated_fields):
    frobenius_norm = np.linalg.norm(true_fields - interpolated_fields, ord = 'fro')
    cosine_similarity = np.mean((true_fields*interpolated_fields).sum(axis = -1) / \
        np.linalg.norm(true_fields, axis = -1) / np.linalg.norm(interpolated_fields, axis = -1))
    print("Frobenious Norm: {}\nCosine Similarity: {}".format(frobenius_norm, cosine_similarity))

In [22]:
def plot_mesh(world_pos, vertices_interpolate_pos, true_arrow_matrix, interpolated_arrow_matrix, snapshot_index):

    # plot
    x, y, z = world_pos.T
    mesh = go.Mesh3d(x=x, y=y, z=z, alphahull=5, opacity=0.4, color='grey')

    px, py, pz = vertices_interpolate_pos.T
    points = go.Scatter3d(x = px, y = py, z = pz, mode = 'markers', name = 'interpolated_points',
                           marker = dict( size = 3, color = "black"))

    vx_true, vy_true, vz_true = true_arrow_matrix
    true_velocities = go.Scatter3d(x=vx_true, y=vy_true, z=vz_true, mode='lines', name='true_velocities', 
                                   line = dict(color = 'red', width=4))

    vx_interpolated, vy_interpolated, vz_interpolated = interpolated_arrow_matrix
    interpolated_velocities = go.Scatter3d(x=vx_interpolated, y=vy_interpolated, z=vz_interpolated,
        mode='lines', name='interpolated_velocities', line = dict(color = 'blue', width=4))

    fig = go.Figure(data=[points, mesh, true_velocities, interpolated_velocities],)

    fig.show()
    fig.write_image(
        os.path.join(default_meshgraphnet_dataset_path,'flag_simple','flag_{}.png'.format(snapshot_index)))
    plot(fig,filename="vector.html",auto_open=True,image='png',image_height=1000,image_width=1100)


In [23]:
trajactory_index = 0 # specifies the pkl file id to be used for generating snapshots 
snapshot_index = 5 # the number of snapshot of current trajectory to be used 
mask_ratio = 0.03 # divide the known and unknown vertices according to this mask ratio 
scale = 15 # scales the velocity vector for better visualization
sigma = 10 # parameter in gaussian kernel


# read data
meshgraph_path = os.path.join(default_meshgraphnet_dataset_path, 'flag_simple', 'processed_data')
meshgraph_file = os.path.join(meshgraph_path, 'trajectory_{}.pkl'.format(trajactory_index))
mesh_data = pickle.load(open(meshgraph_file,'rb'))[snapshot_index]
print(mesh_data.keys())

vertices = mesh_data['vertices']
adjacency_list = mesh_data['adjacency_list']
weight_list = mesh_data['weight_list']
field = mesh_data['node_features'][:,:3]
world_pos = mesh_data['world_pos']
faces = mesh_data['faces']
n_vertices = len(vertices)
    
# divide vertices into known vertices and vertices to be interpolated
random.seed(0)
vertices_interpolate = random.sample(vertices, int(mask_ratio * n_vertices))
vertices_known = list(set(vertices) - set(vertices_interpolate))
true_fields = field[vertices_interpolate]
vertices_interpolate_pos = world_pos[vertices_interpolate]
n_vertices_interpolate = len(vertices_interpolate)

# create integrator and interpolator
f_fun = GaussianKernel(sigma)
brute_force = BFGFIntegrator(adjacency_list, weight_list, vertices, f_fun)
interpolator = Interpolator(brute_force, vertices_known, vertices_interpolate)

# mask out vertices to be interpolated 
interpolator.integrator._m_matrix[vertices_interpolate,vertices_interpolate] = 0
interpolator.integrator._m_matrix /= interpolator.integrator._m_matrix.sum(axis = 0, keepdims = True)
interpolated_fields = interpolator.interpolate(copy.deepcopy(field))

# calculate arrows representing velocity directions
true_velocity_directions = vertices_interpolate_pos + scale * true_fields
interpolated_velocity_directions = vertices_interpolate_pos + scale * interpolated_fields
true_arrow_matrix = np.zeros((3, 3*n_vertices_interpolate))
interpolated_arrow_matrix = np.zeros((3, 3*n_vertices_interpolate))
true_arrow_matrix[:,2::3] = None
interpolated_arrow_matrix[:,2::3] = None
true_arrow_matrix[:,::3] = vertices_interpolate_pos.T
interpolated_arrow_matrix[:,::3] = vertices_interpolate_pos.T
true_arrow_matrix[:,1::3] = true_velocity_directions.T
interpolated_arrow_matrix[:,1::3] = interpolated_velocity_directions.T


dict_keys(['adjacency_list', 'weight_list', 'node_features', 'vertices', 'world_pos', 'prev_world_pos', 'faces'])


In [24]:
plot_mesh(world_pos, vertices_interpolate_pos, true_arrow_matrix, interpolated_arrow_matrix, snapshot_index)
calculate_interpolation_metrics(true_fields, interpolated_fields)

Frobenious Norm: 0.13804070385932937
Cosine Similarity: 0.7912914905094018
