# Math foundation

## Ray - surface intersection in 3D

### Definitions

A 3D ray is represented parametrically as the set of points $P + tV$ with $t \in \mathbb{R}$, where:

* $P$ is an origin point
* $V$ is a unit direction vector

A 3D surface is defined with a function $F$ of the form: $F(x,y,z) = 0$.

To compute the intersection of a ray with a surface, we are looking for the unknown $t$ such that a point is on the surface and on the ray:

$$
F(P + tV) = 0
$$

We call the quantity on the left side $Q(t)$, and finding the intersections is finding the roots of $Q$.

### Newton's method

In reality there may be multiple or no) solutions. In practice we are going to solve it iteratively so this issue is minimized.

We can solve the intersection equation using Newton's method with the update step:

$$
t_{n+1} = t_n - \frac{Q(t_n)}{Q'(t_n)}
$$

Developing with the multivariate chain rule we get:

$$
t_{n+1} = t_n - \frac{F(P + t_nV)}{V \cdot \nabla F(P + t_n V)}
$$

where "$\cdot$" in the denominator is the dot product, and $\nabla F$ is the gradient of F:

$$
\nabla F = \left(\frac{dF}{dx}, \frac{dF}{dy}, \frac{dF}{dz}\right)
$$

A 3D surface must define functions to compute $F$ and $\nabla F$.

$Q'$ is zero, and the update step is undefined, when the ray and the surface derivative are parallel, or when the norm of the derivative is zero (which should never be the case).

## Surface of revolution

### 2D shape definition

A 2D shape is defined in the $(x, r)$ plane with the implicit equation:
$$
f(x, r) = 0
$$

X is the optical axis, and R is the meridional axis, aka the perpendiular to the optical axis.

Additionally, for working with lenses, we always have $f(0, 0) = 0$ (the curve crosses the origin) and $f'_r(0, r) = 0$ (the curve is vertical at the origin).

### Rotational symmetry

We want to create the corresponding 3D surface by rotation around the X axis. This means that the R axis from before, is now really the axis of the meridional plane. (A meridional plane is a plane that contains the optical axis X).

So $r$ is the distance from the X axis to any point on the surface. In 3D, we have for any meridional plane $r = \sqrt{y^2 + z^2}$.

The definition of the 3D surface of revolution is that the intersection of every meridional plane with it is the 2D surface:

$$
F(x,y,z) = f \left(x,  \sqrt{y^2 + z^2} \right) = 0
$$

Often the form $f(x, \sqrt{y^2 + z^2})$ can be simplified analytically to provide an efficient implementation of $F$.

## Generic form of $\nabla F$ for surfaces of revolution

We have:

$$
F(x, y, z) = f \left(x,  \sqrt{y^2 + z^2} \right)
$$

Therefore:

$$
F'_x(x, y, z) = f'_x(x, \sqrt{y^2 + z^2})
$$
$$
F'_y(x,y,z) = \frac{y}{\sqrt{y^2 + z^2}} f_r' \left(x, \sqrt{y^2 + z^2} \right)
$$
$$
F'_z(x,y,z) = \frac{z}{\sqrt{y^2 + z^2}} f_r' \left(x, \sqrt{y^2 + z^2} \right)
$$

However, for some curves $f$ this expression simplifies a lot and therefore shapes can provide an optimized version of $\nabla F(x, y, z)$, or even $\nabla F(x,y,z) \cdot V$.



## Collision detection with a 3D transform

Surfaces are defined on a local reference frame so that $F(0, 0, 0) = 0$. But what if we want to apply a transform to move it in 3D space? Can we apply scaling, rotation, translation?

Let's assume our 3D transform $T$ is affine invertible and produces points $X'$ given input points $X$ such that $T: X' = AX + B$.

Let's consider some points $X'$ on the new transformed surface, by definition undoing the transform would put them back on the original surface:

$$
F(T^{-1}(X')) = F(A^{-1}(X' - B)) = 0
$$

Given a parametric 3D ray: $P + tV$, finding the intersection with a transformed 3D surface is therefore solving:

$$
F( A^{-1}(P-B) + tA^{-1}V ) = 0
$$

which is useful because we can use the previous Newton solver aproach by applying the inverse transform to the rays, and using the $F$ function defined locally:

$$
\begin{cases}
P' = A^{-1}(P - B)\\
V' = A^{-1}V
\end{cases}
$$

In the common case, $A^{-1}$ can be computed without matrix inversion because it's the product of a rotation and a scaling, each can be easily inverted.

Note that this applies even if the surface is not defined implicitly: we can find collisions with the transformed surface by applying the above inverse transform to the rays and calling the local collision detection code.

Another thing we need to do is convert vectors from the surface local frame, to the global frame, typically surface normals.

A vector $\overrightarrow{N}$ is the difference between its end point $E$ and start point $S$:

$$
\overrightarrow{N} = E - S
$$

So to transform the vector under the affine transformation, we can take the difference of its transformed endpoints:
$$
T(\overrightarrow{N}) = T(E) - T(S)
$$

So after simplifying we get:
$$
T(\overrightarrow{N}) = A(E - S) = A\overrightarrow{N}
$$

## Adding anchors

Similarly as above, it can be useful to add a translation step before the rotation, to model an "anchor". The anchor is the point on the shape that attaches to the global frame. So, our full transform is now four steps:

1. A translation $-A$ to account for the anchor
2. A scale $S$
3. A rotation $R$
4. A translation $T$ to position the shape in the global frame

$$
X' = RS(X - A) + T
$$

The inverse transform is:

$$
X = S^{-1}R^{-1}(X' - T) + A
$$

When $X'$ (the points on the transformed surface) and also the collision point with parametric rays $P + tV$ we can substitue and get:

$$
S^{-1} R^{-1} (P-T) + A + t S^{-1}R^{-1}V
$$

And so we can compute "inverse transformed" rays:

$$
\begin{cases}
P' = S^{-1} R^{-1}(P-T) + A\\
V' = S^{-1}R^{-1}V
\end{cases}
$$

Direct transform of vectors is:
$$
T(\overrightarrow{N}) = RS\overrightarrow{N}
$$

In [1]:
import torch

"""
The transformation matrices returned from the functions in this file assume
the points on which the transformation will be applied are column vectors.
i.e. the R matrix is structured as

    R = [
            [Rxx, Rxy, Rxz],
            [Ryx, Ryy, Ryz],
            [Rzx, Rzy, Rzz],
        ]  # (3, 3)

This matrix can be applied to column vectors by post multiplication
by the points e.g.

    points = [[0], [1], [2]]  # (3 x 1) xyz coordinates of a point
    transformed_points = R * points

To apply the same matrix to points which are row vectors, the R matrix
can be transposed and pre multiplied by the points:

e.g.
    points = [[0, 1, 2]]  # (1 x 3) xyz coordinates of a point
    transformed_points = points * R.transpose(1, 0)
"""

def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
    """
    Return the rotation matrices for one of the rotations about an axis
    of which Euler angles describe, for each value of the angle given.

    Args:
        axis: Axis label "X" or "Y or "Z".
        angle: any shape tensor of Euler angles in radians

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """

    cos = torch.cos(angle)
    sin = torch.sin(angle)
    one = torch.ones_like(angle)
    zero = torch.zeros_like(angle)

    if axis == "X":
        R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
    elif axis == "Y":
        R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
    elif axis == "Z":
        R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
    else:
        raise ValueError("letter must be either X, Y or Z.")

    return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))


