In [None]:
# Reflection in pytorch

import matplotlib.pyplot as plt
import torch
import math

from torchlensmaker.raytracing import rot2d
from torchlensmaker.physics import reflection


def plotv(v, **kwargs):
    plt.plot([0, v[0]], [0, v[1]], **kwargs)
    

def demo_reflection(theta_i: float, normal_angle: float) -> None:
    """
    theta_i: incident angle
    normal_angle: angle of the vector normal to the surface
    """
    
    # surface normal vector
    normal = rot2d(torch.tensor([1., 0.]), normal_angle)

    B = 10
    
    
    # Make some random incident rays around theta_i
    all_theta_i = torch.full((B,), theta_i) + torch.deg2rad(torch.tensor(20.))*torch.rand(B)
    V = torch.zeros(B, 2)
    for i in range(B):
        V[i] = torch.as_tensor(-rot2d(normal, all_theta_i[i]), dtype=torch.float32)

    # Use the same normal for all incident rays
    all_normal = torch.tile(torch.as_tensor(normal, dtype=torch.float32), (B, 1))
    
    # Sanity checks
    assert(torch.allclose(torch.linalg.norm(V, axis=1), torch.tensor(1.0)))
    assert(torch.allclose(torch.linalg.norm(all_normal, axis=1), torch.tensor(1.0)))
    assert( torch.allclose(torch.sum(-V * all_normal, dim=1), torch.cos(all_theta_i)) )
    
    reflected = reflection(V, normal)

    # Verify using the trigonometric version of reflection
    theta_r = torch.arctan2(reflected[:, 1], reflected[:, 0]) - torch.arctan2(all_normal[:, 1], all_normal[:, 0])    
    assert torch.allclose(theta_r, -all_theta_i)
    assert(torch.allclose(torch.linalg.norm(reflected, axis=1), torch.tensor(1.0)))

    # Rendering
    
    # Draw the surface
    plotv(rot2d(normal, math.pi/2), color="lightblue")
    plotv(rot2d(normal, -math.pi/2), color="lightblue")

    # Draw the normal
    plotv(normal, linestyle="--", color="grey")

    for i in range(B):
        # Draw incident light ray
        plotv(-V[i], color="orange")

        # Draw reflected light ray
        plotv(reflected[i], color="red")

    plt.gca().set_xlim([-1, 1])
    plt.gca().set_ylim([-1, 1])
    plt.gca().set_title("Reflection")
    plt.gca().set_aspect("equal")
    plt.show()


plt.figure()
demo_reflection(theta_i = torch.deg2rad(torch.tensor(12.16)), normal_angle=torch.deg2rad(torch.tensor(105.0)))