# Triple Lens

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class Surface:
    def __init__(self, p):
        self.p = torch.as_tensor(p, dtype=torch.float64)

    def parameters(self):
        if isinstance(self.p, nn.Parameter):
            return {"p": self.p}
        else:
            return {}

class Foo(nn.Module):
    def __init__(self, surface):
        super().__init__()

        self.surface = surface

        # If surface has parameters, register them
        for name, p in surface.parameters().items():
            print(type(p), p.dtype)
            self.register_parameter(name, p)

    def forward(self, x):
        return (x + 5)*self.surface.p

p = nn.Parameter(torch.tensor(0.5, dtype=torch.float32))
s = Surface(p)

f = Foo(s)


for name, p in f.named_parameters():
    print(name, p)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchlensmaker as tlm

lens_diameter = 15.0

surface = tlm.Parabola(lens_diameter, a=tlm.parameter(-0.005))
lens = tlm.BiLens(surface, (1.0, 1.5), outer_thickness=0.5)

optics = nn.Sequential(
    tlm.PointSourceAtInfinity(0.9*lens_diameter),
    tlm.Gap(15),
    
    lens,
    tlm.Gap(5),
    lens,
    tlm.Gap(5),
    lens,
    
    tlm.Gap(80),
    tlm.FocalPoint(),
)

tlm.show(optics, mode="2D")
tlm.show(optics, mode="3D")

In [None]:
tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=3e-5),
    sampling = {"dim": 2, "dtype": torch.float64, "base": 10},
    num_iter = 100
).plot()

print("Final parabola parameter:", surface.a.item())
print("Outer thickness:", lens.outer_thickness())
print("Inner thickness:", lens.inner_thickness())

In [None]:
tlm.show(optics, mode="2D")
tlm.show(optics, mode="3D")