# Variable number of lenses

# Multiple Lenses Sequence

A more complex example of a system with multiple lenses in sequence. The number of lenses of each type is a easy to change parameter.

In [None]:
import math
import torch
import torch.nn as nn
import numpy as np

import torch.optim as optim
import torchlensmaker as tlm


class PulaskiStack(tlm.Module):
    def __init__(self, lens_radius, focal_length, lens_min_thickness, nplano, nbiconvex, nrplano):
        super().__init__()
        
        self.shape_convex = tlm.Parabola(lens_radius, nn.Parameter(torch.tensor(0.002)))
        self.shape_biconvex = tlm.Parabola(lens_radius, nn.Parameter(torch.tensor(0.005)))
    
        self.plano = tlm.PlanoLens(
            self.shape_convex,
            n = (1.0, 1.5),
            outer_thickness = lens_min_thickness,
            reverse=True,
        )
        
        self.biconvex = tlm.SymmetricLens(
            self.shape_biconvex,
            n = (1.0, 1.5),
            outer_thickness = lens_min_thickness,
        )
    
        self.rplano = tlm.PlanoLens(
            self.shape_convex,
            n = (1.0, 1.5),
            outer_thickness = lens_min_thickness,
            reverse=False,
        )

        self.optics = tlm.OpticalSequence(
            tlm.PointSourceAtInfinity(beam_diameter=0.9*lens_radius),
            tlm.Gap(10.),
            *[
                tlm.Gap(lens_spacing),
                self.plano,
            ]*nplano,
            *[
                tlm.Gap(lens_spacing),
                self.biconvex,
            ] *nbiconvex,
            *[
                tlm.Gap(lens_spacing),
                self.rplano,
            ]*nrplano,
            tlm.Gap(focal_length),
            tlm.FocalPoint(),
        )

    def forward(self, inputs, sampling):
        return self.optics(inputs, sampling)
        

torch.set_printoptions(precision=10)


square_size = 30
lens_radius = math.sqrt(2)*square_size/2
focal_length = 45.
nplano = 1
nbiconvex = 1
nrplano = 0
lens_min_thickness = 1.2
lens_spacing = 3.


def regu_equalparam(optics):
    params = torch.cat([param.view(-1) for param in optics.parameters()])
    return torch.pow(torch.diff(1000*torch.abs(params)).sum(), 2)
    
def regu_equalthickness(optics):
    t0 = optics.plano.inner_thickness()
    t1 = optics.biconvex.inner_thickness()
    return 100*torch.pow(t0 - t1, 2)

def demo_pulaski():
    print("Lens design")
    print("Square size", square_size)
    print("Lens radius", lens_radius)
    print("Configuration", nplano, nbiconvex, nrplano)
    print("lens_min_thickness", lens_min_thickness)
    print("lens_spacing", lens_spacing)

    optics = PulaskiStack(lens_radius, focal_length, lens_min_thickness, nplano, nbiconvex, nrplano)

    tlm.render_plt(optics, force_uniform_source=False)

    print(list(optics.named_parameters()))
    
    tlm.optimize(
        optics,
        optimizer = optim.Adam(optics.parameters(), lr=5e-4),
        sampling = {"rays": 10},
        num_iter = 250,
        regularization = regu_equalthickness,
    )

    tlm.render_plt(optics, force_uniform_source=False)

    # Thickness profile
    half_square_size = square_size/2
    
    def thickness_profile(lens):
        a, c = lens.inner_thickness(), lens.outer_thickness()
        #b = parabolic_lens_thickness(lens, square_size/2)
        return a, 0.0, c

    if nplano > 0:
        print("Plano-convex thickness {:.4f} {:.4f} {:.4f}".format(*thickness_profile(optics.plano)))
    if nbiconvex > 0:
        print("Bi-convex thickness {:.4f} {:.4f} {:.4f}".format(*thickness_profile(optics.biconvex)))
    if nrplano > 0:
        print("Reverse plano-convex thickness {:.4f} {:.4f} {:.4f}".format(*thickness_profile(optics.rplano)))
    print()
    print(list(optics.parameters()))

    return optics

optics = demo_pulaski()

In [None]:
from IPython.display import display

display(tlm.lens_to_part(optics.plano))
display(tlm.lens_to_part(optics.biconvex))