In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn


class DevStack(tlm.Module):
    def __init__(self):
        super().__init__()

        # optimizable
        self.shape = tlm.Parabola(height = 40., a = nn.Parameter(torch.tensor(0.005)))
    
        self.lens = tlm.SymmetricLens(self.shape, (1.0, 1.5), inner_thickness=8.0)

        self.optics = tlm.OpticalSequence(
            tlm.PointSourceAtInfinity(beam_diameter=40),
            tlm.Gap(10),
        
            self.lens,

            tlm.Gap(80),
        
            tlm.FocalPoint(),
        )


    def forward(self, inputs, sampling):
        return self.optics(inputs, sampling)


optics = DevStack()

print("Parameters")
for n, p in optics.named_parameters():
    print(n, p.detach().numpy())
print()

tlm.render_plt(optics)
print(optics.lens.inner_thickness(), optics.lens.outer_thickness())

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-3),
    sampling = {"rays": 10},
    num_iter = 50,
)

tlm.render_plt(optics)




In [None]:
tlm.lens_to_part(optics.lens)
