In [None]:
import torch
import torchlensmaker as tlm

import matplotlib.pyplot as plt
import matplotlib.colors as colors

from functools import partial


def plot_sdf(func):
    fig, (ax0, ax1, ax2) = plt.subplots(1, 3)
    
    xspace = torch.linspace(-20, 20, 200)
    yspace = torch.linspace(-20, 20, 200)
    
    X, Y = torch.meshgrid(xspace, yspace)
    
    points = torch.stack((X, Y), dim=-1)
    
    F, F_grad = func(points)
    
    norm = colors.SymLogNorm(linthresh=0.05, linscale=1.0, vmin=-20.0, vmax=20.0, base=10)
    ax0.pcolormesh(X, Y, F, cmap='seismic', norm=norm, shading='gouraud')
    ax0.set_aspect("equal")
    
    ax1.pcolormesh(X, Y, F_grad[..., 0])
    ax2.pcolormesh(X, Y, F_grad[..., 1])
    ax1.set_aspect("equal")
    ax2.set_aspect("equal")

dtype, device = torch.float32, torch.device("cpu")

# sag = partial(tlm.parabolic_sag_2d, A=0.4)

plot_sdf(tlm.sag_to_implicit_2d(partial(tlm.spherical_sag_2d, C=torch.tensor(1/0.8))))

plot_sdf(tlm.sag_to_implicit_2d(partial(tlm.aspheric_sag_2d, coefficients=torch.distributions.uniform.Uniform(-1.0, 1.0).sample((3,)))))

plot_sdf(tlm.sag_to_implicit_2d(
        partial(
            tlm.conical_sag_2d,
            C=torch.tensor(1 / 15.0, dtype=dtype, device=device),
            K=torch.tensor(0.0, dtype=dtype, device=device),
        )
))