# Test local collide

Generate rays known to collide using ray generators in testing/collision_datasets.py and check that surface.local_collide correctly find all collisions.

In [None]:
import torchlensmaker as tlm
import torch
import torch.nn as nn
from pprint import pprint

from torchlensmaker.testing.basic_transform import basic_transform
from torchlensmaker.core.transforms import IdentityTransform
from torchlensmaker.testing.collision_datasets import *

import matplotlib.pyplot as plt

import sys
import traceback


def dataset_view(surface, dataset, rays_length=100):
    "View a collision dataset testcase with tlmviewer"

    P, V = dataset.P, dataset.V
    # TODO display points at P to see rays origins
    # TODO display collision and normals

    t, local_normals, valid = surface.local_collide(P, V)

    local_points = P + t.unsqueeze(1).expand_as(V) * V

    scene = tlm.viewer.new_scene("2D")
    scene["data"].extend(tlm.viewer.render_collisions(local_points, local_normals))
    scene["data"].append(tlm.viewer.render_points(P, color="grey"))

    rays_start = P - rays_length*V
    rays_end = P + rays_length*V
    scene["data"].append(
        tlm.viewer.render_rays(rays_start, rays_end, layer=0)
    )

    assert torch.all(torch.isfinite(dataset.P))
    assert torch.all(torch.isfinite(dataset.V))

    scene["data"].append(tlm.viewer.render_surfaces([surface], [IdentityTransform(dim=2, dtype=surface.dtype)], dim=2))
    scene["title"] = dataset.name
    tlm.viewer.ipython_display(scene)
    #tlm.viewer.dump(scene, ndigits=2)


def check_collide(surface, dataset, expected_collide):
    "Check that surface.local_collide finds all rays in the dataset to intersect the surface"

    # Call local_collide, rays in testing datasets are in local frame
    P, V = dataset.P, dataset.V
    N, D = P.shape
    t, local_normals, valid = surface.local_collide(P, V)
    local_points = P + t.unsqueeze(1).expand_as(V) * V

    # Check shapes
    assert t.dim() == 1 and t.shape[0] == N
    assert local_normals.dim() == 2 and local_normals.shape == (N, D)
    assert valid.dim() == 1 and valid.shape[0] == N
    assert local_points.dim() == 2 and local_points.shape == (N, D)

    # Check dtypes
    assert t.dtype == surface.dtype
    assert local_normals.dtype == surface.dtype
    assert valid.dtype == torch.bool
    assert local_points.dtype == surface.dtype
    
    # Check valid mask is all 'expected_collide'
    assert torch.all(valid == expected_collide), torch.sum(valid == expected_collide).item()
    assert torch.all(surface.contains(local_points) == expected_collide)
    
    # Check all normals are unit vectors
    assert torch.allclose(torch.linalg.vector_norm(local_normals, dim=1), torch.ones(1, dtype=surface.dtype))


# (surface, generator, expected_collide)
test_cases = [
    (Sphere(30, 30), normal_rays(offset=3.0, N=25), True),
    (Sphere(30, 30), normal_rays(offset=0.0, N=25), True),
    (Sphere(30, 30), normal_rays(offset=-3.0, N=25), True),
    (Sphere(30, 30), tangent_rays(offset=-0.6, N=25), True),
    (Sphere(30, 30), tangent_rays(offset=-2.0, N=25), True),
    (Sphere(30, 30), tangent_rays(offset=-4.0, N=25), True),
    (Sphere(30, 30), tangent_rays(offset=4.0, N=25), False),
    (Sphere(30, 30), random_direction_rays(offset=10.0, N=25), True),

    (Sphere(30, -30), normal_rays(offset=3.0, N=25), True),
    (Sphere(30, -30), normal_rays(offset=0.0, N=25), True),
    (Sphere(30, -30), normal_rays(offset=-3.0, N=25), True),
    (Sphere(30, -30), tangent_rays(offset=0.6, N=25), True),
    (Sphere(30, -30), tangent_rays(offset=2.0, N=25), True),
    (Sphere(30, -30), tangent_rays(offset=4.0, N=25), True),
    (Sphere(30, -30), tangent_rays(offset=-4.0, N=25), False),
    (Sphere(30, -30), random_direction_rays(offset=10.0, N=25), True),

    (Sphere3(30, 30), normal_rays(offset=3.0, N=25), True),
    (Sphere3(30, 30), normal_rays(offset=0.0, N=25), True),
    (Sphere3(30, 30), normal_rays(offset=-3.0, N=25), True),
    (Sphere3(30, 30), tangent_rays(offset=-0.6, N=25), True),
    (Sphere3(30, 30), tangent_rays(offset=-2.0, N=25), True),
    (Sphere3(30, 30), tangent_rays(offset=-4.0, N=25), True),
    (Sphere3(30, 30), tangent_rays(offset=4.0, N=25), False),
    (Sphere3(30, 30), random_direction_rays(offset=10.0, N=25), True),

    (Sphere3(30, -30), normal_rays(offset=3.0, N=25), True),
    (Sphere3(30, -30), normal_rays(offset=0.0, N=25), True),
    (Sphere3(30, -30), normal_rays(offset=-3.0, N=25), True),
    (Sphere3(30, -30), tangent_rays(offset=0.6, N=25), True),
    (Sphere3(30, -30), tangent_rays(offset=2.0, N=25), True),
    (Sphere3(30, -30), tangent_rays(offset=4.0, N=25), True),
    (Sphere3(30, -30), tangent_rays(offset=-4.0, N=25), False),
    (Sphere3(30, -30), random_direction_rays(offset=10.0, N=25), True),

    (Sphere(30, 1e6), normal_rays(offset=3.0, N=25), True),
    (Sphere(30, 1e6), normal_rays(offset=0.0, N=25), True),
    (Sphere(30, 1e6), normal_rays(offset=-3.0, N=25), True),
    
    #(Sphere(30, 15), tangent_rays(offset=-4.0, N=25), False),
    #(Sphere(30, 15), random_direction_rays(offset=10.0, N=25), True),
]

# bug: Sphere3 f_grad when R is negative

for surface, gen, expected_collide in test_cases:
    dataset = gen(surface)

    # check collisions
    try:
        check_collide(surface, dataset, expected_collide)
    except AssertionError as err:
        _, _, tb = sys.exc_info()
        traceback.print_tb(tb)

        # tlmviewer view
        print("Test failed")
        print("dataset:", dataset.name)
        print("expected_collide:", expected_collide)
        print("AssertionError:", err)
        dataset_view(surface, dataset)
