In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn


class SymmetricLens(tlm.Module):
    """
    A simple lens made of two symmetrical refractive surfaces.
    """
    
    def __init__(self, shape, n, inner_thickness=None, outer_thickness=None):
        super().__init__()
        self.shape = shape
        
        if inner_thickness is not None and outer_thickness is None:
            anchors = ("origin", "extent")
            thickness = inner_thickness 
        elif outer_thickness is not None and inner_thickness is None:
            anchors = ("origin", "origin")
            thickness = outer_thickness
        else:
            raise ValueError("Exactly one of inner/outer thickness must be given")

        self.optics = nn.Sequential(
            tlm.RefractiveSurface(self.shape, n, anchors=anchors),
            tlm.GapY(thickness),
            tlm.RefractiveSurface(self.shape, tuple(reversed(n)), scale=-1., anchors=tuple(reversed(anchors))),  
        )

    def forward(self, inputs):
        return self.optics(inputs)



class DevStack(tlm.Module):
    def __init__(self):
        super().__init__()

        # optimizable
        self.shape = tlm.Parabola(width = 20., a = nn.Parameter(torch.tensor(0.005)))

        self.optics = nn.Sequential(
            tlm.ParallelBeamUniform(width=15.),
            tlm.GapY(10),
        
            SymmetricLens(self.shape, (1.0, 1.5), inner_thickness=10.0),

            tlm.GapY(80),
        
            tlm.FocalPointLoss(),
        )


    def forward(self, inputs):
        return self.optics(inputs)


optics = DevStack()

print("Parameters")
for n, p in optics.named_parameters():
    print(n, p.detach().numpy())
print()

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-3),
    inputs = (10, torch.tensor([0., 0.])),
    num_iter = 500,
)

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))




In [None]:
lens = tlm.Lens(optics.optics[2], optics.optics[3], optics.optics[4])

part = tlm.lens_to_part(lens)
part