In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn

# update all shapes to new share / init way
# implement non shape parameter to see if archi still works - ex: VariableGap

# nice gap syntax
# rework all examples / tests

# either: per parameter learning rate adapted to its scale
# or: all internal parameters in same scale / units to work well with a common learning rate

class tlmModule(nn.Module):
    """
    Custom nn.Module to automatically register parameters of shapes

    This is similar to how PyTorch's nn.Module automatically registers nn.Parameters
    that are assigned to it. But here, we register tlm shapes and their inner parameters.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        super().__setattr__("_shapes", {})

    def __getattr__(self, name: str):
        if name in self.__dict__["_shapes"]:
            return self.__dict__["_shapes"][name]
        else:
            return super().__getattr__(name)
    
    def __setattr__(self, name: str, value):
        if isinstance(value, tlm.Parabola):
            for parameter_name, parameter in value.parameters().items():
                # TODO add shape name/id
                self.register_parameter(name + "_" + parameter_name, parameter)
            self.__dict__["_shapes"][name] = value
        else:
            super().__setattr__(name, value)


class DevStack(tlmModule):
    def __init__(self):
        super().__init__()

        # optimizable
        self.shape1 = tlm.Parabola(width = 20., a = nn.Parameter(torch.tensor(0.005)))

        # share from other shape
        #self.shape2 = self.shape1.flip()

        # fixed value
        self.shape2 = tlm.Parabola(width = 20., a = -0.005)

        # free parameter
        self.x = nn.Parameter(torch.tensor(5.0))

        #self.x = tlm.Parameter(torch.tensor(5.0), learning_scale=10)

    # TODO improve hook situation
    def forward(self, num_rays):
        surface1 = tlm.Surface(self.shape1, pos=(0, 0), anchor="origin")
        surface2 = tlm.Surface(self.shape2, pos=surface1.at("extent") + torch.tensor([0, 1]), anchor="extent")

        offset = torch.stack((self.x, torch.tensor(10)))
        surface3 = tlm.Surface(self.shape1, pos=surface2.at("origin") + offset, anchor="origin")
        surface4 = tlm.Surface(self.shape2, pos=surface3.at("extent") + torch.tensor([0, 1]), anchor="extent")

        ##
        #self.surface1.move((0, 0), "origin")
        #self.surface2.move_extent(surface1.extent() + torch.tensor([0, 1]))

        # idea: share shapes by contructing multiple surfaces with the same shape instead of cloning shapes
        #s1 = tlm.Surface(self.shape1,  scale = -1, pos = (0, 0), anchor = "origin")

        #self.S1.move(pos = (0, 0), anchor = "origin")

        # init:
        # self.S1 = Surface(..)

        # forward:
        # self.S1.move_origin( self.S2.origin() + torch.tensor([0, 10]) )        
        # move_extent

        """
        RefractiveSurface(shape1, scale=-1, anchors=("origin", "extent"))
        Gap((0, 5))

        AbsolutePosition((6.0, 5.0))
        RelativePosition(target=surfaceXX, offset=(...))
        
        ReflectiveSurface(...)
        VariableGap(init=0.25, min=0.0, max=5.0)
        ReflectiveSurface(...)
        """

        """
        ParallelBeamUniform()
        Gap()
        RefractiveSurface()

        inputs: (rays, target)
        """
        

        # idea: add "current" position information in the data
        # like we had implicitly when relative positioning
        
        self.optics = nn.Sequential(
            tlm.ParallelBeamUniform(width=15., pos=(0, -10)),
        
            tlm.RefractiveSurface(surface1, (1.0, 1.49)),
            tlm.RefractiveSurface(surface2, (1.49, 1.0)),
        
            tlm.RefractiveSurface(surface3, (1.0, 1.49)),
            tlm.RefractiveSurface(surface4, (1.49, 1.0)),
        
            tlm.FocalPointLoss(pos=(0, 80)),
        )
        

        return self.optics(num_rays)

optics = DevStack()

for n, p in optics.named_parameters():
    print(n, p)

tlm.render_plt(optics, 10)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=1e-2),
    num_rays = 20,
    num_iter = 50,
)

tlm.render_plt(optics, 10)

# signal what I want to optimize or not by just wrapping into parameter:

# ReflectiveSurface(Parabola(tlm.Parameter(5.0)))
# vs
# ReflectiveSurface(Parabola(5.0), pos=(a, b))

# setup a linear stack with same syntax gap(), etc.
# but preprocess the stack to remove gaps and setup surfaces origins/anchors

