# Imports

In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import scipy
import numpy as np
import torch
import torch.nn as nn
import pyvista as pv
import pyvistaqt as pvqt
from tqdm import tqdm
import matplotlib.pyplot as plt

import kaolin
import vtk
import matplotlib
colors = matplotlib.colormaps['Set2'].colors

root_dir = 'C:/Users/danpa/OneDrive/Documents/research_code/miccai2025_diffeo'
if root_dir not in sys.path:
    sys.path.append(root_dir)
import deform_diffeo_lib as ddlib

data_dir = os.path.join(root_dir, 'data')

ModuleNotFoundError: No module named 'pyvista'

In [None]:
template_pv = pv.read(os.path.join(data_dir, 'template_test.vtk'))

verts_init_np = template_pv.points
verts_deformed_true_np = np.load(os.path.join(data_dir, 'simulation_results/test/deformed_verts_[id].npy'))

deformed_true_pv = template_pv.copy()
deformed_true_pv.points = verts_deformed_true_np

# plotter = pvqt.BackgroundPlotter()
# _ = plotter.add_mesh(template_pv, show_edges=True, color=colors[0])
# _ = plotter.add_mesh(deformed_true_pv, show_edges=True, color=colors[1])

In [None]:
device = 'cuda'

verts_init = torch.tensor(verts_init_np, dtype=torch.get_default_dtype(), device=device)[None]
verts_deformed_true = torch.tensor(verts_deformed_true_np, dtype=torch.get_default_dtype(), device=device)[None]

# FourierNetwork

### conditions

1. scaling_and_squaring_pointwise
    - optimization worked well, and inverse is near voxelgrid forward Euler
2. diffeomorphic_forward_euler_pointwise
    - optimization worked well, inverse is best

In [None]:
exp_num = 2
n_steps = 10 # only applies to foward euler

In [None]:
img_shape = np.array([120,120,180])
model_v = ddlib.models.FourierNetwork(input_dim=3, output_dim=3, n_random_freqs=128, hidden_dim=128, num_layers=5, freq_stdev=1e-2, img_shape=img_shape)
model_v.to(device)

optimizer_v = torch.optim.Adam(model_v.parameters(), lr=1e-4)

In [None]:
plot_pv = template_pv.copy()

if exp_num == 1:
    displacements = ddlib.ops.scaling_and_squaring_pointwise(model_v, verts_init)
elif exp_num == 2:
    displacements = ddlib.ops.forward_euler_pointwise(model_v, verts_init)

verts_deformed = verts_init + displacements
plot_pv.points = verts_deformed.detach().squeeze().cpu().numpy()

plotter = pvqt.BackgroundPlotter()
_ = plotter.add_mesh(plot_pv, show_edges=True, color=colors[0], name='pred_pv')
_ = plotter.add_mesh(deformed_true_pv, show_edges=True, color=colors[1], opacity=0.3)

In [None]:
n_epochs = 300
pbar = tqdm(range(n_epochs))
for epoch in pbar:
    if exp_num == 1:
        displacements = ddlib.ops.scaling_and_squaring_pointwise(model_v, verts_init)
    elif exp_num == 2:
        displacements = ddlib.ops.forward_euler_pointwise(model_v, verts_init)

    verts_deformed = verts_init + displacements

    loss = (verts_deformed - verts_deformed_true).norm(dim=-1).mean()

    optimizer_v.zero_grad()
    loss.backward()
    optimizer_v.step()

    if epoch % 1 == 0:
        plot_pv.points = verts_deformed.detach().squeeze().cpu().numpy()
        _ = plotter.add_points(np.zeros(3), name='dummy', reset_camera=False)

    pbar.set_postfix_str('loss: {:.3f}'.format(loss.detach().item()))

100%|██████████| 300/300 [00:11<00:00, 25.63it/s, loss: 0.155]


##### check inverse

In [None]:
if exp_num == 1:
    displacements_inv = ddlib.ops.scaling_and_squaring_pointwise(model_v, verts_deformed, reverse_field=True)
elif exp_num == 2:
    displacements_inv = ddlib.ops.forward_euler_pointwise(model_v, verts_deformed, n_steps=n_steps, reverse_field=True)

verts_inversed = verts_deformed + displacements_inv

plot_pv = template_pv.copy()
plot_pv.points = verts_inversed.detach().squeeze().cpu().numpy()

plotter = pvqt.BackgroundPlotter()
actor_plot = plotter.add_mesh(plot_pv, show_edges=True, color=colors[0])
_ = plotter.add_mesh(deformed_true_pv, show_edges=True, color=colors[1], opacity=0.3)
_ = plotter.add_mesh(template_pv, show_edges=True, color=colors[2], opacity=1)
_ = plotter.add_checkbox_button_widget(lambda flag: actor_plot.SetVisibility(flag), value=True, color_on=colors[0])

print(ddlib.utils.compare_tensors(verts_init, verts_inversed))

(False, tensor(5.8188, device='cuda:0', grad_fn=<MaxBackward1>))
