# Test local collide

Generate rays known to collide using ray generators in testing/collision_datasets.py and check that surface.local_collide correctly find all collisions.

In [None]:
import torchlensmaker as tlm
import torch
import torch.nn as nn
from pprint import pprint

from torchlensmaker.testing.basic_transform import basic_transform
from torchlensmaker.core.transforms import IdentityTransform
from torchlensmaker.testing.collision_datasets import (
    normal_rays,
    tangent_rays,
    random_direction_rays,
    fixed_rays
)

from torchlensmaker.core.collision_detection import CollisionMethod, Newton, LM, init_zeros, init_best_axis

from torchlensmaker.core.surfaces import CircularPlane, Sphere, SphereR

from torchlensmaker.core.geometry import rotated_unit_vector, unit3d_rot

import matplotlib.pyplot as plt

import sys
import traceback


from torchlensmaker.testing.dataset_view import dataset_view


def check_collide(surface, dataset, expected_collide):
    "Check that surface.local_collide finds all rays in the dataset to intersect the surface"

    # Call local_collide, rays in testing datasets are in local frame
    P, V = dataset.P, dataset.V
    N, D = P.shape
    t, local_normals, valid = surface.local_collide(P, V)
    local_points = P + t.unsqueeze(1).expand_as(V) * V

    # Check shapes
    assert t.dim() == 1 and t.shape[0] == N
    assert local_normals.dim() == 2 and local_normals.shape == (N, D)
    assert valid.dim() == 1 and valid.shape[0] == N
    assert local_points.dim() == 2 and local_points.shape == (N, D)

    # Check dtypes
    assert t.dtype == surface.dtype, (P.dtype, V.dtype, t.dtype, surface.dtype)
    assert local_normals.dtype == surface.dtype
    assert valid.dtype == torch.bool
    assert local_points.dtype == surface.dtype

    # Check isfinite
    assert torch.all(torch.isfinite(t)), t
    assert torch.all(torch.isfinite(local_normals))
    assert torch.all(torch.isfinite(valid))
    assert torch.all(torch.isfinite(local_points))
    
    # Check valid mask and contains() are all 'expected_collide'
    assert torch.all(valid == expected_collide), torch.sum(valid == expected_collide).item()
    assert torch.all(surface.contains(local_points) == expected_collide)
    
    # Check all normals are unit vectors
    assert torch.allclose(torch.linalg.vector_norm(local_normals, dim=1), torch.ones(1, dtype=surface.dtype))

def unit(theta):
    return rotated_unit_vector(torch.deg2rad(torch.tensor([theta], dtype=torch.float64)), dim=2).squeeze(0)


