In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn


class DevStack(tlm.Module):
    def __init__(self):
        super().__init__()

        # optimizable
        self.shape = tlm.Parabola(width = 40., a = nn.Parameter(torch.tensor(0.005)))
    
        self.lens = tlm.SymmetricLens(self.shape, (1.0, 1.5), inner_thickness=8.0)

        self.optics = nn.Sequential(
            tlm.ParallelBeamUniform(width=40.),
            tlm.GapY(10),
        
            self.lens,

            tlm.GapY(80),
        
            tlm.FocalPointLoss(),
        )


    def forward(self, inputs):
        return self.optics(inputs)


optics = DevStack()

print("Parameters")
for n, p in optics.named_parameters():
    print(n, p.detach().numpy())
print()

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))
print(optics.lens.inner_thickness(), optics.lens.outer_thickness())

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-3),
    inputs = (10, torch.tensor([0., 0.])),
    num_iter = 50,
)

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))




In [None]:
tlm.lens_to_part(optics.lens)
