In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm


class Optics(tlm.Module):
    def __init__(self):
        super().__init__()

        self.shape = tlm.Parabola(width=35., a=nn.Parameter(torch.tensor(-0.005)))

        self.optics = nn.Sequential(
            tlm.PointSourceAtInfinity(beam_diameter=25),
            tlm.Gap(100.),
            
            tlm.ReflectiveSurface(self.shape),
            
            tlm.Gap(-45.0),
            tlm.FocalPointLoss(),
        )

    def forward(self, inputs):
        return self.optics(inputs)

optics = Optics()

tlm.render_plt(optics)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-4),
    num_iter = 60
)

tlm.render_plt(optics)