## Visualize implicit solver

In [None]:
import torch
import torchlensmaker as tlm

import matplotlib.pyplot as plt
import matplotlib.colors as colors

from functools import partial

from jaxtyping import Float

In [None]:
dtype, device = torch.float32, torch.device("cpu")

#sag = partial(tlm.parabolic_sag_2d, A=0.4)
#sag = partial(tlm.spherical_sag_2d, C=torch.tensor(1/1.5))

sag = partial(tlm.sag_sum_2d, sags=[
    partial(tlm.aspheric_sag_2d, coefficients=torch.tensor([0.5, -0.05])),
    partial(tlm.conical_sag_2d, C=torch.tensor(-1/2.0), K=torch.tensor(0.5))
])

# Init rays
rays_y = torch.linspace(-1.0, 1.0, 100, dtype=dtype)
rays_theta = torch.deg2rad(torch.linspace(-10.0, 10.0, 100))
P = torch.stack((torch.full_like(rays_y, -0.5), rays_y), dim=-1)
V = torch.stack((torch.cos(rays_theta), torch.sin(rays_theta)), dim=-1)


scene = tlm.new_scene("2D")

# render rays
start = P
end = P + 2.0*V
scene["data"].append(tlm.render_rays(start, end, 0, {}, {}))

def render_iter(num_iter, color):
    # solve
    t = tlm.implicit_solver_newton(P, V, tlm.sag_to_implicit_2d(sag), num_iter)

    # render collision points
    cp = P + t.unsqueeze(-1)*V
    node = tlm.render_points(cp, color, radius=0.005)
    
    return node

scene["data"].append(render_iter(1, "red"))
scene["data"].append(render_iter(2, "white"))
scene["data"].append(render_iter(3, "blue"))
render_iter(3, "white")

tlm.display_scene(scene)