In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import math
import torch
import torchlensmaker as tlm

from torch.nn.functional import normalize

from torchlensmaker.core.collision_detection import Newton, GD, LM, CollisionMethod, init_zeros, init_best_axis

from torchlensmaker.testing.collision_datasets import CollisionDataset
from torchlensmaker.testing.dataset_view import dataset_view

import matplotlib as mpl

from IPython.display import display, HTML

Tensor = torch.Tensor


# Create the input grid tensor
def sample_grid(xlim, ylim, N):
    x = np.linspace(xlim[0], xlim[1], N)
    y = np.linspace(ylim[0], ylim[1], N)
    X, Y = np.meshgrid(x, y)
    return X, Y, torch.tensor(np.stack((X, Y), axis=-1).reshape(-1, 2))


def analysis_single_ray(surface, P, V):
    dataset = CollisionDataset("", surface, P.unsqueeze(0), V.unsqueeze(0))
    dataset_view(surface, dataset, rays_length=100)

    t_solve, t_history = surface.collision_method(surface, P.unsqueeze(0), V.unsqueeze(0), history=True)
    t_min = t_history.min().item()
    t_max = t_history.max().item()

    N = 1000
    tspace = torch.linspace(t_min - (t_max - t_min), t_max + (t_max - t_min), N)

    # t plot
    tpoints = P.unsqueeze(0).expand((N, 2)) + tspace.unsqueeze(1).expand((N, 2)) * V.unsqueeze(0).expand((N, 2))
    Q = surface.f(tpoints)
    Qgrad = torch.sum(surface.f_grad(tpoints) * V, dim=1)
    
    points_history = P.unsqueeze(0) + t_history[0, :].squeeze(0).unsqueeze(1).expand((-1, 2)) * V.unsqueeze(0)
    final_point = P + t_solve[0].expand((2)) * V
    
    fig, axes = plt.subplots(2, 1, figsize=(10, 5))
    fig.tight_layout(pad=3, w_pad=1.8, h_pad=3)
    ax_t, ax_iter = axes

    # t plot: plot Q and Q grad
    ax_t.plot(tspace.detach().numpy(), Q.detach().numpy(), label="Q(t)=F(P+tV)")
    ax_t.plot(tspace.detach().numpy(), Qgrad.detach().numpy(), label="Q'(t) = F_grad . V")
    ax_t.grid()
    ax_t.set_xlabel("t")
    ax_t.legend()

    # t plot: plot t history
    ax_t.scatter(t_history[0, :], surface.f(points_history), c=range(t_history.shape[1]), cmap="viridis", marker="o")

    # History plot: plot F
    F_history = surface.f(points_history)
    ax_iter.plot(range(t_history.shape[1]), torch.abs(F_history), label="|F(P+tV)|")
    ax_iter.legend()
    ax_iter.set_xlabel("iteration")
    ax_iter.set_title(f"final F = {surface.f(final_point.unsqueeze(0))[0].item():.6f}")
    ax_iter.set_yscale("log")
    ax_iter.grid()
    ax_iter.set_ylim([1e-8, 100])

    fig.suptitle(surface.testname() + " " + str(surface.collision_method))
    plt.show(fig)
    display(HTML("<hr/>"))


newton08_zeros = CollisionMethod(
    init=init_zeros,
    step0=Newton(damping=0.8, max_iter=15, max_delta=10),
)

newton08_best_axis = CollisionMethod(
    init=init_best_axis,
    step0=Newton(damping=0.8, max_iter=15, max_delta=10),
)

lm01_zeros = CollisionMethod(
    init=init_zeros,
    step0=LM(damping=0.1, max_iter=50, max_delta=10),
)

lm01_best_axis = CollisionMethod(
    init=init_best_axis,
    step0=LM(damping=0.1, max_iter=50, max_delta=10),
)
    
# vertical ray on Y axis with Newton 0.8 - init zero
analysis_single_ray(tlm.Sphere(30, R=30, collision_method=newton08_zeros),
      P=torch.tensor([0.0000,   3.0000], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# vertical ray on Y axis with Newton 0.8 - init best axis        !! nan
analysis_single_ray(tlm.Sphere(30, R=30, collision_method=newton08_best_axis),
      P=torch.tensor([0.0000,   3.0000], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# vertical ray on Y axis with LM 0.1
analysis_single_ray(tlm.Sphere(30, R=30, collision_method=lm01_zeros),
      P=torch.tensor([0.0000,   3.0000], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# vertical ray with Newton 0.8
analysis_single_ray(tlm.Sphere(30, R=30, collision_method=newton08_zeros),
      P=torch.tensor([4.0043, -11.9740], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)

# vertical ray with LM 0.1
analysis_single_ray(tlm.Sphere(30, R=30, collision_method=lm01_zeros),
      P=torch.tensor([4.0043, -11.9740], dtype=torch.float64),
      V=torch.tensor([0., 1.], dtype=torch.float64),
)
