# Biconvex Spherical lens

In [None]:
import torch
import torchlensmaker as tlm
import torch.optim as optim

import math

def get_all_gradients(model):
    grads = []
    for param in model.parameters():
        if param.grad is not None:
            grads.append(param.grad.view(-1))
    return torch.cat(grads)

def train(optics, dim, dtype, num_iter, nshow=20):
    optimizer = optim.Adam(optics.parameters(), lr=4e-3)
    sampling = {"dim": dim, "dtype": dtype, "base": 10}
    
    default_input = tlm.default_input(sampling)

    show_every = math.ceil(num_iter / nshow)

    
    for i in range(num_iter):
        optimizer.zero_grad()

        # evaluate the model
        outputs = optics(default_input)
        loss = outputs.loss
        loss.backward()

        grad = get_all_gradients(optics)
        if torch.isnan(grad).any():
            print("ERROR: nan in grad", grad)
            raise RuntimeError("nan in gradient, check your torch.where() =)")
        
        optimizer.step()

        if i % show_every == 0:
            iter_str = f"[{i:>3}/{num_iter}]"
            L_str = f"L= {loss.item():>6.3f} | grad norm= {torch.linalg.norm(grad)}"
            print(f"{iter_str} {L_str}")




surface = tlm.surfaces.Sphere(diameter=15, r=tlm.parameter(20))

optics = tlm.Sequential(
    tlm.PointSourceAtInfinity(beam_diameter=18.5),
    tlm.Gap(10),
    tlm.RefractiveSurface(surface, n=(1.0, 1.5), anchors=("origin", "extent")),
    tlm.Gap(2),
    tlm.RefractiveSurface(surface, n=(1.5, 1.0), scale=-1, anchors=("extent", "origin")),
    tlm.Gap(30),
    tlm.FocalPoint(),
)

for name, p in optics.named_parameters():
    print(name, p)

tlm.show(optics, mode="2D", end=20)
tlm.show(optics, mode="3D", end=20)

In [None]:
train(optics, 2, torch.float64, num_iter=50, nshow=20)

In [None]:
tlm.show(optics, mode="2D", end=50)
tlm.show(optics, mode="3D", end=50)