In [1]:
import pyvista as pv
import numpy as np

import torch
from torch_dvf.data import Data
from torch_dvf.transforms import SkeletonPointCloudHierarchy, RadiusPointCloudHierarchy, GeodesicPointCloudHierarchy

torch.manual_seed(0)

<torch._C.Generator at 0x28e6473b0>

In [6]:
def make_points():
    """Helper to make XYZ points"""
    theta = np.linspace(-2 * np.pi, 2 * np.pi, 1000)
    z = np.linspace(-2, 2, 1000)
    r = z**2 + 1
    x = r * np.sin(theta)
    y = r * np.cos(theta)
    return np.column_stack((x, y, z))

points = make_points()

def polyline_from_points(points):
    poly = pv.PolyData()
    poly.points = points
    the_cell = np.arange(0, len(points), dtype=np.int_)
    the_cell = np.insert(the_cell, 0, len(points))
    poly.lines = the_cell
    return poly

polyline = polyline_from_points(points)
polyline['stuff'] = np.arange(polyline.n_points)
tube = polyline.tube(scalars='stuff', radius_factor=2)

In [3]:
data = Data(pos=torch.tensor(tube.points).float())
data_radius = RadiusPointCloudHierarchy((0.2, 0.5, 0.5, 0.5, 0.5), (3.0, 4.0, 5.0, 6.0, 7.0), interp_simplex="triangle", max_num_neighbors=1024)(data)

In [4]:
data = Data(
    pos=torch.tensor(tube.points).float(),
    skeleton_pos=torch.tensor(polyline.points).float(),
    skeleton_edge_index=torch.tensor([(i, i+1) for i in range(len(polyline.points) - 1)]).T.long()
)
data_skeleton = SkeletonPointCloudHierarchy((0.2, 0.5, 0.5, 0.5, 0.5), (2.0, 5.0, 10.0, 15.0, 20.0), interp_simplex="triangle")(data)

In [5]:
num_scales = 5
sampling_index = 0
pl = pv.Plotter(shape=(2, num_scales))

for j, data in enumerate([data_radius, data_skeleton]):
    points = data.pos

    for i in range(num_scales):
        index = data[f"scale{i}_pool_target"][sampling_index]
        pooling_point = points[data[f"scale{i}_sampling_index"][index]]
        pool = data[f"scale{i}_pool_source"][data[f"scale{i}_pool_target"] == index]
        
        poly = pv.PolyData(points.numpy())
        poly[f"scale{i}"] = np.zeros(poly.n_points)
        poly[f"scale{i}"][pool.numpy()] = 1

        points = points[data[f"scale{i}_sampling_index"]]

        pl.subplot(j, i)
        pl.add_mesh(poly.copy(), opacity=0.5, scalars=f"scale{i}")
        pl.add_mesh(polyline.copy(), color="red")
        pl.add_mesh(pv.PolyData(pooling_point.numpy()), render_points_as_spheres=True, point_size=10)

pl.link_views()
pl.show()

Widget(value='<iframe src="http://localhost:55104/index.html?ui=P_0x2fbc0cf90_1&reconnect=auto" class="pyvista…

 JS Error => error: Uncaught TypeError: Cannot mix BigInt and other types, use explicit conversions
