# Biconvex lens (bezier spline)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

In [None]:
lens_radius = 15.0

class Optics(tlm.Module):
    def __init__(self):
        super().__init__()

        self.shape = tlm.BezierSpline(height=lens_radius*2,
            X=nn.Parameter(torch.tensor([3.0])),
            CX=nn.Parameter(torch.tensor([4.8])),
            CY=nn.Parameter(torch.tensor([0.2*lens_radius, 1.2*lens_radius])),
        )

        self.lens = tlm.SymmetricLens(self.shape, (1.0, 1.5), outer_thickness=2.)

        # TODO enforce CX > X, i.e. control point within knots

        self.optics = tlm.OpticalSequence(
            tlm.PointSourceAtInfinity(beam_diameter=0.99*lens_radius*2),
            tlm.Gap(15.),

            self.lens,
            
            tlm.Gap(50.0),
            tlm.FocalPoint(),
        )

    def forward(self, inputs, sampling):
        return self.optics(inputs, sampling)

optics = Optics()

tlm.render_plt(optics)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), 5e-2),
    sampling = {"rays": 10},
    num_iter = 150
)

tlm.render_plt(optics)

In [None]:
print(optics.shape.parameters())