# (surface, generator, expected_collide)
test_cases = [
    # Plane
    (CircularPlane(30), normal_rays(dim=2, N=10, offset=5.0), True),
    
    # Sphere with curvature parameterization
    (Sphere(30, R=30), normal_rays(dim=2, N=25, offset=3.0), True),
    (Sphere(30, R=30), normal_rays(dim=2, N=25, offset=0.0), True),
    (Sphere(30, R=30), normal_rays(dim=2, N=25, offset=-3.0), True),
    (Sphere(30, R=30), tangent_rays(dim=2, N=25, offset=-0.6), True),
    (Sphere(30, R=30), tangent_rays(dim=2, N=25, offset=-2.0), True),
    (Sphere(30, R=30), tangent_rays(dim=2, N=25, offset=-4.0), True),
    (Sphere(30, R=30), tangent_rays(dim=2, N=25, offset=4.0), False),
    (Sphere(30, R=30), random_direction_rays(dim=2, N=25, offset=10.0), True),

    (Sphere(30, R=-30), normal_rays(dim=2, N=25, offset=3.0), True),
    (Sphere(30, R=-30), normal_rays(dim=2, N=25, offset=0.0), True),
    (Sphere(30, R=-30), normal_rays(dim=2, N=25, offset=-3.0), True),
    (Sphere(30, R=-30), tangent_rays(dim=2, N=25, offset=0.6), True),
    (Sphere(30, R=-30), tangent_rays(dim=2, N=25, offset=2.0), True),
    (Sphere(30, R=-30), tangent_rays(dim=2, N=25, offset=4.0), True),
    (Sphere(30, R=-30), tangent_rays(dim=2, N=25, offset=-4.0), False),
    (Sphere(30, R=-30), random_direction_rays(dim=2, N=25, offset=10.0), True),

    (Sphere(30, C=0), normal_rays(dim=2, N=25, offset=3.0), True),
    (Sphere(30, C=0), normal_rays(dim=2, N=25, offset=0.0), True),
    (Sphere(30, C=0), normal_rays(dim=2, N=25, offset=-3.0), True),

    ##

    # TODO same sign for offset in sphere and sphereR
    
    # Sphere with radius parameterization
    (SphereR(30, R=30), normal_rays(dim=2, N=25, offset=3.0), True),
    (SphereR(30, R=30), normal_rays(dim=2, N=25, offset=0.0), True),
    (SphereR(30, R=30), normal_rays(dim=2, N=25, offset=-3.0), True),
    (SphereR(30, R=30), tangent_rays(dim=2, N=25, offset=-0.6), True),
    (SphereR(30, R=30), tangent_rays(dim=2, N=25, offset=-2.0), True),
    (SphereR(30, R=30), tangent_rays(dim=2, N=25, offset=-4.0), True),
    (SphereR(30, R=30), tangent_rays(dim=2, N=25, offset=4.0), False),
    (SphereR(30, R=30), random_direction_rays(dim=2, N=25, offset=10.0), True),

    (SphereR(30, R=-30), normal_rays(dim=2, N=25, offset=3.0), True),
    (SphereR(30, R=-30), normal_rays(dim=2, N=25, offset=0.0), True),
    (SphereR(30, R=-30), normal_rays(dim=2, N=25, offset=-3.0), True),
    (SphereR(30, R=-30), tangent_rays(dim=2, N=25, offset=-0.6), True),
    (SphereR(30, R=-30), tangent_rays(dim=2, N=25, offset=-2.0), True),
    (SphereR(30, R=-30), tangent_rays(dim=2, N=25, offset=-4.0), True),
    (SphereR(30, R=-30), tangent_rays(dim=2, N=25, offset=4.0), False),
    (SphereR(30, R=-30), random_direction_rays(dim=2, N=25, offset=10.0), True),

    # Exact half sphere, tangent, normal and random rays
    (SphereR(30, R=15), tangent_rays(dim=2, N=25, offset=4.0), False),
    (SphereR(30, R=15), tangent_rays(dim=2, N=25, offset=-4.0), True),
    (SphereR(30, R=15), normal_rays(dim=2, N=25, offset=4.0), True),
    (SphereR(30, R=15), random_direction_rays(dim=2, N=25, offset=10.0), True),

    # Exact half sphere, horizontal rays
    (SphereR(30, R=15), fixed_rays(dim=2, N=25, direction=torch.tensor([1.0, 0.0]), offset=1.0), True),
    (SphereR(30, R=-15), fixed_rays(dim=2, N=25, direction=torch.tensor([1.0, 0.0]), offset=1.0), True),
    (SphereR(30, R=-15), fixed_rays(dim=2, N=25, direction=torch.tensor([1.0, 0.0]), offset=-50.0), True),
    (SphereR(30, R=15), fixed_rays(dim=2, N=25, direction=torch.tensor([1.0, 0.0]), offset=-50.0), True),

    # Exact half sphere, vertical rays
    (SphereR(30, R=15), fixed_rays(dim=2, N=25, direction=torch.tensor([0.0, 1.0]), offset=1.0), True),
    (SphereR(30, R=-15), fixed_rays(dim=2, N=25, direction=torch.tensor([0.0, 1.0]), offset=1.0), True),
    (SphereR(30, R=-15), fixed_rays(dim=2, N=25, direction=torch.tensor([0.0, 1.0]), offset=-50.0), True),
    (SphereR(30, R=15), fixed_rays(dim=2, N=25, direction=torch.tensor([0.0, 1.0]), offset=-50.0), True),
]

newton_zeros = CollisionMethod(
    init=init_zeros,
    step0=Newton(damping=0.8, max_iter=15, max_delta=10),
)

newton_best_axis = CollisionMethod(
    init=init_best_axis,
    step0=Newton(damping=0.8, max_iter=15, max_delta=10),
)

