In [None]:
import pytest
from functools import partial

import torch

from torchlensmaker.implicit_surfaces.sag_functions import (
    aspheric_sag_3d,
    conical_sag_3d,
    parabolic_sag_3d,
    spherical_sag_3d,
    xypolynomial_sag_3d,
)

import matplotlib.pyplot as plt


dtype, device = torch.float32, torch.device("cpu")

# sag = partial(spherical_sag_2d, C=torch.tensor(-1 / 4, dtype=dtype, device=device))
# sag = partial(conical_sag_3d, C=torch.tensor(-1 / 2.0, dtype=dtype, device=device), K=torch.tensor(0.0, dtype=dtype, device=device))

# sag = partial(
#     xypolynomial_sag_3d,
#     coefficients=torch.distributions.uniform.Uniform(-1.0, 1.0).sample((3, 3)),
# )

sag = partial(
    xypolynomial_sag_3d,
    coefficients=torch.distributions.uniform.Uniform(-1.0, 1.0).sample((3, 3)),
)

yspace = torch.linspace(-1.0, 1.0, 100)
zspace = torch.linspace(-1.0, 1.0, 100)


def plot_sag_3d(sag):
    y, z = torch.meshgrid(yspace, zspace, indexing="ij")

    x, G_grad = sag(y, z)
    dy, dz = G_grad.unbind(-1)

    f, axes = plt.subplots(1, 3)
    ax0, ax1, ax2 = axes

    ax0.pcolormesh(y, z, x)
    ax1.pcolormesh(y, z, dy)
    ax2.pcolormesh(y, z, dz)

    for ax in axes:
        ax.set_aspect("equal")


rspace = torch.linspace(-1.0, 1.0, 1000)

plot_sag_3d(sag)