In [None]:
import sys

import numpy as np

sys.path.append('../Packages')
import util.tensors as tensors
import data.convert as convert
import algo.geodesic as geo
import algo.euler as euler

from matplotlib import pyplot as plt
from disp.vis import *
%matplotlib widget

In [None]:
name = 'braid'
input_dir = '../Brains'
output_dir = f'../Checkpoints/{name}'

epoch = 1000
mask = convert.read_nhdr(f'{input_dir}/{name}/{name}_filt_mask.nhdr').double().permute(1, 0).numpy()
tensor_pred_lin = convert.read_nhdr(f'{input_dir}/{name}/{name}_learned_tensors_{epoch}.nhdr').permute(2, 1, 0).numpy()
vector_lin1 = convert.read_nhdr(f'{input_dir}/sin/sin_vector_field.nhdr').permute(2, 0, 1).numpy()
vector_lin2 = convert.read_nhdr(f'{input_dir}/cos/cos_vector_field.nhdr').permute(2, 0, 1).numpy()

tensor_pred_mat = tensors.lin2mat(tensor_pred_lin)
metric_pred_mat = np.linalg.inv(tensor_pred_mat)
metric_pred_lin = tensors.mat2lin(metric_pred_mat)

## Geodesic plotting

In [None]:
geo_delta_t, euler_delta_t = 5e-2, 5e-2
geo_iters, euler_iters = 60000, 60000

In [None]:
start_coords = np.array([60, 29])
init_velocities = vector_lin1[:, start_coords[0], start_coords[1]]
geox_pred1, geoy_pred1 = geo.geodesicpath('f', tensor_pred_lin, vector_lin1, mask, \
                                          start_coords, init_velocities, \
                                          geo_delta_t, iter_num=geo_iters, both_directions=True)

eulx1, euly1 = euler.eulerpath_vectbase_2d_w_dv(vector_lin1, mask, start_coords, euler_delta_t, iter_num=euler_iters,
                                                both_directions=True)

In [None]:
start_coords = np.array([60, 70])
init_velocities = vector_lin2[:, start_coords[0], start_coords[1]]
geox_pred2, geoy_pred2 = geo.geodesicpath('f', tensor_pred_lin, vector_lin2, mask, \
                                          start_coords, init_velocities, \
                                          geo_delta_t, iter_num=geo_iters, both_directions=True)

eulx2, euly2 = euler.eulerpath_vectbase_2d_w_dv(vector_lin2, mask, start_coords, euler_delta_t, iter_num=euler_iters,
                                                both_directions=True)

In [None]:
tens_fig = vis_tensors(metric_pred_lin * np.stack((mask, mask, mask), 0), '', False, scale=8e-1, opacity=0.3,
                       show_axis_labels=False)
vis_path(eulx1, euly1, tens_fig, "integral curve on vector field", 'black', 2, 1, False, show_legend=False)
vis_path(eulx2, euly2, tens_fig, "integral curve on vector field", 'black', 2, 1, False, show_legend=False)
vis_path(geox_pred1, geoy_pred1, tens_fig, f"geodesic on learned {name}", '#0082fb', 10, 1, False, show_legend=False)
vis_path(geox_pred2, geoy_pred2, tens_fig, f"geodesic on learned {name}", '#0082fb', 10, 1, False, show_legend=False)
plt.axis('off')
plt.plot([60, 60], [29, 70], linestyle='', marker='*', color='black', markersize=12)
# plt.savefig(f'{output_dir}/{name}_{epoch}_for_nips.png', bbox_inches='tight', dpi=300)