# Triple Lens

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

In [None]:
    
lens_width = 15.0

class Optics(tlm.Module):
    def __init__(self):
        super().__init__()

        self.shape = tlm.Parabola(lens_width, a=nn.Parameter(torch.tensor(-0.005)))
        
        surface1 = tlm.RefractiveSurface(self.shape, (1.0, 1.49), anchors=("origin", "extent"))
        surface2 = tlm.RefractiveSurface(self.shape, (1.49, 1.0), scale=-1, anchors=("extent", "origin"))

        lens = [surface1, tlm.Gap(5.0), surface2]
        
        self.optics = tlm.OpticalSequence(
            tlm.PointSourceAtInfinity(beam_diameter=0.9*lens_width),
            tlm.Gap(15.),
            
            *lens,
            tlm.Gap(5.),
            *lens,
            tlm.Gap(5.),
            *lens,
            
            tlm.Gap(80.), # focal length
            tlm.FocalPoint(),
        )

    def forward(self, inputs, sampling):
        return self.optics(inputs, sampling)


optics = Optics()


tlm.render_plt(optics)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=5e-4),
    sampling = {"rays": 10},
    num_iter = 100
)

tlm.render_plt(optics)