In [1]:
from functools import partial
from typing import Any, Literal

import numpy as np
from scipy.spatial.transform import Rotation
import vedo

from histalign.backend.models import Orientation

vedo.settings.default_backend = "vtk"

In [2]:
def compute_normal(
    pitch: int,
    yaw: int,
    orientation: Orientation,
) -> np.ndarray:
    match orientation:
        case Orientation.CORONAL:
            normal = [-1, 0, 0]
            rotation = Rotation.from_euler("ZY", [pitch, yaw], degrees=True)
        case Orientation.HORIZONTAL:
            normal = [0, 1, 0]
            rotation = Rotation.from_euler("ZX", [pitch, yaw], degrees=True)
        case Orientation.SAGITTAL:
            normal = [0, 0, 1]
            rotation = Rotation.from_euler("XY", [pitch, yaw], degrees=True)
        case _:
            raise Exception("ASSERT NOT REACHED")

    return rotation.apply(normal)


def generate_grid_points(
    grid_size: int,
    spread: int = 1,
    radius: int = 5,
) -> vedo.Points:
    x = np.linspace(0, (grid_size - 1) * spread, grid_size, dtype=int)
    y = np.linspace(0, (grid_size - 1) * spread, grid_size, dtype=int)
    z = np.linspace(0, (grid_size - 1) * spread, grid_size, dtype=int)

    mesh_grid = np.meshgrid(x, y, z, indexing="ij")
    mesh_grid = np.vstack(mesh_grid)
    mesh_grid = mesh_grid.reshape(3, -1).T

    return vedo.Points(mesh_grid, r=radius)


def snap_to_grid(point: np.ndarray, spread: int) -> np.ndarray:
    base, remainder = np.divmod(np.array(point), spread)
    return (base + np.round(remainder / spread)) * spread


def adjust_angle(value: int, angle: Literal["pitch", "yaw"]) -> None:
    match angle:
        case "pitch":
            global pitch
            pitch = value
        case "yaw":
            global yaw
            yaw = value
        case _:
            raise Exception("ASSERT NOT REACHED")


def get_camera(
    orientation: Orientation,
    grid_size: int,
    spread: int,
) -> dict[str, Any]:
    camera = dict(
        position=np.array((grid_size,) * 3) // 2 * spread,
        focal_point=np.array((grid_size,) * 3) // 2 * spread,
    )
    match orientation:
        case Orientation.CORONAL:
            camera["viewup"] = (0, -1, 0)
            camera["position"][0] += grid_size * 3 * spread
        case Orientation.HORIZONTAL:
            camera["viewup"] = (-1, 0, 0)
            camera["position"][1] += grid_size * 3 * spread
        case Orientation.SAGITTAL:
            camera["viewup"] = (0, -1, 0)
            camera["position"][2] -= grid_size * 3 * spread
        case _:
            raise Exception("ASSERT NOT REACHED")

    return camera


def show(
    volumes: vedo.CommonVisual | list[vedo.CommonVisual],
    camera: dict[str, Any] | None = None,
) -> None:
    if isinstance(volumes, vedo.CommonVisual):
        volumes = [volumes]

    plotter = vedo.Plotter(axes=3, interactive=True)
    for volume in volumes:
        plotter.add(volume)

    plotter.show(camera=camera)

In [34]:
grid_size = 1000
spread = 1

stack_size = 5
stack_spacing = 1

pitch = 6
yaw = 4
orientation = Orientation.CORONAL
camera_orientation = Orientation.SAGITTAL

flat_normal = compute_normal(0, 0, orientation)

In [35]:
plane_origin = np.array([grid_size // 2 * spread] * 3)
plane_normal = compute_normal(pitch, yaw, orientation)
plane_normal_mesh = vedo.Arrow(
    start_pt=plane_origin, end_pt=plane_origin + plane_normal * spread
)
plane_mesh = vedo.Plane(pos=plane_origin, normal=plane_normal, s=(5 * spread,) * 2)

plane_normal_line_point0 = vedo.Plane(
    pos=(flat_normal * spread * 2),
    normal=flat_normal,
    s=(1_000_000,) * 2,
).intersect_with_line(
    plane_origin - 100000 * plane_normal, plane_origin + 100000 * plane_normal
)
plane_normal_line_point1 = vedo.Plane(
    pos=(grid_size * spread,) * 3 + (-flat_normal * spread * 2),
    normal=flat_normal,
    s=((spread * grid_size * 2) ** 2,) * 2,
).intersect_with_line(
    plane_origin - 100000 * plane_normal, plane_origin + 100000 * plane_normal
)
plane_normal_line = vedo.Line(
    p0=plane_normal_line_point0.tolist()[0],
    p1=plane_normal_line_point1.tolist()[0],
    lw=spread // 2,
)

stack_meshes = [
    vedo.Plane(
        pos=plane_origin + i * plane_normal * (stack_spacing * spread),
        normal=plane_normal,
        s=(5 * spread,) * 2,
    )
    for i in range(-stack_size // 2 + 1, stack_size // 2 + 1)
]

intersections = [
    vedo.Plane(
        pos=i * spread * np.abs(flat_normal),
        normal=flat_normal,
        s=(1_000_000,) * 2,
    ).intersect_with_line(
        np.squeeze(plane_normal_line_point0),
        np.squeeze(plane_normal_line_point1),
    )
    for i in range(grid_size)
]

grid_snap_points = [snap_to_grid(point, spread) for point in intersections]

In [33]:
intersections[:10]

[array([[1.73727699e-12, 4.47447900e+03, 5.35156006e+03]]),
 array([[  10.        , 4475.53027344, 5350.85693359]]),
 array([[  20.        , 4476.58105469, 5350.15380859]]),
 array([[  30.        , 4477.63232422, 5349.45068359]]),
 array([[  40.        , 4478.68310547, 5348.74755859]]),
 array([[  50.        , 4479.734375  , 5348.04443359]]),
 array([[  60.        , 4480.78515625, 5347.34130859]]),
 array([[  70.        , 4481.83642578, 5346.63818359]]),
 array([[  80.        , 4482.88720703, 5345.93505859]]),
 array([[  90.        , 4483.93847656, 5345.23193359]])]

In [32]:
show(
    [
        # generate_grid_points(grid_size, spread=spread, radius=spread // 2),
        plane_mesh,
        stack_meshes,
        plane_normal_mesh,
        plane_normal_line,
        [vedo.Point(point.tolist()[0]) for point in intersections],
        [vedo.Point(point.tolist()[0], c="blue") for point in grid_snap_points],
    ],
    camera=get_camera(camera_orientation, grid_size, spread),
)