In [None]:
import torch
import torch.nn as nn
import torchlensmaker as tlm

from typing import Any

# sequential element that hold a surface with no optical behavior
# just to test positioning
class TestSurface(tlm.SequentialElement):
    def __init__(self, surface):
        super().__init__()
        self.surface = surface

    def forward(self, data):
        t, normals, valid, _, tf_next = self.surface(data.P, data.V, data.fk)
        return data.replace(fk=tf_next)


class CollisionSurfaceArtist:
    def render(self, collective, module):
        inputs = collective.input_tree[module]
        outputs = collective.output_tree[module]

        t, normals, valid, tf_surface, tf_next = outputs

        rendered_surface = module.render()
        rendered_surface["matrix"] = tf_surface.direct.tolist()

        dim, dtype = 2, torch.float32
        P, V, tf_in = inputs
        
        origin = torch.zeros((dim,), dtype=dtype)
        joint = tlm.transform_points(tf_in.direct, origin)

        rendered_joints = [{"type": "points", "data": [joint.tolist()], "layers": [0]}]

        return [rendered_surface] + rendered_joints


source = tlm.PointSourceAtInfinity2D(8.0, sampler_pupil=tlm.LinspaceSampler1D(50))
data = source(tlm.default_input(2))

surface1 = tlm.SphereByCurvature(10.0, C=0.05, anchors=(0.0, 1.0), scale=1, trainable=True)
surface2 = tlm.SphereByCurvature(10.0, C=surface1.C, anchors=(1.0, 0.0), scale=-1)

optics = tlm.Sequential(
    source,
    tlm.Gap(3),
    TestSurface(surface1),
    tlm.Gap(0.5),
    TestSurface(surface2),
    tlm.Gap(2),
    TestSurface(tlm.Disk(8.0)),
    tlm.Gap(0),
)

scene = tlm.show2d(optics, extra_artists={
    TestSurface: [tlm.ForwardArtist(lambda mod: mod.surface)],
    tlm.SphereByCurvature: [CollisionSurfaceArtist()],
    tlm.Disk: [CollisionSurfaceArtist()]
}, 
controls={"show_optical_axis": True, "show_other_axes": True, "show_kinematic_joints": True},
                   return_scene=True)

print(scene)