# Math foundation

## Ray - surface intersection in 3D

### Definitions

A 3D ray is represented parametrically as the set of points $P + tV$ with $t \in \mathbb{R}$, where:

* $P$ is an origin point
* $V$ is a unit direction vector

A 3D surface is defined with a function $F$ of the form: $F(x,y,z) = 0$.

To compute the intersection of a ray with a surface, we are looking for the unknown $t$ such that a point is on the surface and on the ray:

$$
F(P + tV) = 0
$$

We call the quantity on the left side $Q(t)$, and finding the intersections is finding the roots of $Q$.

### Newton's method

In reality there may be multiple or no) solutions. In practice we are going to solve it iteratively so this issue is minimized.

We can solve the intersection equation using Newton's method with the update step:

$$
t_{n+1} = t_n - \frac{Q(t_n)}{Q'(t_n)}
$$

Developing with the multivariate chain rule we get:

$$
t_{n+1} = t_n - \frac{F(P + t_nV)}{V \cdot \nabla F(P + t_n V)}
$$

where "$\cdot$" in the denominator is the dot product, and $\nabla F$ is the gradient of F:

$$
\nabla F = \left(\frac{dF}{dx}, \frac{dF}{dy}, \frac{dF}{dz}\right)
$$

A 3D surface must define functions to compute $F$ and $\nabla F$.

$Q'$ is zero, and the update step is undefined, when the ray and the surface derivative are parallel, or when the norm of the derivative is zero (which should never be the case).

## Surface of revolution

### 2D shape definition

A 2D shape is defined in the $(x, r)$ plane with the implicit equation:
$$
f(x, r) = 0
$$

X is the optical axis, and R is the meridional axis, aka the perpendiular to the optical axis.

Additionally, for working with lenses, we always have $f(0, 0) = 0$ (the curve crosses the origin) and $f'_r(0, r) = 0$ (the curve is vertical at the origin).

### Rotational symmetry

We want to create the corresponding 3D surface by rotation around the X axis. This means that the R axis from before, is now really the axis of the meridional plane. (A meridional plane is a plane that contains the optical axis X).

So $r$ is the distance from the X axis to any point on the surface. In 3D, we have for any meridional plane $r = \sqrt{y^2 + z^2}$.

The definition of the 3D surface of revolution is that the intersection of every meridional plane with it is the 2D surface:

$$
F(x,y,z) = f \left(x,  \sqrt{y^2 + z^2} \right) = 0
$$

Often the form $f(x, \sqrt{y^2 + z^2})$ can be simplified analytically to provide an efficient implementation of $F$.

## Generic form of $\nabla F$ for surfaces of revolution

We have:

$$
F(x, y, z) = f \left(x,  \sqrt{y^2 + z^2} \right)
$$

Therefore:

$$
F'_x(x, y, z) = f'_x(x, \sqrt{y^2 + z^2})
$$
$$
F'_y(x,y,z) = \frac{y}{\sqrt{y^2 + z^2}} f_r' \left(x, \sqrt{y^2 + z^2} \right)
$$
$$
F'_z(x,y,z) = \frac{z}{\sqrt{y^2 + z^2}} f_r' \left(x, \sqrt{y^2 + z^2} \right)
$$

However, for some curves $f$ this expression simplifies a lot and therefore shapes can provide an optimized version of $\nabla F(x, y, z)$, or even $\nabla F(x,y,z) \cdot V$.



## Collision detection with a 3D transform

Surfaces are defined on a local reference frame so that $F(0, 0, 0) = 0$. But what if we want to apply a transform to move it in 3D space? Can we apply scaling, rotation, translation?

Let's assume our 3D transform $T$ is affine invertible and produces points $X'$ given input points $X$ such that $T: X' = AX + B$.

Let's consider some points $X'$ on the new transformed surface, by definition undoing the transform would put them back on the original surface:

$$
F(T^{-1}(X')) = F(A^{-1}(X' - B)) = 0
$$

Given a parametric 3D ray: $P + tV$, finding the intersection with a transformed 3D surface is therefore solving:

$$
F( A^{-1}(P-B) + tA^{-1}V ) = 0
$$

which is useful because we can use the previous Newton solver aproach by applying the inverse transform to the rays, and using the $F$ function defined locally:

$$
\begin{cases}
P' = A^{-1}(P - B)\\
V' = A^{-1}V
\end{cases}
$$

In the common case, $A^{-1}$ can be computed without matrix inversion because it's the product of a rotation and a scaling, each can be easily inverted.

Note that this applies even if the surface is not defined implicitly: we can find collisions with the transformed surface by applying the above inverse transform to the rays and calling the local collision detection code.

Another thing we need to do is convert vectors from the surface local frame, to the global frame, typically surface normals.

A vector $\overrightarrow{N}$ is the difference between its end point $E$ and start point $S$:

$$
\overrightarrow{N} = E - S
$$

So to transform the vector under the affine transformation, we can take the difference of its transformed endpoints:
$$
T(\overrightarrow{N}) = T(E) - T(S)
$$

So after simplifying we get:
$$
T(\overrightarrow{N}) = A(E - S) = A\overrightarrow{N}
$$

## Adding anchors

Similarly as above, it can be useful to add a translation step before the rotation, to model an "anchor". The anchor is the point on the shape that attaches to the global frame. So, our full transform is now four steps:

1. A translation $-A$ to account for the anchor
2. A scale $S$
3. A rotation $R$
4. A translation $T$ to position the shape in the global frame

$$
X' = RS(X - A) + T
$$

The inverse transform is:

$$
X = S^{-1}R^{-1}(X' - T) + A
$$

When $X'$ (the points on the transformed surface) and also the collision point with parametric rays $P + tV$ we can substitue and get:

$$
S^{-1} R^{-1} (P-T) + A + t S^{-1}R^{-1}V
$$

And so we can compute "inverse transformed" rays:

$$
\begin{cases}
P' = S^{-1} R^{-1}(P-T) + A\\
V' = S^{-1}R^{-1}V
\end{cases}
$$

Direct transform of vectors is:
$$
T(\overrightarrow{N}) = RS\overrightarrow{N}
$$

In [4]:
import torchlensmaker as tlm
import torch
import torch.nn
import math

from torchlensmaker.rot3d import euler_angles_to_matrix

import pprint





def homogeneous_transform_matrix4(A, B):
    "Homogeneous 4x4 transform matrix for 3D transform AX+B"
    rows = torch.cat((A, B.unsqueeze(0).T), dim=1)
    return torch.cat((rows, torch.tensor([[0.0, 0.0, 0.0, 1.0]])), dim=0)


class BaseTransform:
    def direct_vectors(self, vectors):
        "Apply the transform to vectors"
        raise NotImplementedError

    def inverse_points(self, surface, points):
        "Apply the inverse transform to points"
        raise NotImplementedError
    
    def inverse_rays(self, P, V, surface):
        "Apply the inverse transform to rays"
        raise NotImplementedError

    def matrix4(self, surface):
        "Homogeneous coordinates 4x4 matrix representing the transform"
        return NotImplementedError


class SurfaceTransform(BaseTransform):
    """
    Transform of the form X' = RS(X - A) + T
    where A is a surface anchor point determined by the surface shape
    """

    def __init__(self, scale, anchor, rotations, position):
        self.anchor = anchor

        # scale matrix
        self.S = torch.tensor([[scale, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
        self.S_inv = torch.tensor(
            [[1.0 / scale, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
        )

        # rotation matrix
        self.R = euler_angles_to_matrix(
            torch.deg2rad(torch.as_tensor(rotations)), "XYZ"
        )
        self.R_inv = self.R.T

        # position translation
        self.T = torch.as_tensor(position)

    def anchor_point(self, surface):
        "Get position of anchor of surface"
        if self.anchor == "origin":
            return torch.zeros(3)
        elif self.anchor == "extent":
            return torch.cat(
                (torch.atleast_1d(surface.extent()), torch.zeros(2)), dim=0
            )
        else:
            raise ValueError

    def direct_vectors(self, V):
        return (self.R @ self.S @ V.T).T

    def inverse_points(self, surface, P):
        S_inv, R_inv, T = self.S_inv, self.R_inv, self.T
        A = self.anchor_point(surface)
        return (S_inv @ R_inv @ (P-T).T).T + A

    def inverse_rays(self, P, V, surface):
        S_inv, R_inv, T = self.S_inv, self.R_inv, self.T
        A = self.anchor_point(surface)

        Ps = (S_inv @ R_inv @ (P-T).T).T + A
        Vs = (S_inv @ R_inv @ V.T).T

        return Ps, Vs

    def matrix4(self, surface):
        hom = homogeneous_transform_matrix4
        return hom(self.R @ self.S, self.T) @ hom(torch.eye(3), -self.anchor_point(surface))




def surface_to_json(surface, matrix4):
    N = 100
    samples = surface.samples2D(N)

    obj = {"matrix": matrix4.tolist(), "samples": samples.tolist()}
    
    # outline
    if isinstance(surface.outline, tlm.SquareOutline):
        obj["side_length"] = surface.outline.side_length

    return obj


def rays_to_json(rays, length):
    rays_start = rays[:, :3]
    rays_end = rays_start + length * rays[:, 3:]
    return torch.hstack((rays_start, rays_end)).tolist()


def render(points, normals):
    groups = []

    groups.append(
        {
            "type": "surfaces",
            "data": [surface_to_json(s, t.matrix4(s)) for t, s in test_surfaces],
        }
    )

    groups.append(
        {
            "type": "rays",
            "data": rays_to_json(test_rays, 150),
        }
    )

    groups.append(
        {
            "type": "points",
            "data": points.tolist(),
            "color": "#ff0000",
        }
    )

    groups.append(
        {
            "type": "arrows",
            "data": [n.tolist() + p.tolist() + [1.0] for p, n in zip(points, normals)],
        }
    )

    # pprint.pprint(data)

    tlm.viewer(groups)


def intersect(surface, P, V, transform):
    """
    Surface-rays collision detection

    Find collision points and normal vectors for the intersection of rays P+tV with
    a surface and a transform applied to that surface.

    Args:
        P: (N,3) tensor, rays origins
        V: (N, 3) tensor, rays vectors
        surface: surface to collide with
        transform: transform applied to the surface

    Returns:
        points: collision points
        normals: surface normals at the collision points
    """

    # Convert rays to surface local frame
    Ps, Vs = transform.inverse_rays(P, V, surface)

    # Collision detection in the surface local frame
    t, local_normals, valid = surface.local_collide(Ps, Vs)

    # Compute collision points and convert normals to global frame
    points = P + t.unsqueeze(1).expand((-1, 3)) * V
    normals = transform.direct_vectors(local_normals)

    # remove non valid (non intersecting) points
    # do this before computing global frame?
    points = points[valid]
    normals = normals[valid]

    return points, normals


def make_random_rays(num_rays, start_x, end_x, max_y):
    rays_start = (torch.rand((num_rays, 3)) * 2 - 1) * max_y
    rays_start[:, 0] = start_x

    rays_end = (torch.rand((num_rays, 3)) * 2 - 1) * max_y
    rays_end[:, 0] = end_x

    rays_vectors = torch.nn.functional.normalize(rays_end - rays_start, dim=1)

    return torch.hstack((rays_start, rays_vectors))


test_rays = make_random_rays(
    num_rays=50,
    start_x=-15,
    end_x=50,
    max_y=20,
)

# debug newton 3D:
# iteration plot of t
# history of collision point

test_surfaces = [
    (SurfaceTransform(1.0, "origin", [0., 10., 0.], [0., 0., 0.]), tlm.surfaces.Sphere(15.0, 1e6)),
    (SurfaceTransform(1.0, "origin", [0., 0., 0.], [10., 0., -10.]), tlm.surfaces.Sphere(25.0, 20)),
    (SurfaceTransform(1.0, "origin", [0., 0., 0.], [20., 20., 0.]), tlm.surfaces.Sphere(15.0, -10)),
    (SurfaceTransform(1.0, "origin", [0., 0., 0.], [30., 0., 0.]), tlm.surfaces.Parabola(15., -0.05)),
    (SurfaceTransform(1.0, "origin", [0., 0., 0.], [40., 0., 0.]), tlm.surfaces.Parabola(20., -0.04)),
    (SurfaceTransform(1.0, "origin", [0., 0., 0.], [50., 0., 0.]), tlm.surfaces.Parabola(30., 0.02)),
    (SurfaceTransform(1.0, "origin", [0., 10., -10.], [60., 0., 0.]), tlm.surfaces.Parabola(30., 0.05)),
    (SurfaceTransform(1.0, "origin", [0., 0., 0.], [80., 0., 0.]), tlm.surfaces.CircularPlane(50.)),
    (SurfaceTransform(1.0, "origin", [0., 0., 0.], [5., 0., -5.]), tlm.surfaces.SquarePlane(15.)),
    
    (SurfaceTransform(1.0, "origin", [0.0, 10.0, -10.0], [100.0, 0.0, 0.0]), tlm.surfaces.Parabola(30.0, -0.05)),
    (SurfaceTransform(-1.0, "extent", [0.0, 20.0, -20.0], [100.0, 2.0, 5.0]), tlm.surfaces.Parabola(20.0, 0.05)),

    (SurfaceTransform(1.0, "origin", [0.0, 10.0, -10.0], [100.0, 0.0, 0.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (SurfaceTransform(1.0, "extent", [0.0, 20.0, -20.0], [100.0, 2.0, 5.0]), tlm.surfaces.Parabola(20.0, 0.05)),

    (SurfaceTransform(1.0, "extent", [0.0,  0.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (SurfaceTransform(1.0, "extent", [0.0, 10.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (SurfaceTransform(1.0, "extent", [0.0, 20.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (SurfaceTransform(1.0, "extent", [0.0, 30.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (SurfaceTransform(1.0, "extent", [0.0, 40.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),

    (SurfaceTransform(1.0, "origin", [0.0, 0.0, 0.0], [10.0, 0.0, 0.0]), tlm.surfaces.Parabola(30., 0.05)),
    (SurfaceTransform(1.0, "origin", [0.0, 0.0, 0.0], [10.0, 0.0, 0.0]), tlm.surfaces.SquarePlane(30.))
]

# TODO
# more testing of transforms
# test 3D refraction / reflection


def demo(rays):

    all_points = torch.empty((0, 3))
    all_normals = torch.empty((0, 3))
    P, V = test_rays[:, :3], test_rays[:, 3:6]

    for transform, surface in test_surfaces:

        points, normals = intersect(surface, P, V, transform)

        if points.numel() > 0:
            all_points = torch.cat((all_points, points), dim=0)
            all_normals = torch.cat((all_normals, normals), dim=0)

    render(all_points, all_normals)


demo(test_rays)