In [None]:
import torch
import torch.nn as nn
import torchlensmaker as tlm
import torch.optim as optim

# y = a*x^2
surface = tlm.Parabola(diameter=15, A=tlm.parameter(0.015))

lens = tlm.BiLens(surface, material="BK7", outer_thickness=1.0)

optics = nn.Sequential(
    tlm.PointSourceAtInfinity(beam_diameter=18.5),
    tlm.Wavelength(500, 800),
    tlm.Gap(10),
    lens,
    tlm.Gap(50),
    tlm.FocalPoint(),
)

print(optics)

In [None]:
s1 = tlm.SphereR(diameter=15, R=7.5)

optics = tlm.Sequential(
    tlm.PointSource(beam_angular_size=20),
    tlm.Gap(15),
    tlm.KinematicSurface(nn.Sequential(
        tlm.CollisionSurface(s1),
        tlm.RefractiveBoundary("SF10-nd", "clamp"),
    ), s1, anchors=("origin", "extent")),

    tlm.KinematicSurface(nn.Sequential(
        
        tlm.CollisionSurface(s1),
        tlm.RefractiveBoundary("air", "clamp"),
        
    ), s1, scale=-1, anchors=("extent", "origin")),
)

tlm.show(optics, dim=2, end=20)

In [None]:
print(optics)
print()
print(optics[2].element[0])

      
# execute_tree = forward_tree(optics, tlm.default_input(2, torch.float64, sampling={"base": 5}))

# execute_tree[2].element[0].context

In [None]:
from typing import Any, Iterator
from dataclasses import dataclass

from torchlensmaker.core.full_forward import forward_tree

ins, outs = forward_tree(optics, tlm.default_input(sampling={"base": 5}, dim=2, dtype=torch.float64))

In [None]:
outs[optics]