In [None]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchlensmaker as tlm

from xxchallenge import *

base_height = 35


class Corners(nn.Module):
    def __init__(self, surface):
        super().__init__()
        self.surface = surface

    def forward(self, data):
        tf = data.tf()

        X = -1000
        WEIGHT = 1000

        corners = torch.tensor([
            [X, -500, -500],
            [X, 500, -500],
            [X, -500, 500],
            [X, 500, 500],
        ], dtype=torch.float64)

        # Transform corners to local frame
        Ps = tf.inverse_points(corners)

        F = self.surface.F(Ps) / 1000
        loss = WEIGHT * torch.clamp(F, max=0.0).pow(2).sum() / 4

        return data.replace(loss=data.loss + loss)


# Parameters
A = tlm.parameter(-1.14)
G0 = tlm.parameter(-255)

# XY polynomial parameters
C = tlm.parameter(torch.zeros((13,13), dtype=torch.float64))
fixed_mask = torch.zeros_like(C, dtype=torch.bool)
fixed_mask[0, 0] = True  # Freeze position (0,0)
C.register_hook(lambda grad: grad.masked_fill(fixed_mask, 0.))

cylinder = tlm.ImplicitCylinder(*torch.tensor([-50/2, 50/2, 37.02/2], dtype=torch.float64).unbind())
rod_data = StoreVar(lambda data: data)

# Primary mirror
sag = tlm.SagSum([
    tlm.Parabolic(A=A, normalize=True),
    tlm.XYPolynomial(C, normalize=True)
])
primary = tlm.SagSurface(1800, sag)
primary_data = StoreVar(lambda data: data)

# Optical model
optics = tlm.Sequential(
    tlm.Gap(-1000),
    XXLightSource.load(),

    # Primary mirror
    tlm.Gap(1000-base_height),
    primary_data,
    Corners(primary),
    tlm.ReflectiveSurface(primary),

    # Fixed rod
    tlm.AbsolutePosition(x=G0),
    #tlm.Rotate3D(y=45),
    rod_data,
    NonImagingRod(cylinder),
)

xxrender(optics, sampling={"xx": 500, "letter": "both"})

In [None]:
param_groups = [
    {'params': [A], 'lr': 0.003},
    {'params': [G0], 'lr': 40},
    {'params': [C], 'lr': 0.008},
]


record = tlm.optimize(
    optics,
    optimizer = optim.SGD(param_groups),
    sampling = {"xx": 10000, "disable_viewer": True, "letter": "both"},
    dim = 3,
    num_iter = 20,
    nshow=20,
)

plot_record(record, param_groups, optics)
record.best()

print()
print("Final values")
print("A", A)
print("G0", G0)

# Print rod position
target = rod_data.value.target()
print("ROD X", target[1].item())
print("ROD Y", target[2].item())
print("ROD Z", -target[0].item())

xxrender(optics, sampling={"xx": 500, "letter": "both"})

In [None]:
# Print rod position
target = rod_data.value.target()
print("ROD POSITION IN XX FRAME")
print("ROD X", target[1].item())
print("ROD Y", target[2].item())
print("ROD Z", -target[0].item())
rod_z = f"z{-target[0].item():.0f}"
print(rod_z)

with torch.no_grad():
    part_primary = tess_mirror(xxgrid(499, 80), primary_data.value.tf(), primary)

part_sides = makesides(part_primary.vectors.dtype)

mesh.Mesh(np.concatenate([
    part_primary.data,
    part_sides.data
])).save(f"parabolaXY-{rod_z}.stl")