def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
    """
    Convert rotations given as Euler angles in radians to rotation matrices.

    Args:
        euler_angles: Euler angles in radians as tensor of shape (..., 3).
        convention: Convention string of three uppercase letters from
            {"X", "Y", and "Z"}.

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
        raise ValueError("Invalid input euler angles.")
    if len(convention) != 3:
        raise ValueError("Convention must have 3 letters.")
    if convention[1] in (convention[0], convention[2]):
        raise ValueError(f"Invalid convention {convention}.")
    for letter in convention:
        if letter not in ("X", "Y", "Z"):
            raise ValueError(f"Invalid letter {letter} in convention string.")
    matrices = [
        _axis_angle_rotation(c, e)
        for c, e in zip(convention, torch.unbind(euler_angles, -1))
    ]
    # return functools.reduce(torch.matmul, matrices)
    return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])

In [14]:
import torchlensmaker as tlm
import torch
import torch.nn
import math

import pprint


class Outline:
    "An outline limits the extent of a surface in the YZ plane"

    def contains(self, points, tol=1e-6):
        raise NotImplementedError

    def max_radius(self):
        "Furthest distance to the X axis that's within the outline"
        raise NotImplementedError


class SquareOutline(Outline):
    "Square outline around the X axis"

    def __init__(self, side_length):
        self.side_length = side_length

    def contains(self, points):
        return torch.logical_and((
            torch.abs(points[:, 1]) < self.side_length / 2,
            torch.abs(points[:, 2]) < self.side_length / 2))

    def max_radius(self):
        return math.sqrt(2) * self.side_length / 2


class CircularOutline(Outline):
    "Fixed distance to the X axis"

    def __init__(self, diameter):
        self.diameter = diameter

    def contains(self, points, tol=1e-6):
        return torch.hypot(points[:, 1], points[:, 2]) <= self.diameter / 2

    def max_radius(self):
        return self.diameter / 2


class LocalSurface3D:
    """
    A local surface defines a 3D surface in a local reference frame.
    """

    def __init__(self, outline):
        self.outline = outline

    def local_collide(self, P, V):
        """
        Find collision points and surface normals of ray-surface intersection
        for parametric rays P+tV expressed in the surface local frame.

        Returns:
            t: Value of parameter t such that P + tV is on the surface
            normals: Normal vectors to the surface at the collision points
        """
        raise NotImplementedError

    def extent(self):
        """
        Extent along the X axis
        i.e. X coordinate of the point on the surface such that |X| is maximized
        """
        raise NotImplementedError

    def contains(self, points, tol=1e-6):
        raise NotImplementedError


class Plane(LocalSurface3D):
    "X=0 plane"

    def __init__(self, outline):
        self.outline = outline

    def samples2D(self, N):
        r = torch.linspace(0, self.outline.max_radius(), N)
        return torch.stack(
            (torch.zeros(N), r), dim=-1
        )

    def local_collide(self, P, V):
        t = -P[:, 0] / V[:, 0]
        local_points = P + t.unsqueeze(1).expand((-1, 3)) * V
        local_normals = torch.tile(torch.tensor([-1.0, 0.0, 0.0]), (P.shape[0], 1))
        return t, local_normals

    def extent(self):
        return torch.zeros(1)

    def contains(self, points, tol=1e-6):
        return torch.logical_and(self.outline.contains(points),
                                 torch.abs(points[:, 0]) < tol)


class SquarePlane(Plane):
    def __init__(self, side_length):
        super().__init__(SquareOutline(side_length))


class CircularPlane(Plane):
    "aka disk"
    def __init__(self, diameter):
        super().__init__(CircularOutline(diameter))


class ImplicitSurface3D(LocalSurface3D):
    """
    Surface3D defined in implicit form: F(x,y,z) = 0
    """

    def __init__(self, outline):
        self.outline = outline

    def contains(self, points, tol=1e-6):
        return torch.logical_and(self.outline.contains(points),
                                 torch.abs(self.F(points[:, 0], points[:, 1], points[:, 3])) < tol)
                
    def local_collide(self, P, V):

        # Initial guess is the intersection of rays with the X=0 plane
        init_t = -P[:, 0] / V[:, 0]
        
        t = intersect_newton_3D(self, P, V, init_t)

        local_points = P + t.unsqueeze(1).expand((-1, 3)) * V
        local_normals = self.F_grad(local_points)

        # If there is no intersection, newton's method won't converge
        # and points will not be on the surface
        # So verify intersection here and filter points
        # that aren't on the surface
        # TODO

        return t, local_normals

    
    def F(self, points):
        """
        Implicit equation for the 3D shape: F(x,y,z) = 0

        Args:
            points: tensor of shape (N, 3) where columns are X, Y, Z coordinates and N is the batch dimension

        Returns:
            F: value of F at the given points, tensor of shape (N,)
        """
        raise NotImplementedError

    def F_grad(self, points):
        """
        Gradient of F

        Args:
            points: tensor of shape (N, 3) where columns are X, Y, Z coordinates and N is the batch dimension

        Returns:
            F_grad: value of the gradient of F at the given points, tensor of shape (N, 3)
        """
        raise NotImplementedError


class Parabola(ImplicitSurface3D):
    def __init__(self, diameter, a):
        super().__init__(CircularOutline(diameter))
        self.a = a

    def samples2D(self, N):
        """
        Generate N sample points located on the shape's curve with r >= 0
        """

        r = torch.linspace(0, self.outline.max_radius(), N)
        x = self.a * r**2
        return torch.stack((x, r), dim=-1)

    def extent(self):
        r = self.outline.max_radius()
        return torch.as_tensor(self.a * r**2)

    def f(self, x, r):
        return self.a * torch.pow(r, 2) - x

    def f_grad(self, x, r):
        return torch.stack((-torch.ones_like(x), 2 * self.a * r), dim=-1)

    def F(self, points):
        x, y, z = points[:, 0], points[:, 1], points[:, 2]
        return self.a * (y**2 + z**2) - x

    def F_grad(self, points):
        x, y, z = points[:, 0], points[:, 1], points[:, 2]
        return torch.stack(
            (-torch.ones_like(x), 2 * self.a * y, 2 * self.a * z), dim=-1
        )


class Sphere(ImplicitSurface3D):
    def __init__(self, diameter, r):
        super().__init__(CircularOutline(diameter))
        assert (
            torch.abs(torch.as_tensor(r)) >= diameter / 2
        ), f"Sphere diameter ({diameter}) must be less than 2x its arc radius (2x{r}={2*r})"
        self.diameter = diameter
        self.K = 1.0 / r

    def extent(self):
        r = self.outline.max_radius()
        K = self.K
        return (K * r) / (1 + torch.sqrt(1 - r * K**2))

    def samples2D(self, N):
        K = self.K
        r = torch.linspace(0, self.outline.max_radius(), N)
        x = (K * r**2) / (1 + torch.sqrt(1 - r**2 * K**2))
        return torch.stack((x, r), dim=-1)

    def F(self, points):
        x, y, z = points[:, 0], points[:, 1], points[:, 2]
        K = self.K
        r2 = y**2 + z**2
        return (K * r2) / (1 + torch.sqrt(1 - r2 * K**2)) - x

    def F_grad(self, points):
        x, y, z = points[:, 0], points[:, 1], points[:, 2]
        K = self.K
        r2 = y**2 + z**2
        denom = torch.sqrt(1 - r2 * K**2)
        return torch.stack(
            (-torch.ones_like(x), (K * y) / denom, (K * z) / denom), dim=-1
        )


def homogeneous_transform_matrix4(A, B):
    "Homogeneous 4x4 transform matrix for 3D transform AX+B"
    rows = torch.cat((A, B.unsqueeze(0).T), dim=1)
    return torch.cat((rows, torch.tensor([[0.0, 0.0, 0.0, 1.0]])), dim=0)


class BaseTransform:
    def direct_vectors(self, vectors):
        "Apply the transform to vectors"
        raise NotImplementedError

    def inverse_points(self, surface, points):
        "Apply the inverse transform to points"
        raise NotImplementedError
    
    def inverse_rays(self, P, V, surface):
        "Apply the inverse transform to rays"
        raise NotImplementedError

    def matrix4(self, surface):
        "Homogeneous coordinates 4x4 matrix representing the transform"
        return NotImplementedError


class SurfaceTransform(BaseTransform):
    """
    Transform of the form X' = RS(X - A) + T
    where A is a surface anchor point determined by the surface shape
    """

    def __init__(self, scale, anchor, rotations, position):
        self.anchor = anchor

        # scale matrix
        self.S = torch.tensor([[scale, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
        self.S_inv = torch.tensor(
            [[1.0 / scale, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
        )

        # rotation matrix
        self.R = euler_angles_to_matrix(
            torch.deg2rad(torch.as_tensor(rotations)), "XYZ"
        )
        self.R_inv = self.R.T

        # position translation
        self.T = torch.as_tensor(position)

    def anchor_point(self, surface):
        "Get position of anchor of surface"
        if self.anchor == "origin":
            return torch.zeros(3)
        elif self.anchor == "extent":
            return torch.cat(
                (torch.atleast_1d(surface.extent()), torch.zeros(2)), dim=0
            )
        else:
            raise ValueError

    def direct_vectors(self, V):
        return (self.R @ self.S @ V.T).T

    def inverse_points(self, surface, P):
        S_inv, R_inv, T = self.S_inv, self.R_inv, self.T
        A = self.anchor_point(surface)
        return (S_inv @ R_inv @ (P-T).T).T + A

    def inverse_rays(self, P, V, surface):
        S_inv, R_inv, T = self.S_inv, self.R_inv, self.T
        A = self.anchor_point(surface)

        Ps = (S_inv @ R_inv @ (P-T).T).T + A
        Vs = (S_inv @ R_inv @ V.T).T

        return Ps, Vs

    def matrix4(self, surface):
        hom = homogeneous_transform_matrix4
        return hom(self.R @ self.S, self.T) @ hom(torch.eye(3), -self.anchor_point(surface))


def newton_delta(surface, P, V, t):
    "Compute the delta for one step of Newton's method"
    
    points = P + t.unsqueeze(1).expand_as(V) * V

    F = surface.F(points)
    F_grad = surface.F_grad(points)

    # Denominator will be zero if F_grad and V are orthogonal
    denom = torch.sum(F_grad * V, dim=1)

    return F / denom


def intersect_newton_3D(surface, P, V, init_t):
    """
    Surface-Ray collision detection in 3D using Newton's method

    Args:
        P: rays origin points
        V: rays unit vectors
        init_t: initial value for t

    Returns:
        t: tensor of t values such that P+tV are the collision points
           or nan if no collision is found
    """

    assert isinstance(P, torch.Tensor) and P.dim() == 2
    assert isinstance(V, torch.Tensor) and V.dim() == 2
    assert P.shape[0] == V.shape[0]
    assert P.shape[1] == V.shape[1] == 3

    # Initialize solutions t
    t = init_t

    with torch.no_grad():
        for _ in range(20):  # TODO parameters for newton iterations
            # TODO warning if stopping due to max iter (didn't converge)
            delta = newton_delta(surface, P, V, t)
            # TODO early stop if delta is small enough
            t = t - delta

    # One newton iteration for backwards pass
    t = t - newton_delta(surface, P, V, t)
    
    return t



def surface_to_json(surface, matrix4):
    N = 100
    samples = surface.samples2D(N)

    obj = {"matrix": matrix4.tolist(), "samples": samples.tolist()}
    
    # outline
    if isinstance(surface.outline, SquareOutline):
        obj["side_length"] = surface.outline.side_length

    return obj


def rays_to_json(rays, length):
    rays_start = rays[:, :3]
    rays_end = rays_start + length * rays[:, 3:]
    return torch.hstack((rays_start, rays_end)).tolist()


def render(points, normals):
    groups = []

    groups.append(
        {
            "type": "surfaces",
            "data": [surface_to_json(s, t.matrix4(s)) for t, s in test_surfaces],
        }
    )

    groups.append(
        {
            "type": "rays",
            "data": rays_to_json(test_rays, 150),
        }
    )

    groups.append(
        {
            "type": "points",
            "data": points.tolist(),
            "color": "#ff0000",
        }
    )

    groups.append(
        {
            "type": "arrows",
            "data": [n.tolist() + p.tolist() + [1.0] for p, n in zip(points, normals)],
        }
    )

    # pprint.pprint(data)

    tlm.viewer(groups)


def intersect(surface, P, V, transform):
    """
    Surface-rays collision detection

    Find collision points and normal vectors for the intersection of rays P+tV with
    a surface and a transform applied to that surface.

    Args:
        P: (N,3) tensor, rays origins
        V: (N, 3) tensor, rays vectors
        surface: surface to collide with
        transform: transform applied to the surface

    Returns:
        points: collision points
        normals: surface normals at the collision points
    """

    # Convert rays to surface local frame
    Ps, Vs = transform.inverse_rays(P, V, surface)

    # Collision detection in the surface local frame
    t, local_normals = surface.local_collide(Ps, Vs)
    
    # t, points, normals, blocked = surface.local_collide(Ps, Vs)

    # Compute collision points and convert normals to global frame
    points = P + t.unsqueeze(1).expand((-1, 3)) * V
    normals = transform.direct_vectors(local_normals)

    return points, normals


def make_random_rays(num_rays, start_x, end_x, max_y):
    rays_start = (torch.rand((num_rays, 3)) * 2 - 1) * max_y
    rays_start[:, 0] = start_x

    rays_end = (torch.rand((num_rays, 3)) * 2 - 1) * max_y
    rays_end[:, 0] = end_x

    rays_vectors = torch.nn.functional.normalize(rays_end - rays_start, dim=1)

    return torch.hstack((rays_start, rays_vectors))


test_rays = make_random_rays(
    num_rays=50,
    start_x=-15,
    end_x=50,
    max_y=6,
)

# debug newton 3D:
# iteration plot of t
# history of collision point

test_surfaces = [
    #(SurfaceTransform(1.0, "origin", [0., 10., 0.], [0., 0., 0.]), Sphere(15.0, 1e6)),
    #(SurfaceTransform(1.0, "origin", [0., 0., 0.], [10., 0., -10.]), Sphere(25.0, 20)),
    #(SurfaceTransform(1.0, "origin", [0., 0., 0.], [20., 20., 0.]), Sphere(15.0, -10)),
    #(SurfaceTransform(1.0, "origin", [0., 0., 0.], [30., 0., 0.]), Parabola(15., -0.05)),
    #(SurfaceTransform(1.0, "origin", [0., 0., 0.], [40., 0., 0.]), Parabola(20., -0.04)),
    #(SurfaceTransform(1.0, "origin", [0., 0., 0.], [50., 0., 0.]), Parabola(30., 0.02)),
    #(SurfaceTransform(1.0, "origin", [0., 10., -10.], [60., 0., 0.]), Parabola(30., 0.05)),
    #(SurfaceTransform(1.0, "origin", [0., 0., 0.], [80., 0., 0.]), Plane(50.)),
    #(SurfaceTransform(1.0, "origin", [0., 0., 0.], [5., 0., -5.]), Plane(15.)),
    
    #(SurfaceTransform(1.0, "origin", [0.0, 10.0, -10.0], [100.0, 0.0, 0.0]), Parabola(30.0, -0.05)),
    #(SurfaceTransform(-1.0, "extent", [0.0, 20.0, -20.0], [100.0, 2.0, 5.0]), Parabola(20.0, 0.05)),

    #(SurfaceTransform(1.0, "origin", [0.0, 10.0, -10.0], [100.0, 0.0, 0.0]), Parabola(30.0, 0.05)),
    #(SurfaceTransform(1.0, "extent", [0.0, 20.0, -20.0], [100.0, 2.0, 5.0]), Parabola(20.0, 0.05)),

    #(SurfaceTransform(1.0, "extent", [0.0,  0.0, 0.0], [50.0, 5.0, 5.0]), Parabola(30.0, 0.05)),
    #(SurfaceTransform(1.0, "extent", [0.0, 10.0, 0.0], [50.0, 5.0, 5.0]), Parabola(30.0, 0.05)),
    #(SurfaceTransform(1.0, "extent", [0.0, 20.0, 0.0], [50.0, 5.0, 5.0]), Parabola(30.0, 0.05)),
    #(SurfaceTransform(1.0, "extent", [0.0, 30.0, 0.0], [50.0, 5.0, 5.0]), Parabola(30.0, 0.05)),
    #(SurfaceTransform(1.0, "extent", [0.0, 40.0, 0.0], [50.0, 5.0, 5.0]), Parabola(30.0, 0.05)),

    (SurfaceTransform(1.0, "origin", [0.0, 0.0, 0.0], [10.0, 0.0, 0.0]), Parabola(30., 0.05)),
    (SurfaceTransform(1.0, "origin", [0.0, 0.0, 0.0], [10.0, 0.0, 0.0]), SquarePlane(30.))
]

# TODO
# better handling of out of domain / no collision
# more testing of transforms
# test 3D refraction / reflection


def demo(rays):

    all_points = torch.empty((0, 3))
    all_normals = torch.empty((0, 3))
    P, V = test_rays[:, :3], test_rays[:, 3:6]

    for transform, surface in test_surfaces:

        points, normals = intersect(surface, P, V, transform)

        # filter remove nan and inf, TODO better way to handle no-collision?
        keep = ~torch.any(~torch.isfinite(points), dim=1)
        points = points[keep, :]
        normals = normals[keep, :]

        if points.numel() > 0:
            all_points = torch.cat((all_points, points), dim=0)
            all_normals = torch.cat((all_normals, normals), dim=0)

    render(all_points, all_normals)


demo(test_rays)