In [None]:
import torchlensmaker as tlm
import torch
import math

from torchlensmaker.testing.collision_datasets import FixedRays

from typing import TypeAlias

Tensor = torch.Tensor

from torchlensmaker.core.collision_detection import init_brd


def demo_brd_initialization(surface, P, V, rays_length=100):

    N, D = P.shape
    B = 8
    
    # :: (B, N)
    init_t = init_brd(surface, P, V, B)
    print(init_t.shape)

    # :: (B, N, D)
    points = P.expand((B, -1, -1)) + init_t.unsqueeze(-1).expand((B, N, D)) * V.expand((B, -1, -1))

    print(points.shape)
    points = points.reshape((-1, D))
    print(points.shape)
    
    scene = tlm.viewer.new_scene("2D" if D == 2 else "3D")
    scene["data"].append(tlm.viewer.render_surface(surface, D))
    
    rays_start = P - rays_length*V
    rays_end = P + rays_length*V
    scene["data"].append(
        tlm.viewer.render_rays(rays_start, rays_end, layer=0)
    )

    scene["data"].append(tlm.viewer.render_points(points))
    

    tlm.viewer.ipython_display(scene)


surface = tlm.Sphere(30, 16)
generator = FixedRays(dim=2, N=10, direction=tlm.unit2d_rot(40), offset=30, epsilon=0.05)

P, V = generator(surface)

demo_brd_initialization(surface, P, V)