In [None]:
import math
import torch
import torch.optim as optim
import torchlensmaker as tlm
from xxchallenge import *

# Non imagium rod
xmin, xmax, tau = torch.tensor([-50/2, 50.2, 37.02/2], dtype=torch.float64).unbind()
cylinder = tlm.ImplicitCylinder(xmin, xmax, tau)

# Primary mirror
C = tlm.parameter(torch.zeros((5, 5), dtype=torch.float64))
fixed_mask = torch.zeros_like(C, dtype=torch.bool)
fixed_mask[0, 0] = True  # Freeze constant term to 0
C.register_hook(lambda grad: grad.masked_fill(fixed_mask, 0.))

primary = tlm.SagSurface(1200, tlm.SagSum([
    tlm.Parabolic(torch.tensor(-0.1630), normalize=True),
    #tlm.XYPolynomial(C, normalize=True)
]))

primary = tlm.Parabola(1400, A=tlm.parameter(-0.35), normalize=True)


# Kinematics
rod_max_z = math.floor(1000 - math.sqrt(25**2 + (37.02/2)**2))
rod_x = tlm.parameter(0.)
rod_y = tlm.parameter(0.)
#rod_z = tlm.parameter(0.)


# Optical model
optics = tlm.Sequential(
    tlm.Gap(-1000),
    XXLightSource.load(half=True),
    RaysViewerPlane(1500, "mark1"),
    tlm.Gap(1000),
    tlm.SubChain(
        tlm.Rotate3D((5, 0)),
        tlm.ReflectiveSurface(primary),
    ),
    #tlm.Gap(-470),
    tlm.AbsoluteTransform(tlm.TranslateTransform(torch.tensor([-500, 0, 0], dtype=torch.float64))),
    tlm.Gap(rod_x),
    RaysViewerPlane(1500, "mark2"),
    #RaysViewerPlane(10000, "mark3"),
    tlm.Rotate3D((90, 45)),
    NonImagingRod(cylinder),
)



tlm.show3d(optics, sampling={"xx": 200}, extra_artists={NonImagingRod: RodArtist()}, controls={"show_bounding_cylinders": True}, end=1000)

# optimize rod position


#torch.autograd.set_detect_anomaly(True)

In [None]:
tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-4),
    sampling = {"xx": 1000, "disable_viewer": True},
    dim = 3,
    num_iter = 30
).plot()

In [None]:
tlm.show3d(optics, sampling={"xx": 500}, extra_artists={NonImagingRod: RodArtist()}, end=1100, controls={"show_bounding_cylinders": True})

#print(primary.A.item())
#print(primary.sag_function.terms[1].coefficients)

In [None]:
import numpy as np
from scipy.spatial import Delaunay
from stl import mesh

def xxgrid():
    x = np.linspace(-500, 500, 30)
    y = np.linspace(-500, 500, 30)
    X, Y = np.meshgrid(x, y)
    return np.stack((X, Y), -1).reshape(-1, 2)


def tess(points: np.array, sag, filename: str):
    """
    Tesselate a sag function

    Args:
        points: sampling points in the YZ plane
        sag: sag function to tesselate
        filename: stl output filename
    """

    tri = Delaunay(points)

    Y, Z = points[:, 0], points[:, 1]

    pointsG = sag.sag_function.G(torch.as_tensor(Y), torch.as_tensor(Z), sag.tau()).detach().numpy()

    # flip Z axis
    pointsG *= -1

    # convert mm to m
    points *= 0.001
    pointsG *= 0.001
    
    vertices = np.concatenate((points, pointsG[:, np.newaxis]), -1)
    faces = tri.simplices

    # Create the mesh
    part = mesh.Mesh(np.zeros(faces.shape[0], dtype=mesh.Mesh.dtype))
    for i, f in enumerate(faces):
        for j in range(3):
            part.vectors[i][j] = vertices[f[j], :]

    part.save(filename)


tess(xxgrid(), primary, "primary.stl")

# Print position of non imaging rod
output = optics(tlm.default_input({"xx": 1, "disable_viewer": True}, dim=3, dtype=torch.float64))
print(output.target())

In [None]:
## Notes on kinematic elements

# tlm.SubChain(transform=, children=)
# tlm.RelativeTransform(transform=)
# tlm.AbsoluteTransform(transform=)

# child: abs / next: abs
# child: prev / next: abs
# child: prev + t / next: abs
# :: XXX, tlm.AbsolutePosition

# child: abs / next: prev
# child: prev / next: prev
# child: prev + t / next: prev
# :: tlm.SubChain(tlm.Sequential(XXX, ...))

# child: abs / next: prev + t
# child: prev / next: prev + t
# child: prev + t / next: prev + t
# :: XXX

In [None]:
import torchlensmaker as tlm
import torch

surface = tlm.Sphere(4, 6)

optics = tlm.Sequential(
    tlm.Marker("hello world"),
    tlm.ReflectiveSurface(surface),
    tlm.Translate3D(x=10, z=5),
    tlm.ReflectiveSurface(surface),
)

tlm.show3d(optics, controls={"show_optical_axis": True, "show_kinematic_joints": True})