In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import math
import torch
import torchlensmaker as tlm

from torch.nn.functional import normalize

import matplotlib as mpl

Tensor = torch.Tensor


# Create the input grid tensor
def sample_grid(xlim, ylim, N):
    x = np.linspace(xlim[0], xlim[1], N)
    y = np.linspace(ylim[0], ylim[1], N)
    X, Y = np.meshgrid(x, y)
    return X, Y, torch.tensor(np.stack((X, Y), axis=-1).reshape(-1, 2))


def Qplot(surface, P, V):
    N = 1000
    tspace = torch.linspace(-12, 12, N)

    # t plot
    tpoints = P.unsqueeze(0).expand((N, 2)) + tspace.unsqueeze(1).expand((N, 2)) * V.unsqueeze(0).expand((N, 2))
    Q = surface.f(tpoints)
    Qgrad = torch.sum(surface.f_grad(tpoints) * V, dim=1)

    # init_t = torch.zeros((1,))
    init_t = torch.tensor([-P[1] / V[1]])
    print("init_t", init_t)
    t_solve, t_history = surface.collision_algorithm(surface, P.unsqueeze(0), V.unsqueeze(0), init_t=init_t, history=True)
    
    points_history = P.unsqueeze(0) + t_history[0, :].squeeze(0).unsqueeze(1).expand((-1, 2)) * V.unsqueeze(0)
    final_point = P + t_solve[0].expand((2)) * V
    
    fig, axes = plt.subplot_mosaic([['upper left', 'right'],
                                  ['lower left', 'right']],
                                   figsize=(10, 5), layout="constrained")
    ax_qplot, ax_splot, ax_history = axes.values()

    # t plot: plot Q and Q grad
    ax_qplot.plot(tspace.detach().numpy(), Q.detach().numpy(), label="Q(t)")
    ax_qplot.plot(tspace.detach().numpy(), Qgrad.detach().numpy(), label="Q'(t)")
    ax_qplot.grid()
    ax_qplot.set_xlabel("t")
    ax_qplot.legend()

    # t plot: plot t history
    ax_qplot.scatter(t_history[0, :], surface.f(points_history), c=range(t_history.shape[1]), cmap="viridis", marker="o")

    # Surface plot: plot F
    X, Y, ppoints = sample_grid((-10, 10), (-10, 10), 200)
    norm = colors.SymLogNorm(linthresh=0.05, linscale=0.05, vmin=-20.0, vmax=20.0, base=10)
    ax_splot.pcolormesh(X, Y, surface.f(ppoints).reshape(X.shape), cmap='RdBu_r', norm=norm, shading='auto')

    # Surface plot: plot the line
    ax_splot.plot(tpoints[:, 0], tpoints[:, 1], color="black", linewidth=1, marker="none")

    # Surface plot: points history
    ax_splot.scatter(points_history[:, 0], points_history[:, 1],  c=range(t_history.shape[1]), cmap="viridis", marker="+")

    # History plot: plot F
    ax_history.plot(range(t_history.shape[1]), surface.f(points_history), label="F(P+tV)")
    ax_history.legend()
    ax_history.set_xlabel("iteration")
    ax_history.set_title(f"final F = {surface.f(final_point.unsqueeze(0))[0].item():.6f}")
    
    ax_splot.set_title("F(x,y)")
    ax_splot.set_aspect("equal")

    fig.suptitle(str(surface.collision_algorithm))

