In [None]:
import torch
import torch.nn as nn
import torchlensmaker as tlm


def display_hit_miss_3d(source, surface, dist):
    """
    Given a light source (sequential system) and a surface
    position the surface after the light source and show hit / miss rays in tlmviewer
    """

    # Sample the light source
    source.set_sampling3d(pupil=500)
    data = source(tlm.default_input(dim=3))
    
    # Raytrace the surface
    t, normals, valid, stf, ntf = surface(data.P, data.V, data.fk)

    if not t.isfinite().all():
        print("Warning: surface collision returned some non-finite t values")
    if valid.all():
        print("All rays hit the surface")
    elif (~valid).all():
        print("All rays miss the surface")

    # Compute end points for colliding and non colliding rays
    hit_start = data.P[valid]
    hit_end = (data.P + t.unsqueeze(-1)*data.V)[valid]
    miss_start = data.P[~valid]
    miss_end = (data.P + dist*data.V)[~valid]

    # Render both ray groups
    hit = tlm.render_rays(hit_start, hit_end, 0, default_color="lightgreen")
    miss = tlm.render_rays(miss_start, miss_end, 0, default_color="orange")

    # Render surface
    sdict = surface.render()
    sdict["matrix"] = stf.direct.tolist()

    # Render manually
    scene = tlm.new_scene("3D")
    scene["data"] = [sdict, hit, miss]

    # Display
    scene["controls"]={"show_optical_axis": True, "show_other_axes": True}
    tlm.display_scene(scene)

In [None]:
display_hit_miss_3d(
    tlm.Sequential(
        tlm.Gap(-1),
        tlm.PointSource(80),
        tlm.Gap(5),
    ),
    tlm.Disk(5),
    10,
)