# side quest

## Sampling dimensions

An optical systems is defined with a model that represents a physical reality where there are infinitely many rays of light emanating from light sources.

In practice, when simulating, we need to decide how many rays we use for computation. This is called **sampling** and the number of rays is also called the number of samples.

This is not to say that rays exist for all possible values for the ray variables. Sometimes they follow a particular distribution, like wavelength can be between 400 and 600nm only ; angular distribution can be restricted to [-10, 10]°, and so on. But those distributions are continuous, and infinitly many rays can follow them.

A system defines multiple dimensions, corresponding to different continuous variables that uniquely describe a ray. For example, the coordinate of the source point of a ray is one possible dimension. Another one is the wavelength of the ray.

* Base dimension: This dimension is always present unless we simulate a single ray. It's the dimension along which we sample when all other variables are fixed.

-> rename to angular? at infinity angular becomes linear, but really it's still the emanating angle of the ray from the object

can actually not exist if simulating a single ray
  
* Object dimension
* Wavelength dimension
* System configuration

Additionally, the number of spatial dimension can also be 2 or 3. When doing a 2D simulation, rays are restricted to a single meridional plane and don't have a z coordinate. This simplifies the simulation considerably, but at the price of ignoring skew rays.

Note that the meridional plane in the 2D simulation is abstract and not necessarily the Z=0 plane. If the system is not rotationally symmetric, 2D mode is not possible.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

from dataclasses import dataclass, replace

from torchlensmaker.tensorframe import TensorFrame

from torchlensmaker.rot2d import rot2d
from torchlensmaker.rot3d import euler_angles_to_matrix

from typing import Any

Tensor = torch.Tensor

@dataclass
class NewOpticalData:
    # sampling information
    sampling: dict[str, Any]

    # transform kinematic chain
    transforms: list[tlm.TransformBase]

    rays: TensorFrame


def default_input(sampling):
    dim, dtype = sampling["dim"], sampling["dtype"]
    if dim == 2:
        rays = TensorFrame(torch.empty((0, 4)), columns = ["RX", "RY", "VX", "VY"])
    else:
        rays = TensorFrame(torch.empty((0, 6)), columns = ["RX", "RY", "RZ", "VX", "VY", "VZ"])

    return NewOpticalData(
        sampling=sampling,
        transforms=[tlm.IdentityTransform(dim, dtype)],
        rays=rays,
    )

    

# light sources have to be in the optical sequence
# but not necessarily on the kinematic chain

class PointSourceAtInfinity(nn.Module):
    """
    A simple light source that models a perfect point at infinity.

    All rays are parallel with possibly some incidence angle
    """

    def __init__(self, beam_diameter, angle1=0., angle2=0.):
        """
        beam_diameter: diameter of the beam of parallel light rays
        angle1: angle of indidence with respect to the principal axis, in degrees
        angle2: second angle of incidence used in 3D

        samples along the base sampling dimension
        """

        super().__init__()
        self.beam_diameter = torch.as_tensor(beam_diameter, dtype=torch.float64)
        self.angle1 = torch.deg2rad(torch.as_tensor(angle1, dtype=torch.float64))
        self.angle2 = torch.deg2rad(torch.as_tensor(angle2, dtype=torch.float64))

    def forward(self, inputs):
        # Create new rays by sampling the beam diameter
        dim, dtype = inputs.sampling["dim"], inputs.sampling["dtype"]
        num_rays = sampling["base"]
        margin = 0.1  # TODO

        # rays origins
        D = self.beam_diameter
        RY = torch.linspace(-D / 2 + margin, D / 2 - margin, num_rays)

        if dim == 3:
            RZ = RY.clone()

        if dim == 2:
            RX = torch.zeros(num_rays)
            rays_origins = torch.column_stack((RX, RY))
        else:
            RX = torch.zeros(num_rays*num_rays)
            prod = torch.cartesian_prod(RY, RZ)
            rays_origins = torch.column_stack((RX, prod[:, 0], prod[:, 1]))
        
        # rays vectors
        if dim == 2:
            V = torch.tensor([1.0, 0.0], dtype=dtype)
        
            vect = rot2d(V, self.angle1)
        else:
            V = torch.tensor([1.0, 0.0, 0.0], dtype=dtype)
            M = euler_angles_to_matrix(
                torch.deg2rad(torch.as_tensor([0., self.angle1, self.angle2], dtype=dtype)), "ZYX"
            ).to(
                dtype=dtype
            )  # TODO need to support dtype in euler_angles_to_matrix
            vect = V @ M
        
        assert vect.dtype == dtype


        if dim == 2:
            rays_vectors = torch.tile(vect, (num_rays, 1))
        else:
            # use the same base dimension twice here
            # TODO could define different ones
            rays_vectors = torch.tile(vect, (num_rays * num_rays, 1))

        # transform sources to the chain target
        transform = tlm.forward_kinematic(inputs.transforms)
        rays_origins = transform.direct_points(rays_origins)
        rays_vectors = transform.direct_vectors(rays_vectors)

        # normalized coordinate along the base dimension
        #coord_base = (RY + self.beam_diameter / 2) / self.beam_diameter

        if dim == 2:
            new_rays = TensorFrame(
                torch.cat((rays_origins, rays_vectors), dim=1),
                columns=["RX", "RY", "VX", "VY"],
            )
        else:
            assert rays_origins.shape[1] == 3, rays_origins.shape
            assert rays_vectors.shape[1] == 3, rays_vectors.shape
            new_rays = TensorFrame(
                torch.cat((rays_origins, rays_vectors), dim=1),
                columns=["RX", "RY", "RZ", "VX", "VY", "VZ"],
            )

        return replace(inputs, rays=inputs.rays.stack(new_rays))



