In [None]:
# Batched refraction (Snell's law) in pytorch

import math
import torch

from torchlensmaker.raytracing import *
from torchlensmaker.physics import refraction

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines

from IPython.display import display, HTML

def plotv(v, **kwargs):
    plt.plot([0, v[0]], [0, v[1]], **kwargs)



def demo_batched_refraction(crit_option, theta_i: float, normal_angle: float, n1, n2):
    """
    Demo / Test of a batched refraction function
    
    theta_i: incident angle
    normal_angle: angle of the vector normal to the surface
    """

    # Critical angle
    if n1 > n2:
        critical_angle = np.arcsin(n2/n1, dtype=np.float32)
        print("Critical angle: {:.2f} deg".format(np.degrees(np.arcsin(n2/n1), dtype=np.float32)))
    else:
        critical_angle = None
        print("No critical angle")
    
    # Surface normal vector
    normal = rot2d(torch.tensor([1., 0.]), normal_angle)
    B = 10
    
    # Make B incident rays +- 20 deg around theta_i
    spread = np.radians(35)
    noise = torch.linspace(-spread/2, spread/2, B)
    all_theta_i = torch.full((B,), theta_i) + noise
    V = torch.zeros(B, 2)
    for i in range(B):
        V[i] = torch.as_tensor(-rot2d(normal, all_theta_i[i]), dtype=torch.float32)

    # Use the same normal for all incident rays
    all_normal = torch.tile(torch.as_tensor(normal, dtype=torch.float32), (B, 1))
    
    # Sanity checks
    assert(np.allclose(np.linalg.norm(V, axis=1), 1.0))
    assert(np.allclose(np.linalg.norm(all_normal, axis=1), 1.0))
    assert( torch.allclose(torch.sum(-V * all_normal, dim=1), torch.cos(all_theta_i)) )

    # Call refraction function
    refracted = refraction(V, all_normal, n1, n2, critical_angle=crit_option)

    # Check for nans
    number_of_nonfinite = (~torch.isfinite(refracted).any(dim=1)).sum()
    if number_of_nonfinite > 0:
        print(f"Warning! {number_of_nonfinite} refracted rays contain nan!")

    if V.shape[0] != refracted.shape[0]:
        print(f"Warning! {V.shape[0]} incident rays but only {refracted.shape[0]} refracted rays.")

    # Rendering
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    
    # Draw the surface
    plotv(rot2d(normal, math.pi/2), color="lightblue")
    plotv(rot2d(normal, -math.pi/2), color="lightblue")

    # Draw the normal
    plotv(normal, linestyle="--", color="grey")
    plotv(-normal, linestyle="--", color="grey")

    # Draw critical angle line
    if critical_angle is not None:
        plotv(1.5*rot2d(normal, critical_angle),  linestyle="--", color="lightgrey")
        plotv(1.5*rot2d(normal, -critical_angle),  linestyle="--", color="lightgrey")

    # Draw incident and refracted light rays
    for i in range(V.shape[0]):
        plotv(-V[i], color="orange")

    for i in range(refracted.shape[0]):
        plotv(refracted[i], color="red")

    ax.set_xlim([-1, 1])
    ax.set_ylim([-1, 1])
    title = f"critical_angle='{crit_option}' | n = ({n1}, {n2})"
    ax.set_title(title)
    ax.set_aspect("equal")

    orange_line = mlines.Line2D([], [], color='orange', label='incident')
    red_line = mlines.Line2D([], [], color='red', label='refracted')
    ax.legend(handles=[orange_line, red_line])
    
    display(fig)
    plt.close(fig)


crit_options = [
    'nan',
    'clamp',
    'drop',
]

plt.ioff()

theta_i = np.radians(-39.16, dtype=np.float32)
normal_angle=np.radians(105, dtype=np.float32)
n1, n2 = 1.5, 1.0

for c in crit_options:
    demo_batched_refraction(c, theta_i, normal_angle, n1, n2)
    display(HTML("<hr>"))