# side quest

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

from dataclasses import dataclass, replace

from torchlensmaker.tensorframe import TensorFrame

from torchlensmaker.raytracing import rot2d

Tensor = torch.Tensor

@dataclass
class NewOpticalData:
    # transform kinematic chain
    transforms: list[tlm.TransformBase]

    rays: TensorFrame


def default_input(dim: int, dtype: torch.dtype):
    if dim == 2:
        rays = TensorFrame(torch.empty((0, 4)), columns = ["RX", "RY", "VX", "VY"])
    else:
        rays = TensorFrame(torch.empty((0, 4)), columns = ["RX", "RY", "RZ", "VX", "VY", "VZ"])

    return NewOpticalData(
        transforms=[tlm.IdentityTransform(dim, dtype)],
        rays=rays,
    )


def forward_kinematic(transforms: list[tlm.TransformBase]) -> Tensor:
        "Target point of a list of transforms chained together"
        
        # Forward kinematic chain
        # Compose transforms in reverse order
        dim, dtype = transforms[-1].dim, transforms[-1].dtype
        return tlm.ComposeTransform(list(reversed(transforms)))


class PointSourceAtInfinity(nn.Module):
    def __init__(self, beam_diameter, angle=0.):
        """
        beam_diameter: diameter of the beam of parallel light rays
        angle: angle of indidence with respect to the principal axis, in degrees

        samples along the base sampling dimension
        """

        super().__init__()
        self.beam_diameter = torch.as_tensor(beam_diameter, dtype=torch.float32)
        self.angle = torch.deg2rad(torch.as_tensor(angle, dtype=torch.float32))

    def forward(self, inputs, sampling):
        # Create new rays by sampling the beam diameter
        num_rays = sampling["rays"]
        margin = 0.1  # TODO
        RX = torch.zeros(num_rays)
        RY = torch.linspace(
            -self.beam_diameter / 2 + margin,
            self.beam_diameter / 2 - margin,
            num_rays,
        )

        transform = forward_kinematic(inputs.transforms)

        rays_origins = torch.column_stack((RX, RY))
        vect = rot2d(torch.tensor([1.0, 0.0]), self.angle)
        rays_vectors = torch.tile(vect, (num_rays, 1))

        # transform sources to the chain target
        rays_origins = transform.direct_points(rays_origins)
        rays_vectors = transform.direct_vectors(rays_vectors)

        # normalized coordinate along the base dimension
        coord_base = (RY + self.beam_diameter / 2) / self.beam_diameter

        new_rays = TensorFrame(
            torch.cat((rays_origins, rays_vectors, coord_base.unsqueeze(1)), dim=1),
            columns=["RX", "RY", "VX", "VY", "rays"],
        )

        return replace(inputs, rays=inputs.rays.stack(new_rays))



