# Reflecting Telescope

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchlensmaker as tlm

# A simple reflecting telescope made of two concave mirrors

# In this example we keep the position of the mirrors constant
# and optimize their curvature jointly

# Note that there is more than one solution because rays can cross N times before focusing on the focal point
# We want the solution where they cross at the focal point for the first time
# TODO use image loss to account for flips
# aka "intermetidate image"

class Optics(tlm.Module):
    def __init__(self):
        super().__init__()

        self.shape_primary = tlm.Parabola(height=35., a=nn.Parameter(torch.tensor(-0.0001)))  # y = a * x^2
        self.shape_secondary = tlm.CircularArc(height=35., r=nn.Parameter(torch.tensor(450.0)))

        self.optics = tlm.OpticalSequence(
            tlm.Gap(-100),
            tlm.PointSourceAtInfinity(beam_diameter=30),
            tlm.Gap(100),
            
            tlm.ReflectiveSurface(self.shape_primary),
            tlm.Gap(-80),

            tlm.ReflectiveSurface(self.shape_secondary),

            tlm.Gap(100),
            tlm.FocalPoint(),
        )

    def forward(self, inputs, sampling):
        return self.optics(inputs, sampling)

optics = Optics()

tlm.render_plt(optics)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=2e-4),
    sampling = {"rays": 10},
    num_iter = 100
)

tlm.render_plt(optics)