# Algorithm compare - Sphere

In [None]:
# TODO

# Surface testing (without local_collide)
# - samples return contains() true
# - normals() of samples are finite and unit vectors

# Surface testing (local_collide)
# - rays generators

# Implicit Surface testing:
# - F and F grad should be finite everywhere
# - F should be zero on samples
# - F should be non zero outside of bounding sphere

In [None]:
import torchlensmaker as tlm
import torch
import torch.nn as nn
from pprint import pprint

from torchlensmaker.core.rot2d import rot2d
from torchlensmaker.testing.basic_transform import basic_transform
from torchlensmaker.core.transforms import IdentityTransform
from torchlensmaker.testing.collision_datasets import *
from torchlensmaker.core.collision_detection import Newton, GD, LM, CollisionMethod, init_zeros, init_best_axis

from torchlensmaker.testing.dataset_view import dataset_view

import matplotlib.pyplot as plt


# analysis:
# given dataset expected collide
# list of algorithm
# for each algorithm:
# number of missing collisions
# distribution of number of iterations to converge


def convergence_plot(dataset, methods):
    "Plot convergence of collision detection for multiple algorithms"

    surface = dataset.surface
    P, V = dataset.P, dataset.V

    # move rays by a tiny bit, to avoid t=0 local minimum
    # that happens with constructed datasets
    # TODO augment dataset with different shifts
    #P, V = move_rays(P, V, 0.)
    
    fig, axes = plt.subplots(len(methods), 1, figsize=(10, 3*len(methods)), layout="tight", squeeze=False)

    for i, method in enumerate(methods):
        axQ = axes.flat[i]

        t_solve, t_history = method(surface, P, V, history=True)
        
        # Reshape tensors for broadcasting
        N, H = P.shape[0], t_history.shape[1]
        P_expanded = P.unsqueeze(1)  # Shape: (N, 1, 2)
        V_expanded = V.unsqueeze(1)  # Shape: (N, 1, 2)
        t_history_expanded = t_history.unsqueeze(2)  # Shape: (N, H, 1)
    
        # Compute points_history
        points_history = P_expanded + t_history_expanded * V_expanded  # Shape: (N, H, 2)
    
        assert t_history.shape == (N, H), (N, H)
        assert points_history.shape == (N, H, 2)
    
        # plot Q(t)
        for ray_index in range(t_history.shape[0]):
            axQ.plot(range(t_history.shape[1]), surface.f(points_history[ray_index, :, :]))
        
        axQ.set_xlabel("iteration")
        axQ.set_ylabel("Q(t)", rotation=0)
        axQ.set_title(f"{dataset.name} | {str(method)}")

        # plot total error
        axE = axQ.twinx()
        axE.set_ylabel("error")
        axE.set_yscale("log")
        axE.set_ylim([1e-8, 1e-3])

        residuals = torch.ones((N, H))
        for h in range(H):
            residuals[:, h] = surface.f(points_history[:, h, :])

        error = torch.sqrt(torch.sum(residuals**2, dim=0) / N)
        assert error.shape == (H,)
        axE.plot(error, label="error")
        axE.legend()


    return fig


def collision_statistics(surface, dataset, methods):
    "Compute and return collision statistics for a dataset and an algorithm"

    P, V = dataset.P, dataset.V

    for i, method in enumerate(methods):
        t_solve, t_history = method(surface, P, V, history=True)
            
        # Reshape tensors for broadcasting
        N, H = P.shape[0], t_history.shape[1]
        P_expanded = P.unsqueeze(1)  # Shape: (N, 1, 2)
        V_expanded = V.unsqueeze(1)  # Shape: (N, 1, 2)
        t_history_expanded = t_history.unsqueeze(2)  # Shape: (N, H, 1)
    
        # Compute points_history
        points_history = P_expanded + t_history_expanded * V_expanded  # Shape: (N, H, 2)
    
        assert t_history.shape == (N, H), (N, H)
        assert points_history.shape == (N, H, 2)

        # count number of collisions
        local_points = points_history[:, -1, :]
        residuals = surface.f(local_points)

        tol = 1e-6
        count = torch.sum(torch.abs(residuals) > tol).item()
        error = torch.sqrt(torch.sum(residuals**2) / N).item()
        # error function of iterations
        

        print(f"{str(method): <20} error={error:.8f} ({count} misses)")

    return

def unit(theta):
    v = torch.tensor([1.0, 0.0])
    return rot2d(v, torch.deg2rad(torch.as_tensor(theta)))

# log plot error history (with fixed low ylim range like [0, 0.01], scaled by surface diameter?)

# algo idea:
# 1. setup multiple beam starts with bounding sphere diameter sampling
# 2. run N steps with max delta: sampling step size -- big steps, get in local minimum region
# 3. keep only best beam
# 4. run M steps with smaller max delta -- finer adjustement
# 5. run 1 step for backwards

methods = [
    CollisionMethod(
        init=init_zeros,
        step0=Newton(0.8, max_iter=20, max_delta=10),
    ),

    CollisionMethod(
        init=init_best_axis,
        step0=Newton(0.8, max_iter=20, max_delta=10),
    ),
    
    CollisionMethod(
        init=init_zeros,
        step0=LM(0.1, max_iter=20, max_delta=10),
    ),

    CollisionMethod(
        init=init_best_axis,
        step0=LM(0.1, max_iter=20, max_delta=10),
    ),
]

surface = tlm.Sphere(30, 30)

generators = [
    normal_rays(offset=9.0, N=25),
    tangent_rays(offset=-0.6, N=15),
    fixed_rays(direction=unit(45), offset=30, N=15),
    fixed_rays(direction=unit(65), offset=30, N=15),
    fixed_rays(direction=unit(85), offset=30, N=15),
]

# initialization methods:
# zeros
# intersect with axis
# smart intersect with axis
# bbox sampling

for gen in generators:
    dataset = gen(surface)

    # tlmviewer view
    dataset_view(surface, dataset)

    # statistics
    collision_statistics(surface, dataset, methods)
    
    # convergence plots
    fig = convergence_plot(dataset, methods)
    plt.show(fig)

# individual test:
# 1 surface
# 1 ray generator

# batch test:
# 1 surface
# many ray generators