# vertical ray on Y axis with Newton 0.8
Qplot(tlm.Sphere(30, R=30, collision=tlm.Newton(max_iter=15, max_delta=10, damping=0.8)),
      P=torch.tensor([0.0000,   3.0000], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# vertical ray on Y axis with LM 0.1
Qplot(tlm.Sphere(30, R=30, collision=tlm.LM(max_iter=15, max_delta=10, damping=0.1)),
      P=torch.tensor([0.0000,   3.0000], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# vertical ray with Newton 0.8
Qplot(tlm.Sphere(30, R=30, collision=tlm.Newton(max_iter=15, max_delta=10, damping=0.8)),
      P=torch.tensor([4.0043, -11.9740], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# vertical ray with LM 0.1
Qplot(tlm.Sphere(30, R=30, collision=tlm.LM(max_iter=15, max_delta=10, damping=0.1)),
      P=torch.tensor([4.0043, -11.9740], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# fails to converge case with Newton 0.8
theta = math.radians(-80)
Qplot(tlm.Sphere(6, 3, collision=tlm.Newton(max_iter=10, max_delta=100, damping=0.8)),
      P=torch.tensor([1.8, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
)

# fails to converge case with Newton 0.1
theta = math.radians(-40)
Qplot(tlm.Sphere(6, 3, collision=tlm.Newton(max_iter=10, max_delta=100, damping=1.0)),
      P=torch.tensor([1.5, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
)

## LM slow convergence
theta = math.radians(10)
Qplot(tlm.Sphere(6, 3, collision=tlm.LM(max_iter=10, max_delta=100, damping=2.0)),
      P=torch.tensor([1.8, 0.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
)

## LM 
theta = math.radians(10)
Qplot(tlm.Sphere(6, 3, collision=tlm.LM(max_iter=10, max_delta=100, damping=1.0)),
      P=torch.tensor([1.8, 0.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
)

theta = math.radians(-40)
Qplot(tlm.Sphere(6, 3, collision=tlm.Newton(max_iter=10, max_delta=100, damping=0.9)),
      P=torch.tensor([1.5, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0),
)

theta = math.radians(-40)
Qplot(tlm.Sphere(6, 3, collision=tlm.Newton(max_iter=10, max_delta=100, damping=0.5)),
      P=torch.tensor([1.5, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0),
)

theta = math.radians(-40)
Qplot(tlm.Sphere(6, 3, collision=tlm.GD(max_iter=10, max_delta=100, step_size=0.1)),
      P=torch.tensor([1.5, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0),
)

theta = math.radians(-40)
Qplot(tlm.Sphere(6, 3, collision=tlm.LM(max_iter=10, max_delta=100, damping=0.5)),
      P=torch.tensor([1.5, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0),
)

theta = math.radians(-40)
Qplot(tlm.Sphere(6, 3, collision=tlm.LM(max_iter=10, max_delta=100, damping=1.5)),
      P=torch.tensor([1.5, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0),
)

theta = math.radians(-40)
Qplot(tlm.Sphere(6, 3, collision=tlm.LM(max_iter=10, max_delta=100, damping=2.5)),
      P=torch.tensor([1.5, 3.], dtype=torch.float64),
      V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0),
)


# theta = math.radians(-40)
# Qplot(tlm.Sphere(6, 3),
#       P=torch.tensor([-2.5, 5.], dtype=torch.float64),
#       V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
# )

# theta = math.radians(-40)
# Qplot(tlm.Sphere3(6, 3),
#       P=torch.tensor([1.5, 3.], dtype=torch.float64),
#       V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
# )

# theta = math.radians(15.0)
# Qplot(tlm.Sphere(6, 6),
#       P=torch.tensor([2., 2.], dtype=torch.float64),
#       V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
# )

# theta = math.radians(45.0)
# Qplot(tlm.Sphere(6, 6),
#       P=torch.tensor([2., 2.], dtype=torch.float64),
#       V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
# )

# theta = math.radians(95.0)
# Qplot(tlm.Sphere(6, 6),
#       P=torch.tensor([2., 2.], dtype=torch.float64),
#       V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
# )

# theta = math.radians(95.0)
# Qplot(tlm.Sphere(6, 6),
#       P=torch.tensor([2., 2.], dtype=torch.float64),
#       V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0),
#       collision_function=gd_debug
# )

# theta = math.radians(95.0)
# Qplot(tlm.Sphere(6, 6),
#       P=torch.tensor([-2., 2.], dtype=torch.float64),
#       V=normalize(torch.tensor([math.cos(theta), math.sin(theta)], dtype=torch.float64), dim=0)
# )