# side quest

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm


class NewOpticalSurface(tlm.Module):
    def __init__(self, surface, transform):
        super().__init__()
        self.surface = surface
        self.transform = transform

    def forward(self, inputs, sampling):
        return inputs


class SurfaceArtist:
    @staticmethod
    def render_element(element: NewOpticalSurface, _inputs, _outputs):
        return {"type": "surfaces",
                "data": [tlm.viewer.render_surface(element.surface, element.transform, dim=2, N=10)]}

artists_dict = {
    NewOpticalSurface: SurfaceArtist,
}

def render_sequence(optics):

    sampling = {}
    execute_list, outputs = tlm.full_forward(optics, tlm.default_input, sampling)
    
    scene = {"data": [], "mode": "2D", "camera": "XY"}

    for module, inputs, outputs in execute_list:
        for typ, artist in artists_dict.items():
            if isinstance(module, typ):
                group = artist.render_element(module, inputs, outputs)
                scene["data"].append(group)

    return scene


def view(optics):
    scene = render_sequence(optics)
    #tlm.viewer.pprint(scene, ndigits=2)
    tlm.viewer.show(scene)


surface = tlm.surfaces.Parabola(35.0, 0.010)
transform = tlm.basic_transform(1.0, "extent", 15., [40., 12])(surface)

## surfaces can have chain-local transforms
## adds to the compose of the chain

optics = tlm.OpticalSequence(
    NewOpticalSurface(surface, transform),
)

view(optics)

In [None]:
shape = tlm.Parabola(height=15., a=nn.Parameter(torch.tensor(0.005)))


optics = tlm.OpticalSequence(
    tlm.PointSourceAtInfinity(beam_diameter=20),
    tlm.Gap(10.),
    lens,
    tlm.Gap(45.0),
    tlm.FocalPoint(),
)

tlm.render_plt(optics)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-3),
    sampling = {"rays": 10},
    num_iter = 100
)

tlm.render_plt(optics)