# Demo collision detection 3D

In [None]:
import torchlensmaker as tlm
import torch
import torch.nn
import math

import pprint



def make_random_rays(num_rays, start_x, end_x, max_y):
    rays_start = (torch.rand((num_rays, 3)) * 2 - 1) * max_y
    rays_start[:, 0] = start_x

    rays_end = (torch.rand((num_rays, 3)) * 2 - 1) * max_y
    rays_end[:, 0] = end_x

    rays_vectors = torch.nn.functional.normalize(rays_end - rays_start, dim=1)

    return torch.hstack((rays_start, rays_vectors))


test_rays = make_random_rays(
    num_rays=50,
    start_x=-15,
    end_x=50,
    max_y=20,
)

test_data = [
    (tlm.SurfaceTransform(1.0, "origin", [0., 10., 0.], [0., 0., 0.]), tlm.surfaces.Sphere(15.0, 1e6)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 0., 0.], [10., 0., -10.]), tlm.surfaces.Sphere(25.0, 20)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 0., 0.], [20., 20., 0.]), tlm.surfaces.Sphere(15.0, -10)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 0., 0.], [30., 0., 0.]), tlm.surfaces.Parabola(15., -0.05)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 0., 0.], [40., 0., 0.]), tlm.surfaces.Parabola(20., -0.04)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 0., 0.], [50., 0., 0.]), tlm.surfaces.Parabola(30., 0.02)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 10., -10.], [60., 0., 0.]), tlm.surfaces.Parabola(30., 0.05)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 0., 0.], [80., 0., 0.]), tlm.surfaces.CircularPlane(50.)),
    (tlm.SurfaceTransform(1.0, "origin", [0., 0., 0.], [5., 0., -5.]), tlm.surfaces.SquarePlane(15.)),
    
    (tlm.SurfaceTransform(1.0, "origin", [0.0, 10.0, -10.0], [100.0, 0.0, 0.0]), tlm.surfaces.Parabola(30.0, -0.05)),
    (tlm.SurfaceTransform(-1.0, "extent", [0.0, 20.0, -20.0], [100.0, 2.0, 5.0]), tlm.surfaces.Parabola(20.0, 0.05)),

    (tlm.SurfaceTransform(1.0, "origin", [0.0, 10.0, -10.0], [100.0, 0.0, 0.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (tlm.SurfaceTransform(1.0, "extent", [0.0, 20.0, -20.0], [100.0, 2.0, 5.0]), tlm.surfaces.Parabola(20.0, 0.05)),

    (tlm.SurfaceTransform(1.0, "extent", [0.0,  0.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (tlm.SurfaceTransform(1.0, "extent", [0.0, 10.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (tlm.SurfaceTransform(1.0, "extent", [0.0, 20.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (tlm.SurfaceTransform(1.0, "extent", [0.0, 30.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),
    (tlm.SurfaceTransform(1.0, "extent", [0.0, 40.0, 0.0], [50.0, 5.0, 5.0]), tlm.surfaces.Parabola(30.0, 0.05)),

    (tlm.SurfaceTransform(1.0, "origin", [0.0, 0.0, 0.0], [10.0, 0.0, 0.0]), tlm.surfaces.Parabola(30., 0.05)),
    (tlm.SurfaceTransform(1.0, "origin", [0.0, 0.0, 0.0], [10.0, 0.0, 0.0]), tlm.surfaces.SquarePlane(30.))
]

test_surfaces = [s for t, s in test_data]
test_transforms = [t for t, s in test_data]

# TODO
# more testing of transforms
# test 3D refraction / reflection


def demo(rays):

    all_points = torch.empty((0, 3))
    all_normals = torch.empty((0, 3))
    P, V = test_rays[:, :3], test_rays[:, 3:6]

    for transform, surface in test_data:

        points, normals = tlm.intersect(surface, P, V, transform)

        if points.numel() > 0:
            all_points = torch.cat((all_points, points), dim=0)
            all_normals = torch.cat((all_normals, normals), dim=0)

    groups = tlm.viewer.render(test_rays, all_points, all_normals, test_surfaces, test_transforms, rays_length=150)
    tlm.viewer.show(groups)


demo(test_rays)