In [None]:
import torch
import torch.nn as nn
import torchlensmaker as tlm

from typing import Any

# sequential element that hold a surface with no optical behavior
# just to test positioning
class TestSurface(tlm.SequentialElement):
    def __init__(self, surface):
        super().__init__()
        self.surface = surface

    def forward(self, data):
        t, normals, valid, _, _, dfk, ifk = self.surface(data.P, data.V, data.dfk, data.ifk)
        return data.replace(dfk=dfk, ifk=ifk)


class CollisionSurfaceArtist:
    def render_module(self, collective, module):
        inputs = collective.input_tree[module]
        outputs = collective.output_tree[module]

        t, normals, valid, surface_dfk, surface_ifk, next_dfk, next_ifk = outputs

        return [{   
            "type": "surface-sag",
            "diameter": module.diameter.item(),
            "sag-function":  {
                "sag-type": "spherical",
                "C": module.C.item(),
            },
            "matrix": surface_dfk.tolist(),
        }]

    def render_rays(self, collective, module) -> list[Any]:
        return []

    def render_joints(self, collective, module) -> list[Any]:
        dim, dtype = 2, torch.float32

        inputs = collective.input_tree[module]
        P, V, dfk, ifk = inputs
        
        origin = torch.zeros((dim,), dtype=dtype)
        joint = tlm.transform_points(dfk, origin)

        return [{"type": "points", "data": [joint.tolist()], "layers": [0]}]


source = tlm.PointSourceAtInfinity2D(8.0, sampler_pupil=tlm.LinspaceSampler1D(50))
data = source(tlm.default_input(2))

optics = tlm.Sequential(
    source,
    tlm.Gap(3),
    TestSurface(tlm.SphereC(10.0, C=0.15, anchors=(0.0, 1.0), scale=1)),
    tlm.Gap(0.5),
    TestSurface(tlm.SphereC(10.0, C=0.15, anchors=(1.0, 0.0), scale=-1)),
    tlm.Gap(0),
)

scene = tlm.show2d(optics, extra_artists={
    TestSurface: [tlm.ForwardArtist(lambda mod: mod.surface)],
    tlm.SphereC: [CollisionSurfaceArtist()]
}, 
controls={"show_optical_axis": True, "show_other_axes": True, "show_kinematic_joints": True},
                   return_scene=True)

print(scene)