In [None]:
import torchlensmaker as tlm
import torch
import math

from torchlensmaker.testing.collision_datasets import FixedRays

from typing import TypeAlias, Any

Tensor = torch.Tensor

from torchlensmaker.core.collision_detection import init_brd, CollisionAlgorithm, surface_f


def view_beams(surface, P, V, t, rays_length=100):

    N, D = P.shape
    B = t.shape[0]
    assert t.shape[1] == P.shape[0]
    
    # :: (B, N)

    # :: (B, N, D)
    points = P.expand((B, -1, -1)) + t.unsqueeze(-1).expand((B, N, D)) * V.expand((B, -1, -1))
    assert points.shape == (B, N, D)
    
    points = points.reshape((-1, D))
    assert points.shape == (B*N, D)
    
    scene = tlm.viewer.new_scene("2D" if D == 2 else "3D")
    scene["data"].append(tlm.viewer.render_surface(surface, D))
    
    rays_start = P - rays_length*V
    rays_end = P + rays_length*V
    scene["data"].append(
        tlm.viewer.render_rays(rays_start, rays_end, layer=0)
    )

    scene["data"].append(tlm.viewer.render_points(points))
    

    tlm.viewer.ipython_display(scene)


###########

def demo():
    surface = tlm.Sphere(30, 16)
    generator = FixedRays(dim=2, N=30, direction=tlm.unit2d_rot(40), offset=30, epsilon=0.05)

    # TODO ImplicitSurface.bounding_radius()
    br = math.sqrt((surface.diameter/2)**2 + surface.extent_x()**2)
    
    P, V = generator(surface)
    N, D = P.shape
    B = 8
    algo = tlm.Newton(damping=0.8, max_iter=10, max_delta=br / (B-1))
    
    init_t = init_brd(surface, P, V, B)
    
    view_beams(surface, P, V, init_t)
    
    t = init_t
    
    for i in range(5):
        t = t - algo.delta(surface, P, V, t)
        # view_beams(surface, P, V, t)

    # Keep best beam
    # F :: (B, N)
    # F_grad :: (B, N, D)
    F, _ = surface_f(surface, P, V, t)
    assert F.shape == (B, N)
    _, indices = torch.min(torch.abs(F), dim=0)

    assert t.shape == (B, N)
    print("indices shape", indices.shape)
    print("indices", indices)
    assert indices.shape == (N,)
    best_t = torch.gather(t, 0, indices.unsqueeze(0)).squeeze(0)

    print(best_t.shape)

    # TODO when there are more than one collisions
    # how to choose which to keep deterministically?
    # preferably the one closest to t=0 for optics
    # idea: insteaed of max(): filter F < tol, keep closest to t=0 or other condition
    
    # Final points
    #final_points = P + t.unsqueeze(1).expand_as(V) * V

    view_beams(surface, P, V, best_t.unsqueeze(0))

demo()