# side quest

In [None]:
import torch
import torch.nn as nn
import torchlensmaker as tlm

from torchlensmaker.lenses import *


surface1 = tlm.surfaces.Parabola(15.0, 0.020)
surface2 = tlm.surfaces.Parabola(15.0, 0.030)
surface3 = tlm.surfaces.Sphere(12, -20)

lens = nn.Sequential(
    tlm.RefractiveSurface(surface1, n=(1.0, 1.5), anchors=("origin", "extent")),
    tlm.Gap(1),
    tlm.RefractiveSurface(surface1, n=(1.5, 1.0), scale=-1, anchors=("extent", "origin")),
)

optics = nn.Sequential(
    tlm.PointSourceAtInfinity(18),
    tlm.Gap(5),
    lens,
    tlm.Gap(0),
    tlm.Aperture(10),
    tlm.Gap(15),
    tlm.ReflectiveSurface(surface3),
)



dim, dtype = 3, torch.float64

it = anchor_thickness(lens, "origin", dim, dtype)
ot = anchor_thickness(lens, "extent", dim, dtype)
print("inner thickness", it)
print("outer thickness", ot)

tlm.show(optics, dim=2, end=10)
tlm.show(optics, dim=3, end=10)



In [None]:
# mode 1:  inline / chained / affecting
# surface transform = input.transforms - anchor + scale
# output transform = input.transforms - first anchor + second anchor

# mode 2: offline / free / independent
# surface transform = input.transforms + local transform - anchor + scale
# output transform = input.transforms

# how to support absolute position on chain?

# RS(X - A) + T
# surface transform(X) = CSX - A
# surface transform = anchor1 + scale + chain
# output transform = chain + anchor1 + anchor2

In [None]:
import torch


def prod(A, B):
    """
    Cartesian product of 2 batched coordinate tensors of shape (N, D) and (M, D)
    returns 2 Tensors of shape ( N*M , D )
    """

    assert A.shape[1] == B.shape[1]
    N, M = A.shape[0], B.shape[0]
    D = A.shape[1]

    A = torch.repeat_interleave(A, M, dim=0)
    B = torch.tile(B, (N, 1))
    
    assert A.shape == B.shape == (M*N, D)
    return A, B

## 2D
def test_2d():
    P = torch.tensor([[1, 2],
                      [3, 4]])
    V = torch.tensor([[10, 11],
                      [12, 13],
                      [14, 15]])
    P, V = prod(P, V)
    
    print(P)
    print(V)

## 3D
def test_3d():
    P = torch.tensor([[1, 2, 3], [4, 5, 6]])
    V = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
    P, V = prod(P, V)

    print(P)
    print(V)

test_2d()
test_3d()