class NewOpticalSurface(tlm.Module):
    def __init__(self, surface, scale=1., anchors=("origin", "origin")):
        super().__init__()
        self.surface = surface
        self.scale = scale
        self.anchors = anchors

        # mode 1:  inline / chained / affecting
        # surface transform = input.transforms - anchor + scale
        # output transform = input.transforms - first anchor + second anchor
        
        # mode 2: offline / free / independent
        # surface transform = input.transforms + local transform - anchor + scale
        # output transform = input.transforms

        # how to support absolute position on chain?

        # RS(X - A) + T
        # surface transform(X) = CSX - A
        # surface transform = anchor1 + scale + chain
        # output transform = chain + anchor1 + anchor2
        
    def surface_transform(self, dim, dtype) -> list[tlm.TransformBase]:
        "Transform chain that applies to the underlying surface"

        S = self.scale * torch.eye(dim, dtype=dtype)
        S_inv = 1./self.scale * torch.eye(dim, dtype=dtype)
        scale = [tlm.LinearTransform(S, S_inv)]

        if self.anchors[0] == "extent":
            T = -self.scale * torch.cat(
                    (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
                )
            anchor = [tlm.TranslateTransform(T)]
        else:
            anchor = []

        return anchor + scale

    def output_transform(self, dim, dtype) -> list[tlm.TransformBase]:
        "Transform chain that applies to the next element"

        # subtract first anchor, add second anchor
        
        if self.anchors[0] == "extent":
            Ta = - self.scale * torch.cat(
                (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
            )
            anchor1 = [tlm.TranslateTransform(Ta)]
        else:
            anchor1 = []

        if self.anchors[1] == "extent":
            Tb = self.scale * torch.cat(
                (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
            )
            anchor2 = [tlm.TranslateTransform(Tb)]
        else:
            anchor2 = []
        
        return anchor1 + anchor2


    def forward(self, inputs, sampling):
        dim, dtype = sampling["dim"], sampling["dtype"]

        _surface_transform = self.surface_transform(dim, dtype)

        # TODO compute rays here
        
        output_transform = self.output_transform(dim, dtype)
        
        
        return replace(inputs, transforms=inputs.transforms + output_transform)


class NewGap(tlm.Module):
    def __init__(self, offset: float | int | Tensor):
        super().__init__()
        assert(isinstance(offset, (float, int, torch.Tensor)))
        if isinstance(offset, torch.Tensor):
            assert(offset.dim() == 0)

        # Gap is always stored as float64, but it's converted to the sampling
        # dtype when creating the corresponding transform in forward()
        self.offset = torch.as_tensor(offset, dtype=torch.float64)

    def forward(self, inputs, sampling):
        dim, dtype = sampling["dim"], sampling["dtype"]
        
        translate_vector = torch.cat((self.offset.unsqueeze(0).to(dtype=dtype), torch.zeros(dim-1, dtype=dtype)))
        
        return replace(inputs, transforms=inputs.transforms + [tlm.TranslateTransform(translate_vector)])


class SurfaceArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, inputs, _outputs):

        dim, dtype = inputs.transforms[0].dim, inputs.transforms[0].dtype
        chain = inputs.transforms + element.surface_transform(dim, dtype)
        transform = forward_kinematic(chain)
        
        return {"type": "surfaces",
                "data": [tlm.viewer.render_surface(element.surface, transform, dim=transform.dim, N=10)]}

    @staticmethod
    def render_rays(element, inputs, outputs):

        rays_start = inputs.rays.get(["RX", "RY"])
        rays_vectors = inputs.rays.get(["VX", "VY"])
        
        rays_end = rays_start + 10*rays_vectors
        
        return {"type": "rays",
                "data": tlm.viewer.render_rays(rays_start, rays_end),
                "color": "#ffa724"}


class JointArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, inputs, _outputs):

        dim, dtype = inputs.transforms[0].dim, inputs.transforms[0].dtype
        transform = forward_kinematic(inputs.transforms)
        joint = transform.direct_points(torch.zeros((dim,), dtype=dtype))
                                                                          
        return {"type": "points",
                "data": [joint.tolist()]}
    
artists_dict = {
    NewOpticalSurface: SurfaceArtist,
}

def inspect_stack(execute_list):
    for module, inputs, outputs in execute_list:
        print(type(module))
        print("inputs.transform:")
        for t in inputs.transforms:
            print(t)
        print()
        print("outputs.transform:")
        for t in outputs.transforms:
            print(t)
        print()

def render_sequence(optics, sampling):
    dim, dtype = sampling["dim"], sampling["dtype"]
    execute_list, top_output = tlm.full_forward(optics, default_input(dim, dtype), sampling)

    if dim == 2:
        scene = {"data": [], "mode": "2D", "camera": "XY"}
    else:
        scene = {"data": [], "mode": "3D", "camera": "orthographic"}

    #inspect_stack(execute_list)

    for module, inputs, outputs in execute_list:
        # render chain join position
        scene["data"].append(JointArtist.render_element(module, inputs, outputs))
        
        for typ, artist in artists_dict.items():
            if isinstance(module, typ):
                group = artist.render_element(module, inputs, outputs)
                scene["data"].append(group)

                if inputs.rays.numel() > 0:
                    group = artist.render_rays(module, inputs, outputs)
                    scene["data"].append(group)

    return scene


def view(optics, sampling):
    scene = render_sequence(optics, sampling)
    #tlm.viewer.dump(scene, ndigits=2)
    tlm.viewer.show(scene)

surface = tlm.surfaces.Parabola(15.0, 0.020)
surface2 = tlm.surfaces.Parabola(15.0, 0.030)

surface3 = tlm.surfaces.Sphere(12, -20)

optics = tlm.OpticalSequence(
    PointSourceAtInfinity(10., 45),
    NewGap(5.0),
    NewOpticalSurface(surface, anchors=("origin", "extent")),
    NewGap(1.0),
    NewOpticalSurface(surface2, scale=-1, anchors=("extent", "origin")),
    NewGap(0.5),
    NewOpticalSurface(surface3, anchors=("extent", "extent")),
    NewGap(5.0),
    NewOpticalSurface(surface3, scale=-1, anchors=("extent", "extent")),
)


sampling = {"dim": 2, "dtype": torch.float64, "rays": 10}
view(optics, sampling)

#sampling = {"dim": 3, "dtype": torch.float64, "rays": 10}
#view(optics, sampling)