# Collision detection analysis - dataset

## Iterative collision detection for implicit surfaces

This is a detailed description of the collision detection method used in Torch Lens Maker. First, a bit of nomemclature:

* "algorithm" refers top one instance of Newton, Gradient Descent or Levenberg-Marquardt with associated configuration like number of iterations, maximum step size, damping parameter, etc.
* "method" refers to the overall collision detection procedure, which includes three phases which each use a single algorithm.

**Step 1: Initialization**

Initialize t values. Different initialization methods are available. During the optimization, each ray can be associated with multiple t values so that the search can progress from multiple starting values of t in parallel. This is akin to particle optimization, but here the search is quite simple it's one dimensional. Each of these is called a "beam" in the source code. 

So ultimately tensors in the code can have three dimensions:
* N, the number of rays
* H, the number of iteration steps
* B, the number of beams per ray

**Step 2: Coarse phase**

Run a fixed number of steps of algorithm A, with B beams for each ray. The goal here is to have at least one beam within a close distance to the global minimum.

**Step 3: Fine phase**

Starting from the best beam of the coarse phase, run a fixed number of steps of algorithm B with a single beam. The goal here is to refine the solution to a high degree of precision.

**Step 4: Differentiable step**

Run a single step of algorithm C. The goal here is to provide differentiability during torch backwards pass. Every step except this one is run under `torch.no_grad()`.

## Choosing the inner algorithm and their parameters

We want to use Newton because it has the fastest convergence. But it has one problem, it's undefined when the dot product is zero. This can happen quite frequently when the surface normal and the ray unit vector are orthogonal. To work around this, we use LM with a small damping factor, like 0.1. That way, when the dot product is close to zero, it's closer to gradient descent. When the dot product is far from zero, it's closer to Newton's method. The damping value also prevents overshooting the target and helps with discontinuities in the implicit function.

Use beam search and bbox sampling initialization to find global minimum and avoid oscilation or convergence to a local one.

In [None]:
import torchlensmaker as tlm
import torch
import torch.nn as nn
from pprint import pprint

from torchlensmaker.core.rot2d import rot2d
from torchlensmaker.testing.basic_transform import basic_transform
from torchlensmaker.core.transforms import IdentityTransform
from torchlensmaker.testing.collision_datasets import *
from torchlensmaker.core.collision_detection import Newton, GD, LM, CollisionMethod, init_zeros, init_best_axis

from torchlensmaker.testing.dataset_view import dataset_view

from torchlensmaker.core.geometry import unit2d_rot

import matplotlib.pyplot as plt


# analysis:
# given dataset expected collide
# list of algorithm
# for each algorithm:
# number of missing collisions
# distribution of number of iterations to converge


def convergence_plot(dataset, methods):
    "Plot convergence of collision detection for multiple algorithms"

    surface = dataset.surface
    P, V = dataset.P, dataset.V

    # move rays by a tiny bit, to avoid t=0 local minimum
    # that happens with constructed datasets
    # TODO augment dataset with different shifts
    #P, V = move_rays(P, V, 0.)
    
    fig, axes = plt.subplots(len(methods), 2, figsize=(10, 3*len(methods)), layout="tight", squeeze=False)

    for i, method in enumerate(methods):
        axQ1, axQ2 = axes[i]

        t_solve, t_history = method(surface, P, V, history=True)
        
        # Reshape tensors for broadcasting
        N, H = P.shape[0], t_history.shape[1]
        P_expanded = P.unsqueeze(1)  # Shape: (N, 1, 2)
        V_expanded = V.unsqueeze(1)  # Shape: (N, 1, 2)
        t_history_expanded = t_history.unsqueeze(2)  # Shape: (N, H, 1)
    
        # Compute points_history
        points_history = P_expanded + t_history_expanded * V_expanded  # Shape: (N, H, 2)
    
        assert t_history.shape == (N, H), (N, H)
        assert points_history.shape == (N, H, 2)
    
        # plot Q(t)
        for ray_index in range(t_history.shape[0]):
            axQ2.plot(range(t_history.shape[1]), surface.f(points_history[ray_index, :, :]))
        
        #axQ1.set_xlabel("iteration")
        #axQ1.set_ylabel("Q(t)", rotation=0)
        axQ1.set_title(f"{dataset.name} | {str(method)}")

        # plot total error
        axE = axQ2.twinx()
        axE.set_ylabel("error")
        axE.set_yscale("log")
        axE.set_ylim([1e-8, 100])

        residuals = torch.ones((N, H))
        for h in range(H):
            residuals[:, h] = surface.f(points_history[:, h, :])

        error = torch.sqrt(torch.sum(residuals**2, dim=0) / N)
        assert error.shape == (H,)
        axE.plot(error, label="error")
        axE.legend()


    return fig


def collision_statistics(surface, dataset, methods):
    "Compute and return collision statistics for a dataset and an algorithm"

    P, V = dataset.P, dataset.V

    for i, method in enumerate(methods):
        t_solve = method(surface, P, V, history=False)
            
        # Reshape tensors for broadcasting
        N = P.shape[0]
    
        # count number of collisions
        local_points = P + t_solve.unsqueeze(1).expand_as(V) * V
        residuals = surface.f(local_points)

        tol = 1e-6
        count = torch.sum(torch.abs(residuals) > tol).item()
        error = torch.sqrt(torch.sum(residuals**2) / N).item()     

        print(f"{str(method): <20} error={error:.8f} ({count} misses)")

    return

methods = [
    CollisionMethod(
        init=init_zeros,
        step0=Newton(0.8, max_iter=20, max_delta=10),
    ),

    CollisionMethod(
        init=init_best_axis,
        step0=Newton(0.8, max_iter=20, max_delta=10),
    ),
    
    CollisionMethod(
        init=init_zeros,
        step0=LM(0.1, max_iter=20, max_delta=10),
    ),

    CollisionMethod(
        init=init_best_axis,
        step0=LM(0.1, max_iter=20, max_delta=10),
    ),

    CollisionMethod(
        init=init_zeros,
        step0=LM(0.01, max_iter=20, max_delta=10),
    ),

    CollisionMethod(
        init=init_best_axis,
        step0=LM(0.01, max_iter=20, max_delta=10),
    ),
]


# (surface, generator, expected_collide, Optional[CollisionMethod])
test_cases = [
    (tlm.Sphere(30, 30), normal_rays(offset=9.0, N=25), True),
    (tlm.Sphere(30, 30), tangent_rays(offset=-0.6, N=15), True),
    (tlm.Sphere(30, 30), fixed_rays(direction=unit2d_rot(45), offset=30, N=15), True),
    (tlm.Sphere(30, 30), fixed_rays(direction=unit2d_rot(65), offset=30, N=15), True),
    (tlm.Sphere(30, 30), fixed_rays(direction=unit2d_rot(85), offset=30, N=15), True),
]

# initialization methods:
# zeros
# intersect with axis
# smart intersect with axis
# bbox sampling

for surface, gen, expected_collide in test_cases:
    dataset = gen(surface)

    # tlmviewer view
    dataset_view(surface, dataset)

    # statistics
    collision_statistics(surface, dataset, methods)
    
    # convergence plots
    fig = convergence_plot(dataset, methods)
    plt.show(fig)

# individual test:
# 1 surface
# 1 ray generator

# batch test:
# 1 surface
# many ray generators