lm_zeros = CollisionMethod(
    init=init_zeros,
    step0=LM(damping=0.1, max_iter=15, max_delta=10),
)

lm_best_axis = CollisionMethod(
    init=init_best_axis,
    step0=LM(damping=0.1, max_iter=15, max_delta=10),
)


test_cases = [
    # Failing cases so far with LM
    (tlm.Sphere(30, 30), tangent_rays(dim=2, N=15, offset=-0.6), True),
    (tlm.Sphere(30, 30), fixed_rays(dim=2, N=15, direction=unit(45), offset=30), True),
    (tlm.Sphere(30, 30), fixed_rays(dim=2, N=15, direction=unit(65), offset=30), True),
    (tlm.Sphere(30, 30), fixed_rays(dim=2, N=15, direction=unit(85), offset=30), True),

    # Failing with Newton init_best_axis, because of nan in dot product
    #(tlm.Sphere(30, 30), fixed_rays(direction=torch.tensor([0., 1.0]), offset=30, N=15), True),

    # Failing with Newton init_zeros
    (tlm.Sphere(30, 16, collision_method=newton_zeros), fixed_rays(dim=2, N=25, direction=unit(90), offset=50), True),
    (tlm.Sphere(30, 16, collision_method=newton_zeros), fixed_rays(dim=2, N=25, direction=unit(80), offset=50), True),
    (tlm.Sphere(30, 16, collision_method=newton_zeros), fixed_rays(dim=2, N=25, direction=unit(70), offset=50), True),
    (tlm.Sphere(30, 16, collision_method=newton_zeros), fixed_rays(dim=2, N=25, direction=unit(60), offset=50), True),
    (tlm.Sphere(30, 16, collision_method=newton_zeros), fixed_rays(dim=2, N=25, direction=unit(50), offset=50), True),
    (tlm.Sphere(30, 16, collision_method=newton_zeros), fixed_rays(dim=2, N=25, direction=unit(40), offset=50), True),
]

test_cases = []

show_all = True

for surface, gen, expected_collide in test_cases:
    dataset = gen(surface)

    if show_all:
        dataset_view(surface, dataset)

    # check collisions
    try:
        check_collide(surface, dataset, expected_collide)
    except AssertionError as err:
        _, _, tb = sys.exc_info()
        traceback.print_tb(tb)

        # tlmviewer view
        print("Test failed")
        print("dataset:", dataset.name)
        print("expected_collide:", expected_collide)
        print("AssertionError:", err)
        dataset_view(surface, dataset)
        # TODO add convergence visualization here


In [None]:
test_cases = [
    (SphereR(30, R=15), fixed_rays(dim=3, N=64, direction=torch.tensor([0.0, 1.0, 0.0]), offset=1.0), True),
    (SphereR(30, R=15), fixed_rays(dim=3, N=64, direction=torch.tensor([1.0, 0.0, 0.0]), offset=1.0), True),
    (SphereR(30, R=15), fixed_rays(dim=3, N=64, direction=torch.tensor([0.0, 0.0, 1.0]), offset=1.0), True),
    (SphereR(30, R=15), fixed_rays(dim=3, N=64, direction=torch.tensor([1.0, 1.0, 1.0]), offset=1.0), True),
    (SphereR(30, R=15), fixed_rays(dim=3, N=64, direction=torch.tensor([-1.0, -1.0, -1.0]), offset=1.0), True),

    (Sphere(30, R=16), fixed_rays(dim=3, N=64, direction=torch.tensor([0.0, 1.0, 0.0]), offset=1.0), True),
    (Sphere(30, R=16), fixed_rays(dim=3, N=64, direction=torch.tensor([1.0, 0.0, 0.0]), offset=1.0), True),
    (Sphere(30, R=16), fixed_rays(dim=3, N=64, direction=torch.tensor([0.0, 0.0, 1.0]), offset=1.0), True),
    (Sphere(30, R=16), fixed_rays(dim=3, N=64, direction=torch.tensor([1.0, 1.0, 1.0]), offset=1.0), True),
    (Sphere(30, R=16), fixed_rays(dim=3, N=64, direction=torch.tensor([-1.0, -1.0, -1.0]), offset=1.0), True),
]

for surface, gen, expected_collide in test_cases:
    dataset = gen(surface)

    dataset_view(surface, dataset)