class NewOpticalSurface(nn.Module):
    def __init__(self, surface, scale=1., anchors=("origin", "origin")):
        super().__init__()
        self.surface = surface
        self.scale = scale
        self.anchors = anchors

        # mode 1:  inline / chained / affecting
        # surface transform = input.transforms - anchor + scale
        # output transform = input.transforms - first anchor + second anchor
        
        # mode 2: offline / free / independent
        # surface transform = input.transforms + local transform - anchor + scale
        # output transform = input.transforms

        # how to support absolute position on chain?

        # RS(X - A) + T
        # surface transform(X) = CSX - A
        # surface transform = anchor1 + scale + chain
        # output transform = chain + anchor1 + anchor2
        
    def surface_transform(self, dim, dtype) -> list[tlm.TransformBase]:
        "Transform chain that applies to the underlying surface"

        S = self.scale * torch.eye(dim, dtype=dtype)
        S_inv = 1./self.scale * torch.eye(dim, dtype=dtype)
        scale = [tlm.LinearTransform(S, S_inv)]

        if self.anchors[0] == "extent":
            T = -self.scale * torch.cat(
                    (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
                )
            anchor = [tlm.TranslateTransform(T)]
        else:
            anchor = []

        return anchor + scale

    def output_transform(self, dim, dtype) -> list[tlm.TransformBase]:
        "Transform chain that applies to the next element"

        # subtract first anchor, add second anchor
        
        if self.anchors[0] == "extent":
            Ta = - self.scale * torch.cat(
                (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
            )
            anchor1 = [tlm.TranslateTransform(Ta)]
        else:
            anchor1 = []

        if self.anchors[1] == "extent":
            Tb = self.scale * torch.cat(
                (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
            )
            anchor2 = [tlm.TranslateTransform(Tb)]
        else:
            anchor2 = []
        
        return anchor1 + anchor2


    def forward(self, inputs):
        dim, dtype = inputs.sampling["dim"], inputs.sampling["dtype"]

        _surface_transform = self.surface_transform(dim, dtype)

        # TODO compute rays here

        # intersect rays with surface
        
        
        output_transform = self.output_transform(dim, dtype)
        
        
        return replace(inputs, transforms=inputs.transforms + output_transform)


class NewGap(nn.Module):
    def __init__(self, offset: float | int | Tensor):
        super().__init__()
        assert(isinstance(offset, (float, int, torch.Tensor)))
        if isinstance(offset, torch.Tensor):
            assert(offset.dim() == 0)

        # Gap is always stored as float64, but it's converted to the sampling
        # dtype when creating the corresponding transform in forward()
        self.offset = torch.as_tensor(offset, dtype=torch.float64)

    def forward(self, inputs):
        dim, dtype = inputs.sampling["dim"], inputs.sampling["dtype"]
        
        translate_vector = torch.cat((self.offset.unsqueeze(0).to(dtype=dtype), torch.zeros(dim-1, dtype=dtype)))
        
        return replace(inputs, transforms=inputs.transforms + [tlm.TranslateTransform(translate_vector)])


class SurfaceArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, inputs, _outputs):

        dim, dtype = inputs.transforms[0].dim, inputs.transforms[0].dtype
        chain = inputs.transforms + element.surface_transform(dim, dtype)
        transform = tlm.forward_kinematic(chain)

        # TODO find a way to group surfaces together?
        return tlm.viewer.render_surfaces([element.surface], [transform], dim=transform.dim, N=10)

    @staticmethod
    def render_rays(element, inputs, outputs):

        # TODO dim, dtype as argument
        dim, dtype = inputs.transforms[0].dim, inputs.transforms[0].dtype

        if dim == 2:
            rays_start = inputs.rays.get(["RX", "RY"])
            rays_vectors = inputs.rays.get(["VX", "VY"])
        else:
            rays_start = inputs.rays.get(["RX", "RY", "RZ"])
            rays_vectors = inputs.rays.get(["VX", "VY", "VZ"])
        
        rays_end = rays_start + 10*rays_vectors
        
        return tlm.viewer.render_rays(rays_start, rays_end)


class JointArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, inputs, _outputs):

        dim, dtype = inputs.transforms[0].dim, inputs.transforms[0].dtype
        transform = tlm.forward_kinematic(inputs.transforms)
        joint = transform.direct_points(torch.zeros((dim,), dtype=dtype))
                                                                          
        return {"type": "points",
                "data": [joint.tolist()]}
    
artists_dict = {
    NewOpticalSurface: SurfaceArtist,
}

def inspect_stack(execute_list):
    for module, inputs, outputs in execute_list:
        print(type(module))
        print("inputs.transform:")
        for t in inputs.transforms:
            print(t)
        print()
        print("outputs.transform:")
        for t in outputs.transforms:
            print(t)
        print()

def render_sequence(optics, sampling):
    dim, dtype = sampling["dim"], sampling["dtype"]
    execute_list, top_output = tlm.full_forward(optics, default_input(sampling))

    scene = tlm.viewer.new_scene("2D" if dim == 2 else "3D")

    #inspect_stack(execute_list)

    for module, inputs, outputs in execute_list:
        
        # render chain join position for every module
        scene["data"].append(JointArtist.render_element(module, inputs, outputs))

        # find matching artists for this module, use the first one for rendering
        artists = [a for typ, a in artists_dict.items() if isinstance(module, typ)]

        if len(artists) > 0:
            artist = artists[0]
            group = artist.render_element(module, inputs, outputs)
            scene["data"].append(group)
    
            if inputs.rays.numel() > 0:
                group = artist.render_rays(module, inputs, outputs)
                scene["data"].append(group)

    return scene


def view(optics, sampling):
    scene = render_sequence(optics, sampling)
    #tlm.viewer.dump(scene, ndigits=2)
    tlm.viewer.show(scene)

surface = tlm.surfaces.Parabola(15.0, 0.020)
surface2 = tlm.surfaces.Parabola(15.0, 0.030)
surface3 = tlm.surfaces.Sphere(12, -20)

optics = nn.Sequential(
    PointSourceAtInfinity(10., 0),
    NewGap(5.0),
    NewOpticalSurface(surface, anchors=("origin", "extent")),
    NewGap(1.0),
    NewOpticalSurface(surface2, scale=-1, anchors=("extent", "origin")),
    NewGap(0.5),
    NewOpticalSurface(surface3, anchors=("extent", "extent")),
    NewGap(5.0),
    NewOpticalSurface(surface3, scale=-1, anchors=("extent", "extent")),
)


sampling = {"dim": 2, "dtype": torch.float64, "base": 10}
view(optics, sampling)

sampling = {"dim": 3, "dtype": torch.float64, "base": 10}
view(optics, sampling)