In [None]:
import pytest
from functools import partial

import torch

from torchlensmaker.implicit_surfaces.sag import (
    aspheric_sag_2d,
    conical_sag_2d,
    parabolic_sag_2d,
    spherical_sag_2d,
    sag_sum_2d,
)

import matplotlib.pyplot as plt


dtype, device = torch.float32, torch.device("cpu")

# sag = partial(spherical_sag_2d, C=torch.tensor(-1 / 4, dtype=dtype, device=device))
# sag = partial(conical_sag_2d, C=torch.tensor(-1 / 1.0, dtype=dtype, device=device), K=torch.tensor(-1.0, dtype=dtype, device=device))

sag = partial(
    sag_sum_2d,
    sags=[
        partial(
            spherical_sag_2d,
            C=torch.tensor(1 / 2.0, dtype=dtype, device=device),
        ),
        partial(
            aspheric_sag_2d,
            coefficients=torch.distributions.uniform.Uniform(-1.0, 1.0).sample((3,)),
        ),
    ],
)


def plot_sag_2d(sag, rspace):
    X, X_grad = sag(rspace)

    f, (ax0, ax1) = plt.subplots(1, 2)
    ax0.axvline(color="k", linewidth=0.5)
    ax0.axhline(color="k", linewidth=0.5)
    ax0.plot(X, rspace)
    ax0.set_aspect("equal")
    ax0.margins(x=2.0)

    ax1.plot(X_grad, rspace)
    ax1.set_aspect("equal")


rspace = torch.linspace(-1.0, 1.0, 1000)

plot_sag_2d(sag, rspace)