# side quest

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

from dataclasses import dataclass, replace


Tensor = torch.Tensor

@dataclass
class NewOpticalData:
    transforms: list[tlm.TransformBase]


def default_input(dim: int, dtype: torch.dtype):
    return NewOpticalData([tlm.IdentityTransform(dim, dtype)])


class NewOpticalSurface(tlm.Module):
    def __init__(self, surface):
        super().__init__()
        self.surface = surface
        # todo support optional extra transform that isn't added to the chain
        # todo support scale
        # todo support anchor

    def forward(self, inputs, sampling):
        return inputs


class NewGap(tlm.Module):
    def __init__(self, offset: float | int | Tensor):
        super().__init__()
        assert(isinstance(offset, (float, int, torch.Tensor)))
        if isinstance(offset, torch.Tensor):
            assert(offset.dim() == 0)

        # Gap is always stored as float64, but it's converted to the sampling
        # dtype when creating the corresponding transform in forward()
        self.offset = torch.as_tensor(offset, dtype=torch.float64)

    def forward(self, inputs, sampling):
        dim, dtype = sampling["dim"], sampling["dtype"]
        
        translate_vector = torch.cat((self.offset.unsqueeze(0).to(dtype=dtype), torch.zeros(dim-1, dtype=dtype)))
        
        return replace(inputs, transforms=inputs.transforms + [tlm.TranslateTransform(translate_vector)])


class SurfaceArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, inputs, _outputs):
        
        # todo add anchor transform from the surface
        
        # Compose transforms in the chain to get the total transform
        # up to this element
        transform = tlm.ComposeTransform(inputs.transforms)
        
        return {"type": "surfaces",
                "data": [tlm.viewer.render_surface(element.surface, transform, dim=transform.dim, N=10)]}

artists_dict = {
    NewOpticalSurface: SurfaceArtist,
}

def inspect_stack(execute_list):
    for module, inputs, outputs in execute_list:
        print(type(module))
        print("inputs.transform:")
        for t in inputs.transforms:
            print(t)
        print()
        print("outputs.transform:")
        for t in outputs.transforms:
            print(t)
        print()

def render_sequence(optics, dim, dtype):

    sampling = {"dim": dim, "dtype": dtype}
    execute_list, top_output = tlm.full_forward(optics, default_input(dim, dtype), sampling)

    if dim == 2:
        scene = {"data": [], "mode": "2D", "camera": "XY"}
    else:
        scene = {"data": [], "mode": "3D", "camera": "orthographic"}

    #inspect_stack(execute_list)

    for module, inputs, outputs in execute_list:
        for typ, artist in artists_dict.items():
            if isinstance(module, typ):
                group = artist.render_element(module, inputs, outputs)
                scene["data"].append(group)

    return scene


def view(optics, dim, dtype):
    scene = render_sequence(optics, dim, dtype)
    #tlm.viewer.pprint(scene, ndigits=2)
    tlm.viewer.show(scene)


surface = tlm.surfaces.Parabola(35.0, 0.010)

optics = tlm.OpticalSequence(
    NewOpticalSurface(surface),
    NewGap(5.0),
    NewOpticalSurface(surface),
    NewGap(5.0),
    NewOpticalSurface(surface),
    NewGap(5.0),
    NewOpticalSurface(surface),
)

view(optics, 3, torch.float64)