In [None]:
import torchlensmaker as tlm
import torch
import torch.optim as optim
import torch.nn as nn


# where to handle parameters? inside shapes automatically, or manually in the stack class?
# gap syntax
# rework all examples / tests

class tlmModule(nn.Module):
    """
    Custom nn.Module to automatically register parameters of shapes

    This is similar to how nn.Module automatically registers nn.Parameters
    that are assigned to it. But here, we register shapes and their inner parameters.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        super().__setattr__("_shapes", {})

    def __getattr__(self, name):
        if name in self.__dict__["_shapes"]:
            return self.__dict__["_shapes"][name]
        else:
            return super().__getattr__(name)
    
    def __setattr__(self, name, value):
        if isinstance(value, tlm.Parabola):
            for n, p in value.parameters().items():
                # TODO add shape name/id
                self.register_parameter(name + "_" + n, p)
            self.__dict__["_shapes"][name] = value
        else:
            super().__setattr__(name, value)


class DevStack(tlmModule):
    def __init__(self):
        super().__init__()

        # optimizable
        self.shape1 = tlm.Parabola(width = 20., a = nn.Parameter(torch.tensor(0.005)))

        # share from other shape
        #self.shape2 = self.shape1.flip()

        # fixed value
        self.shape2 = tlm.Parabola(width = 20., a = -0.005)

    def forward(self, num_rays, hook=None):
        surface1 = tlm.Surface(self.shape1, pos=(0, 0), anchor="origin")
        surface2 = tlm.Surface(self.shape2, pos=surface1.at("extent") + torch.tensor([0, 1]), anchor="extent")
        
        surface3 = tlm.Surface(self.shape1, pos=surface2.at("origin") + torch.tensor([5, 10]), anchor="origin")
        surface4 = tlm.Surface(self.shape2, pos=surface3.at("extent") + torch.tensor([0, 1]), anchor="extent")
        
        self.optics = tlm.OpticalStack([
            tlm.ParallelBeamUniform(width=15., pos=(0, -10)),
        
            tlm.RefractiveSurface(surface1, (1.0, 1.49)),
            tlm.RefractiveSurface(surface2, (1.49, 1.0)),
        
            tlm.RefractiveSurface(surface3, (1.0, 1.49)),
            tlm.RefractiveSurface(surface4, (1.49, 1.0)),
        
            tlm.FocalPointLoss(pos=(0, 80)),
        ])

        return self.optics.forward(num_rays, hook=hook)

optics = DevStack()

for n, p in optics.named_parameters():
    print(n, p)

tlm.render_plt(optics, 10)

tlm.optimize(
    optics,
    optimizer = optim.Adam(optics.parameters(), lr=5e-4),
    num_rays = 20,
    num_iter = 50,
)

tlm.render_plt(optics, 10)

# signal what I want to optimize or not by just wrapping into parameter:

# ReflectiveSurface(Parabola(tlm.Parameter(5.0)))
# vs
# ReflectiveSurface(Parabola(5.0), pos=(a, b))

# setup a linear stack with same syntax gap(), etc.
# but preprocess the stack to remove gaps and setup surfaces origins/anchors

