In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn


# where to handle parameters? inside shapes automatically, or manually in the stack class?


class DevStack(nn.Module):
    def __init__(self):
        super().__init__()
        self.shape1 = tlm.Parabola(20., init=[0.005])
        self.shape2 = tlm.Parabola.share(self.shape1, scale=-1.0)

        for name, param in self.shape1.parameters().items():
            if isinstance(param, nn.Parameter):
                self.register_parameter(name, param)

    def forward(self, num_rays, hook=None):
        surface1 = tlm.Surface(self.shape1, pos=(0, 0), anchor="origin")
        surface2 = tlm.Surface(self.shape2, pos=surface1.at("extent") + torch.tensor([0, 1]), anchor="extent")
        
        surface3 = tlm.Surface(self.shape1, pos=surface2.at("origin") + torch.tensor([0, 10]), anchor="origin")
        surface4 = tlm.Surface(self.shape2, pos=surface3.at("extent") + torch.tensor([0, 1]), anchor="extent")
        
        self.optics = tlm.OpticalStack([
            tlm.ParallelBeamUniform(width=15., pos=(0, -10)),
        
            tlm.RefractiveSurface(surface1, (1.0, 1.49)),
            tlm.RefractiveSurface(surface2, (1.49, 1.0)),
        
            tlm.RefractiveSurface(surface3, (1.0, 1.49)),
            tlm.RefractiveSurface(surface4, (1.49, 1.0)),
        
            tlm.FocalPointLoss(pos=(0, 80)),
        ])

        return self.optics.forward(num_rays, hook=hook)

optics = DevStack()

tlm.render_plt(optics, 10)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=5e-4),
    num_rays = 20,
    num_iter = 80,
)

tlm.render_plt(optics, 10)

# signal what I want to optimize or not by just wrapping into parameter:

# ReflectiveSurface(Parabola(tlm.Parameter(5.0)))
# vs
# ReflectiveSurface(Parabola(5.0), pos=(a, b))

# setup a linear stack with same syntax gap(), etc.
# but preprocess the stack to remove gaps and setup surfaces origins/anchors

