# side quest

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

from dataclasses import dataclass, replace


Tensor = torch.Tensor

@dataclass
class NewOpticalData:
    # transform kinematic chain
    transforms: list[tlm.TransformBase]


def default_input(dim: int, dtype: torch.dtype):
    return NewOpticalData([tlm.IdentityTransform(dim, dtype)])


class NewOpticalSurface(tlm.Module):
    def __init__(self, surface, scale=1., anchors=("origin", "origin")):
        super().__init__()
        self.surface = surface
        self.scale = scale
        self.anchors = anchors

        # mode 1:  inline / chained / affecting
        # surface transform = input.transforms - anchor + scale
        # output transform = input.transforms - first anchor + second anchor
        
        # mode 2: offline / free / independent
        # surface transform = input.transforms + specific transform + anchor
        # output transform = input.transforms

        # how to support absolute position on chain?

        # RS(X - A) + T
        # surface transform(X) = CSX - A
        # surface transform = anchor1 + scale + chain
        # output transform = chain + anchor1 + anchor2
        
    def surface_transform(self, dim, dtype) -> list[tlm.TransformBase]:
        "Transform chain that applies to the underlying surface"

        S = self.scale * torch.eye(dim, dtype=dtype)
        S_inv = 1./self.scale * torch.eye(dim, dtype=dtype)
        scale = [tlm.LinearTransform(S, S_inv)]

        if self.anchors[0] == "extent":
            T = -self.scale * torch.cat(
                    (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
                )
            anchor = [tlm.TranslateTransform(T)]
        else:
            anchor = []

        return anchor + scale

    def output_transform(self, dim, dtype) -> list[tlm.TransformBase]:
        "Transform chain that applies to the next element"

        # subtract first anchor, add second anchor
        
        if self.anchors[0] == "extent":
            Ta = - self.scale * torch.cat(
                (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
            )
            anchor1 = [tlm.TranslateTransform(Ta)]
        else:
            anchor1 = []

        if self.anchors[1] == "extent":
            Tb = self.scale * torch.cat(
                (self.surface.extent().unsqueeze(0), torch.zeros(dim - 1, dtype=dtype)), dim=0
            )
            anchor2 = [tlm.TranslateTransform(Tb)]
        else:
            anchor2 = []
        
        return anchor1 + anchor2


    def forward(self, inputs, sampling):
        dim, dtype = sampling["dim"], sampling["dtype"]

        _surface_transform = self.surface_transform(dim, dtype)

        # TODO compute rays here
        
        output_transform = self.output_transform(dim, dtype)
        
        
        return replace(inputs, transforms=inputs.transforms + output_transform)


class NewGap(tlm.Module):
    def __init__(self, offset: float | int | Tensor):
        super().__init__()
        assert(isinstance(offset, (float, int, torch.Tensor)))
        if isinstance(offset, torch.Tensor):
            assert(offset.dim() == 0)

        # Gap is always stored as float64, but it's converted to the sampling
        # dtype when creating the corresponding transform in forward()
        self.offset = torch.as_tensor(offset, dtype=torch.float64)

    def forward(self, inputs, sampling):
        dim, dtype = sampling["dim"], sampling["dtype"]
        
        translate_vector = torch.cat((self.offset.unsqueeze(0).to(dtype=dtype), torch.zeros(dim-1, dtype=dtype)))
        
        return replace(inputs, transforms=inputs.transforms + [tlm.TranslateTransform(translate_vector)])


class SurfaceArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, inputs, _outputs):
        
        # Forward kinematic chain
        # Compose transforms in reverse order
        dim, dtype = inputs.transforms[-1].dim, inputs.transforms[-1].dtype
        chain = inputs.transforms + element.surface_transform(dim, dtype)
        transform = tlm.ComposeTransform(list(reversed(chain)))
        
        return {"type": "surfaces",
                "data": [tlm.viewer.render_surface(element.surface, transform, dim=transform.dim, N=10)]}


class JointArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, inputs, _outputs):

        dim = inputs.transforms[-1].dim
        points = tlm.ComposeTransform(list(reversed(inputs.transforms))).direct_points(torch.zeros((1, dim)))
        return {"type": "points",
                "data": points.tolist()}
    
artists_dict = {
    NewOpticalSurface: SurfaceArtist,
}

def inspect_stack(execute_list):
    for module, inputs, outputs in execute_list:
        print(type(module))
        print("inputs.transform:")
        for t in inputs.transforms:
            print(t)
        print()
        print("outputs.transform:")
        for t in outputs.transforms:
            print(t)
        print()

def render_sequence(optics, dim, dtype):

    sampling = {"dim": dim, "dtype": dtype}
    execute_list, top_output = tlm.full_forward(optics, default_input(dim, dtype), sampling)

    if dim == 2:
        scene = {"data": [], "mode": "2D", "camera": "XY"}
    else:
        scene = {"data": [], "mode": "3D", "camera": "orthographic"}

    #inspect_stack(execute_list)

    for module, inputs, outputs in execute_list:
        # render chain join position
        scene["data"].append(JointArtist.render_element(module, inputs, outputs))
        
        for typ, artist in artists_dict.items():
            if isinstance(module, typ):
                group = artist.render_element(module, inputs, outputs)
                scene["data"].append(group)

    return scene


def view(optics, dim, dtype):
    scene = render_sequence(optics, dim, dtype)
    #tlm.viewer.pprint(scene, ndigits=2)
    tlm.viewer.show(scene)


surface = tlm.surfaces.Parabola(15.0, 0.020)
surface2 = tlm.surfaces.Parabola(15.0, 0.030)

surface3 = tlm.surfaces.Sphere(12, -20)

optics = tlm.OpticalSequence(
    NewOpticalSurface(surface, anchors=("origin", "extent")),
    NewGap(1.0),
    NewOpticalSurface(surface2, scale=-1, anchors=("extent", "origin")),
    NewGap(0.5),
    NewOpticalSurface(surface3, anchors=("extent", "extent")),
    NewGap(5.0),
    NewOpticalSurface(surface3, scale=-1, anchors=("extent", "extent")),
)

view(optics, 2, torch.float64)
view(optics, 3, torch.float64)