In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn

# sharing at the surface level rather than shape level

# update all shapes to new share / init way
# absolute positioning on X
# rework all examples / tests

# either: per parameter learning rate adapted to its scale
# or: all internal parameters in same scale / units to work well with a common learning rate

class tlmModule(nn.Module):
    """
    Custom nn.Module to automatically register parameters of shapes

    This is similar to how PyTorch's nn.Module automatically registers nn.Parameters
    that are assigned to it. But here, we register tlm shapes and their inner parameters.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        super().__setattr__("_shapes", {})

    def __getattr__(self, name: str):
        if name in self.__dict__["_shapes"]:
            return self.__dict__["_shapes"][name]
        else:
            return super().__getattr__(name)
    
    def __setattr__(self, name: str, value):
        if isinstance(value, tlm.Parabola):
            for parameter_name, parameter in value.parameters().items():
                # TODO add shape name/id
                self.register_parameter(name + "_" + parameter_name, parameter)
            self.__dict__["_shapes"][name] = value
        else:
            super().__setattr__(name, value)

"""
        RefractiveSurface(shape1, scale=-1, anchors=("origin", "extent"))
        Gap((0, 5))

        AbsolutePosition((6.0, 5.0))
        RelativePosition(target=surfaceXX, offset=(...))
        
        ReflectiveSurface(...)
        VariableGap(init=0.25, min=0.0, max=5.0)
        ReflectiveSurface(...)
"""



class DevStack(tlmModule):
    def __init__(self):
        super().__init__()

        # optimizable
        self.shape1 = tlm.Parabola(width = 20., a = nn.Parameter(torch.tensor(0.005)))

        # share from other shape
        #self.shape2 = self.shape1.flip()

        # fixed value
        #self.shape2 = tlm.Parabola(width = 20., a = -0.005)

        # free parameter
        self.x = nn.Parameter(torch.tensor(5.0))

        #self.x = tlm.Parameter(torch.tensor(5.0), learning_scale=10)

        self.optics = nn.Sequential(
            tlm.ParallelBeamUniform(width=15.),
            tlm.Gap(torch.tensor([0, 10])),
        
            tlm.RefractiveSurface(self.shape1, (1.0, 1.49), anchors=("origin", "extent")),
            tlm.Gap(torch.tensor([0, 1])),
            tlm.RefractiveSurface(self.shape1, (1.49, 1.0), scale=-1., anchors=("extent", "origin")),

            tlm.Gap(torch.tensor([0, 80])),
        
            tlm.FocalPointLoss(),
        )


    def forward(self, inputs):
        return self.optics(inputs)


optics = DevStack()

print("Parameters")
for n, p in optics.named_parameters():
    print(n, p.detach().numpy())
print()

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-2),
    inputs = (10, torch.tensor([0., 0.])),
    num_iter = 500,
)

tlm.render_plt(optics, (10, torch.tensor([0., 0.])))


