In [1]:
import pyvista as pv
import numpy as np

import torch
from torch_dvf.data import Data
from torch_dvf.transforms import SkeletonPointCloudHierarchy, RadiusPointCloudHierarchy, GeodesicPointCloudHierarchy

torch.manual_seed(0)

<torch._C.Generator at 0x2a8f473b0>

In [2]:
def make_points():
    """Helper to make XYZ points"""
    theta = np.linspace(-2 * np.pi, 2 * np.pi, 1000)
    z = np.linspace(-2, 2, 1000)
    r = z**2 + 1
    x = r * np.sin(theta)
    y = r * np.cos(theta)
    return np.column_stack((x, y, z))

points = make_points()

def polyline_from_points(points):
    poly = pv.PolyData()
    poly.points = points
    the_cell = np.arange(0, len(points), dtype=np.int_)
    the_cell = np.insert(the_cell, 0, len(points))
    poly.lines = the_cell
    return poly

polyline = polyline_from_points(points)
polyline['stuff'] = np.arange(polyline.n_points)
tube = polyline.tube(scalars='stuff', radius_factor=2)

pl = pv.Plotter()
pl.add_mesh(tube, opacity=0.5)
pl.add_mesh(polyline, color="red")
pl.show()

Widget(value='<iframe src="http://localhost:63864/index.html?ui=P_0x2f37b72d0_0&reconnect=auto" class="pyvista…

 JS Error => error: Uncaught TypeError: Cannot mix BigInt and other types, use explicit conversions


In [6]:
data = Data(pos=torch.tensor(tube.points).float())
data_radius = RadiusPointCloudHierarchy((0.5, 0.5, 0.5), (3.0, 4.0, 5.0), interp_simplex="triangle", max_num_neighbors=1024)(data)

In [4]:
data = Data(pos=torch.tensor(tube.points).float())
data_geodesic = GeodesicPointCloudHierarchy((0.5, 0.5, 0.5), (2.0, 3.0, 5.0), interp_simplex="triangle")(data)

In [7]:
data = Data(
    pos=torch.tensor(tube.points).float(),
    skeleton_pos=torch.tensor(polyline.points).float(),
    skeleton_edge_index=torch.tensor([(i, i+1) for i in range(len(polyline.points) - 1)]).T.long()
)
data_skeleton = SkeletonPointCloudHierarchy((0.5, 0.5, 0.5), (3.0, 4.0, 5.0), interp_simplex="triangle", max_num_neighbors=1024)(data)

In [8]:
datas = {
    "pool_skeleton": data_skeleton,
    #"pool_geodesic": data_geodesic,
    "pool_radius": data_radius
}

sampling_index = 0
pl = pv.Plotter(shape=(1, 3))

for i, (name, data) in enumerate(datas.items()):
    print(f"{name} pooling edges: {data['scale0_pool_target'].shape}")
    index = data["scale0_pool_target"][sampling_index]
    pooling_point = data.pos[data["scale0_sampling_index"][index]]
    pool = data["scale0_pool_source"][data["scale0_pool_target"] == index]
    tube[name] = np.zeros(len(tube.points))
    tube[name][pool.numpy()] = 1

    pl.subplot(0, i)
    pl.add_mesh(tube.copy(), opacity=0.5, scalars=name)
    pl.add_mesh(polyline.copy(), color="red")
    pl.add_mesh(pv.PolyData(pooling_point.numpy()), render_points_as_spheres=True, point_size=10)
    pl.add_mesh(pv.Sphere(3.0, pooling_point.numpy()), opacity=0.1)

pl.link_views()
pl.show()

pool_skeleton pooling edges: torch.Size([10260320])
pool_radius pooling edges: torch.Size([10253931])


Widget(value='<iframe src="http://localhost:63864/index.html?ui=P_0x2f1d36650_2&reconnect=auto" class="pyvista…