In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

# A simple reflecting telescope made of two concave mirrors

# In this example we keep the position of the mirrors constant
# and optimize their curvature jointly

# Note that there is more than one solution because rays can cross N times before focusing on the focal point
# We want the solution where they cross at the focal point for the first time
# TODO use image loss to account for flips

class Optics(tlm.Module):
    def __init__(self):
        super().__init__()

        self.shape_primary = tlm.Parabola(width=35., a=nn.Parameter(torch.tensor(-0.0001)))
        self.shape_secondary = tlm.CircularArc(width=35., r=nn.Parameter(torch.tensor(450.0)))

        self.optics = nn.Sequential(
            tlm.GapY(-100),
            tlm.ParallelBeamUniform(width=30.),
            tlm.GapY(100),
            
            tlm.ReflectiveSurface(self.shape_primary),
            tlm.GapY(-80),

            tlm.ReflectiveSurface(self.shape_secondary),

            tlm.GapY(100),
            tlm.FocalPointLoss(),
        )

    def forward(self, inputs):
        return self.optics(inputs)

optics = Optics()

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=2e-4),
    inputs = (10, torch.tensor([0., 0.])),
    num_iter = 100
)

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))