In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn


class DevStack(tlm.Module):
    def __init__(self):
        super().__init__()

        # optimizable
        self.shape1 = tlm.Parabola(width = 20., a = nn.Parameter(torch.tensor(0.005)))

        # free parameter
        #self.x = nn.Parameter(torch.tensor(5.0))

        self.optics = nn.Sequential(
            tlm.ParallelBeamUniform(width=15.),
            tlm.Gap(torch.tensor([0, 10])),
        
            tlm.RefractiveSurface(self.shape1, (1.0, 1.49), anchors=("origin", "extent")),
            tlm.Gap(torch.tensor([0, 1])),
            tlm.RefractiveSurface(self.shape1, (1.49, 1.0), scale=-1., anchors=("extent", "origin")),

            tlm.Gap(torch.tensor([0, 80])),
        
            tlm.FocalPointLoss(),
        )


    def forward(self, inputs):
        return self.optics(inputs)


optics = DevStack()

print("Parameters")
for n, p in optics.named_parameters():
    print(n, p.detach().numpy())
print()

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-2),
    inputs = (10, torch.tensor([0., 0.])),
    num_iter = 500,
)

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))


