# Surfaces SDF

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import torch
import torchlensmaker as tlm

def parabola(points):
    return 0.05*points[:, 1]**2 - points[:, 0]

def circle(points):
    r = torch.linalg.vector_norm(points, dim=1)
    return r - 5

def circle2(points):
    r = torch.linalg.vector_norm(points, dim=1)
    return torch.abs(r - 5)

def box(points):
    b = torch.tensor([6, 3])
    d = torch.abs(points) - b

    i = torch.clamp(d, min=0.0)
    return torch.linalg.vector_norm(i, dim=1) + torch.clamp(torch.max(d, dim=1)[0], max=0.0)

def sdf_surface(surface):
    def f(points):
        return surface.f(points)
    return f

# 12 plots
# for each surface class
# plot F(x,y) : check finite everywhere
# plot grad_x F (x,y)  : check finite everywhere
# plot grad_y F (x,y)  : check finite everywhere
# plot norm(grad): check non zero
# for few values of V, plot F_grad . V
# V: (0,1), (1,0), (a,a)
#    (0, -1), (-1, 0), (-a, a), (a, -a), (-a, -a)

def plot_sdf(F, F_grad=None):
    # Axes are:
    # ax_f: plot of F
    # ax_gfx: plot of X component of grad F
    # ax_gfy: plot of Y component of grad F
    f, axes = plt.subplots(1, 3, figsize=(15, 12))
    (ax_f, ax_gfx, ax_gfy) = axes
    
    # Create a grid of X and Y coordinates
    x = np.linspace(-10, 10, 500)
    y = np.linspace(-10, 10, 500)
    X, Y = np.meshgrid(x, y)
    
    # Combine X and Y into a single input tensor
    input_tensor = np.stack((X, Y), axis=-1).reshape(-1, 2)
    
    # Evaluate the function F on the input tensor
    Z = F(torch.tensor(input_tensor)).reshape(X.shape)

    #norm = colors.CenteredNorm()
    norm=colors.SymLogNorm(linthresh=0.05, linscale=0.05, vmin=-20.0, vmax=20.0, base=10)
    
    ax_f.pcolormesh(X, Y, Z, cmap='RdBu_r', norm=norm, shading='auto')
    #ax_f.contour(X, Y, Z.numpy(), colors="black", alpha=0.5, linewidths=1, norm=norm, levels=30)
    ax_f.set_title("f")

    if F_grad is not None:
        Zgrad = F_grad(torch.tensor(input_tensor))
        ax_gfx.pcolormesh(X, Y, Zgrad[:, 0].reshape(X.shape), cmap='RdBu_r', norm=norm, shading='auto')
        ax_gfy.pcolormesh(X, Y, Zgrad[:, 1].reshape(X.shape), cmap='RdBu_r', norm=norm, shading='auto')
        ax_gfx.set_title("(grad f).x")
        ax_gfy.set_title("(grad f).y")

    for ax in axes:
        ax.set_aspect("equal")
    
    plt.show()

def plot_sdf_surface(surface):
    return plot_sdf(surface.f, surface.f_grad)


#plot_sdf(parabola)
plot_sdf(circle)
#plot_sdf(circle2)

#print(tlm.Sphere(5, 5).extent_x())

#plot_sdf(box)
#plot_sdf_surface(tlm.Sphere2(3, 8))
plot_sdf_surface(tlm.Sphere(5, 5))
#plot_sdf_surface(tlm.Sphere(3, 8))
plot_sdf_surface(tlm.Sphere3(10, 8))


#plot_sdf_surface(tlm.Sphere(3, 9))
#plot_sdf_surface(tlm.Asphere(diameter=5, R=5, K=0, A4=0.))
#plot_sdf_surface(tlm.Asphere(diameter=20, R=-15, K=-1.2, A4=0.00045))
#plot_sdf_surface(tlm.Parabola(diameter=5, a=0.02))
