In [None]:
# Snells law in vector form

import torch
import matplotlib.pyplot as plt
import numpy as np

from torchlensmaker.raytracing import rot2d, rot2d_matrix

from torchlensmaker.raytracing import refraction, super_refraction, clamped_refraction

def plotv(v, **kwargs):
    plt.plot([0, v[0]], [0, v[1]], **kwargs)


def demo_super_refraction(refraction_function, theta_i, normal_angle, n1, n2):

    plt.figure()
    
    if n1 > n2:
        critical_angle = np.arcsin(n2/n1)
        print("Critical angle", np.degrees(np.arcsin(n2/n1)))
    else:
        print("No critical angle")
    
    normal = rot2d(np.array([1, 0]), normal_angle)
    R = -rot2d(normal, theta_i)
    
    assert(np.allclose(np.linalg.norm(R), 1.0))
    assert(np.allclose(np.linalg.norm(normal), 1.0))
    assert( np.allclose(-R.dot(normal), np.cos(theta_i)) )
    
    R_refracted = refraction_function(torch.tensor(R), torch.tensor(normal), n1, n2)
    
    theta_r = np.arctan2(R_refracted[1], R_refracted[0]) - np.arctan2(-normal[1], -normal[0])
    
    print("theta_i", np.degrees(theta_i))
    print("theta_r", np.degrees(theta_r))
    
    correct_refraction = np.allclose( n1*np.sin(theta_i), n2*np.sin(theta_r) )
    assert(np.allclose(np.linalg.norm(R_refracted), 1.0))

    plotv(rot2d(normal, np.pi/2), color="lightblue")
    plotv(rot2d(normal, -np.pi/2), color="lightblue")
    plotv(normal, linestyle="--", color="grey")
    plotv(-normal, linestyle="--", color="grey")
    plotv(-R, color="orange")
    
    plotv(R_refracted, color="red")
    
    plt.gca().set_xlim([-1, 1])
    plt.gca().set_ylim([-1, 1])
    plt.gca().set_title("Normal refraction" if correct_refraction else "Modified refraction!")
    plt.gca().set_aspect("equal")
    plt.show()




def plot(theta_i, normal_angle, n1, n2):
    demo_super_refraction(super_refraction, np.radians(theta_i), np.radians(normal_angle), n1, n2)
    demo_super_refraction(clamped_refraction, np.radians(theta_i), np.radians(normal_angle), n1, n2)
    demo_super_refraction(refraction, np.radians(theta_i), np.radians(normal_angle), n1, n2)

plot(
    theta_i = 42.16,
    normal_angle = 105,
    n1 = 1.0,
    n2 = 1.5,
)

In [None]:
# Batched refraction in pytorch

import numpy as np
import matplotlib.pyplot as plt
import torch

from torchlensmaker.raytracing import rot2d


def plotv(v, **kwargs):
    plt.plot([0, v[0]], [0, v[1]], **kwargs)


def refraction_batched(ray, normal, n1, n2):
    """
    Batched vector-based Snell's law

    ray: unit vectors of the incident rays, shape (B, 2)
    normal: unit vectors normal to the surface, shape (B, 2)
    n1, n2: indices of refraction (floats)
    
    Returns: unit vectors of the refracted rays, shape (B, 2)
    """
    # Compute dot product for the batch
    dot_product = torch.sum(ray * normal, dim=1, keepdim=True)

    # Compute R_perp
    R_perp = n1/n2 * (ray + (-dot_product) * normal)

    # Compute R_para
    R_para = -torch.sqrt(1 - torch.sum(R_perp * R_perp, dim=1, keepdim=True)) * normal

    # Combine R_perp and R_para
    R = R_perp + R_para

    # Normalize the result
    return R / torch.norm(R, dim=1, keepdim=True)

def demo_batched_refraction(refraction_function, theta_i: float, normal_angle: float, n1, n2):
    """
    theta_i: incident angle
    normal_angle: angle of the vector normal to the surface
    """
    
    # surface normal vector
    normal = rot2d(np.array([1, 0], dtype=np.float32), normal_angle)

    B = 10
    
    # Make some random incident rays around theta_i
    all_theta_i = torch.full((B,), theta_i) + np.radians(20)*torch.rand(B)
    V = torch.zeros(B, 2)
    for i in range(B):
        V[i] = torch.as_tensor(-rot2d(normal, all_theta_i[i]), dtype=torch.float32)

    # Use the same normal for all incident rays
    all_normal = torch.tile(torch.as_tensor(normal, dtype=torch.float32), (B, 1))
    
    # Sanity checks
    assert(np.allclose(np.linalg.norm(V, axis=1), 1.0))
    assert(np.allclose(np.linalg.norm(all_normal, axis=1), 1.0))
    assert( torch.allclose(torch.sum(-V * all_normal, dim=1), torch.cos(all_theta_i)) )

    ###
    
    refracted = refraction_function(V, all_normal, n1, n2)

    # Rendering
    
    # Draw the surface
    plotv(rot2d(normal, np.pi/2), color="lightblue")
    plotv(rot2d(normal, -np.pi/2), color="lightblue")

    # Draw the normal
    plotv(normal, linestyle="--", color="grey")
    plotv(-normal, linestyle="--", color="grey")

    for i in range(B):
        # Draw incident light ray
        plotv(-V[i], color="orange")

        # Draw reflected light ray
        plotv(refracted[i], color="red")

    plt.gca().set_xlim([-1, 1])
    plt.gca().set_ylim([-1, 1])
    plt.gca().set_title("Refraction")
    plt.gca().set_aspect("equal")
    plt.show()


plt.figure()
demo_batched_refraction(refraction_batched, theta_i = np.radians(12.16), normal_angle=np.radians(105), n1=1.5, n2=1.0)