#### Setup

In [1]:
import os
import sys
import torch as t
from torch import Tensor
import einops
from ipywidgets import interact
import plotly.express as px
from ipywidgets import interact
from pathlib import Path
from IPython.display import display
from jaxtyping import Float, Int, Bool, Shaped, jaxtyped
import typeguard

# Make sure exercises are in the path
chapter = r"chapter0_fundamentals"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part1_ray_tracing"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
from part1_ray_tracing.utils import render_lines_with_plotly, setup_widget_fig_ray, setup_widget_fig_triangle
import part1_ray_tracing.tests as tests

MAIN = __name__ == "__main__"

## 1D Image Rendering

In our initial setup, the camera will be a single point at the origin, and the screen will be the plane at x=1.

Objects in the world consist of triangles, where triangles are represented as 3 points in 3D space (so 9 floating point values per triangle). You can build any shape out of sufficiently many triangles and your Pikachu will be made from 412 triangles.

The camera will emit one or more rays, where a ray is represented by an origin point and a direction point. Conceptually, the ray is emitted from the origin and continues in the given direction until it intersects an object.

We have no concept of lighting or color yet, so for now we'll say that a pixel on our screen should show a bright color if a ray from the origin through it intersects an object, otherwise our screen should be dark.

### Exercise - implement make_rays_1d

Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 10-15 minutes on this exercise.

In [3]:
def make_rays_1d(num_pixels: int, y_limit: float) -> t.Tensor:
    '''
    num_pixels: The number of pixels in the y dimension. Since there is one ray per pixel, this is also the number of rays.
    y_limit: At x=1, the rays should extend from -y_limit to +y_limit, inclusive of both endpoints.

    Returns: shape (num_pixels, num_points=2, num_dim=3) where the num_points dimension contains (origin, direction) and the num_dim dimension contains xyz.

    Example of make_rays_1d(9, 1.0): [
        [[0, 0, 0], [1, -1.0, 0]],
        [[0, 0, 0], [1, -0.75, 0]],
        [[0, 0, 0], [1, -0.5, 0]],
        ...
        [[0, 0, 0], [1, 0.75, 0]],
        [[0, 0, 0], [1, 1, 0]],
    ]
    '''
    rays = t.zeros(num_pixels, 2, 3)
    t.linspace(-y_limit, y_limit, num_pixels, out=rays[:, 1, 1])
    rays[:, 1, 0] = 2
    return rays


rays1d = make_rays_1d(9, 10.0)

fig = render_lines_with_plotly(rays1d)

### Ray-Object Intersection

In [47]:
fig = setup_widget_fig_ray()
display(fig)

@interact
def response(seed=(0, 10, 1), v=(-2.0, 2.0, 0.01)):
    t.manual_seed(seed)
    L_1, L_2 = t.rand(2, 2)
    P = lambda v: L_1 + v * (L_2 - L_1)
    x, y = zip(P(-2), P(2))
    with fig.batch_update(): 
        fig.data[0].update({"x": x, "y": y}) 
        fig.data[1].update({"x": [L_1[0], L_2[0]], "y": [L_1[1], L_2[1]]}) 
        fig.data[2].update({"x": [P(v)[0]], "y": [P(v)[1]]})

FigureWidget({
    'data': [{'type': 'scatter', 'uid': 'f6e0391d-4a5b-4dc7-9c4d-39d03428bb48', 'x': [], 'y': []},
             {'marker': {'size': 12},
              'mode': 'markers',
              'type': 'scatter',
              'uid': '7dfa9a1f-ed45-44f3-a213-a135873c6955',
              'x': [],
              'y': []},
             {'marker': {'size': 12, 'symbol': 'x'},
              'mode': 'markers',
              'type': 'scatter',
              'uid': '1df4dd61-17b6-466a-bf72-d3768d72c4a9',
              'x': [],
              'y': []}],
    'layout': {'height': 500,
               'showlegend': False,
               'template': '...',
               'width': 600,
               'xaxis': {'range': [-1.5, 2.5]},
               'yaxis': {'range': [-1.5, 2.5]}}
})

interactive(children=(IntSlider(value=5, description='seed', max=10), FloatSlider(value=0.0, description='v', …

### Exercise - which segments intersect with the rays?

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵⚪⚪⚪⚪

You should spend up to 10-15 minutes on this exercise.

For each of the following segments, which camera rays from earlier intersect? You can do this by inspection or using render_lines_with_plotly.

In [5]:
segments = t.tensor([
    [[1.0, -12.0, 0.0], [1, -6.0, 0.0]], 
    [[0.5, 0.1, 0.0], [0.5, 1.15, 0.0]], 
    [[2, 12.0, 0.0], [2, 21.0, 0.0]]
])

rays1d = make_rays_1d(9, 10.0)
render_lines_with_plotly(rays1d, segments)

### Exercise - implement intersect_ray_1d

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 20-25 minutes on this exercise.

It involves some of today's core concepts: tensor manipulation, linear operations, etc.

Using torch.lingalg.solve and torch.stack, implement the intersect_ray_1d function to solve the above matrix equation.

In [3]:
def intersect_ray_1d(ray: t.Tensor, segment: t.Tensor) -> bool:
    '''
    ray: shape (n_points=2, n_dim=3)  # O, D points
    segment: shape (n_points=2, n_dim=3)  # L_1, L_2 points

    Return True if the ray intersects the segment.
    '''
    O, D = ray[:, :2]
    L1, L2 = segment[:, :2]

    B = L1-O
    A = t.stack((D, L1-L2), dim=1)
    try:
        inter = t.linalg.solve(A, B)
    except RuntimeError:
        return False
    return inter[0] >= 0 and 0 <= inter[1] <=1

tests.test_intersect_ray_1d(intersect_ray_1d)
tests.test_intersect_ray_1d_special_case(intersect_ray_1d)

All tests in `test_intersect_ray_1d` passed!
All tests in `test_intersect_ray_1d_special_case` passed!


```
from jaxtyping import jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
from typeguard import typechecked as typechecker
from beartype import beartype as typechecker

@jaxtyped(typechecker=typechecker)
def foo(...):
```
and the old double-decorator syntax
```
@jaxtyped
@typechecker
def foo(...):
```
should no longer be used. (It will continue to work as it did before, but the new approach will produce more readable error messages.)
In particular note that `typechecker` must be passed via keyword argument; the following is not valid:
```
@jaxtyped(typechecker)
def foo(...):
```

  @jaxtyped


### Exercise - implement intersect_rays_1d

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 25-30 minutes on this exercise.

It will be one of the most difficult today; certainly it has the biggest step in difficulty.

In [7]:
def intersect_rays_1d(rays: Float[Tensor, "nrays 2 3"], segments: Float[Tensor, "nsegments 2 3"]) -> Bool[Tensor, "nrays"]:
    '''
    For each ray, return True if it intersects any segment.
    '''

    # set up the batches
    rays = rays[...,:2]
    segments = segments[...,:2]
    n_rays = rays.shape[0]
    n_segments = segments.shape[0]
    rays_b = einops.repeat(rays, 'ray origin dir -> (ray n) origin dir', n=n_segments)
    segments_b = einops.repeat(segments, 'seg l1 l2 -> (n seg) l1 l2', n=n_rays)
    uv_out = t.empty(n_rays * n_segments, 2)
    
    # set up the intersection tests
    O,D = t.unbind(rays_b, dim=1)
    L1, L2 = t.unbind(segments_b, dim=1)
    B = L1-O
    A = t.stack((D, L1-L2), dim=2)
    try:
        t.linalg.solve(A, B, out=uv_out)
    except:
        pass
    uv = einops.rearrange(uv_out, '(b1 b2) uv -> b1 b2 uv', b1 = n_rays)
    intersections = (uv[..., 0] >= 0) & (0 <= uv[..., 1]) & (uv[..., 1] <= 1)
    ray_result = t.any(intersections, dim=1)
    return ray_result


    def test_ex(batch):
        print('Test on batch #', batch)
        print('ray put into function', rays[batch])
        print('segment put into function', segments[batch])
        print('Intersects:', intersect_ray_1d(rays[batch], segments[batch]))
        print('\n')
        print('ray segment', rays_b[0], segments_b[0])
        print('from our function, OD L1L2', O[0], D[0], L1[0], L2[0])
        print('A, B:', A[0], B[0])
        print('uv', uv[0])
        



tests.test_intersect_rays_1d(intersect_rays_1d)
tests.test_intersect_rays_1d_special_case(intersect_rays_1d)

All tests in `test_intersect_rays_1d` passed!
All tests in `test_intersect_rays_1d_special_case` passed!


## 2D Rays

### Exercise - implement make_rays_2d

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵⚪⚪⚪

You should spend up to 10-15 minutes on this exercise.

In [4]:
def make_rays_2d(num_pixels_y: int, num_pixels_z: int, y_limit: float, z_limit: float) -> Float[t.Tensor, "nrays 2 3"]:
    '''
    num_pixels_y: The number of pixels in the y dimension
    num_pixels_z: The number of pixels in the z dimension

    y_limit: At x=1, the rays should extend from -y_limit to +y_limit, inclusive of both.
    z_limit: At x=1, the rays should extend from -z_limit to +z_limit, inclusive of both.

    Returns: shape (num_rays=num_pixels_y * num_pixels_z, num_points=2, num_dims=3).
    '''
    nrays = num_pixels_y * num_pixels_z
    rays = t.zeros(nrays, 2, 3)
    ygrid = t.linspace(-y_limit, y_limit, num_pixels_y)
    zgrid = t.linspace(-z_limit, z_limit, num_pixels_z)
    rays = t.zeros((nrays, 2, 3), dtype=t.float32)
    rays[:, 1, 0] = 1
    rays[:, 1, 1] = einops.repeat(ygrid, "y -> (y z)", z=num_pixels_z)
    rays[:, 1, 2] = einops.repeat(zgrid, "z -> (y z)", y=num_pixels_y)
    return rays


rays_2d = make_rays_2d(10, 10, 0.3, 0.3)
render_lines_with_plotly(rays_2d)

#### Triangles

In [9]:
one_triangle = t.tensor([[0, 0, 0], [3, 0.5, 0], [2, 3, 0]])
A, B, C = one_triangle
x, y, z = one_triangle.T

fig = setup_widget_fig_triangle(x, y, z)

@interact(u=(-0.5, 1.5, 0.01), v=(-0.5, 1.5, 0.01))
def response(u=0.0, v=0.0):
    P = A + u * (B - A) + v * (C - A)
    fig.data[2].update({"x": [P[0]], "y": [P[1]]})

display(fig)

interactive(children=(FloatSlider(value=0.0, description='u', max=1.5, min=-0.5, step=0.01), FloatSlider(value…

FigureWidget({
    'data': [{'marker': {'size': 12},
              'mode': 'markers+text',
              'text': [A, B, C],
              'textfont': {'size': 18},
              'textposition': 'middle left',
              'type': 'scatter',
              'uid': '8cc4d7f7-56b4-40e1-8039-94e8f7cb9e6b',
              'x': array([0., 3., 2.], dtype=float32),
              'y': array([0. , 0.5, 3. ], dtype=float32)},
             {'mode': 'lines',
              'type': 'scatter',
              'uid': 'd238fcca-b201-487a-a6a8-59108f1e3003',
              'x': [0.0, 3.0, 2.0, 0.0],
              'y': [0.0, 0.5, 3.0, 0.0]},
             {'marker': {'size': 12, 'symbol': 'x'},
              'mode': 'markers',
              'type': 'scatter',
              'uid': '72ea7a3d-f221-4011-9002-2d96f5e3fbae',
              'x': [0.0],
              'y': [0.0]}],
    'layout': {'height': 600,
               'showlegend': False,
               'template': '...',
               'title': {'text': 'Barycen

### Exercises - implement triangle_ray_intersects

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to 15-20 minutes on this exercise.

In [5]:
Point = Float[Tensor, "points=3"]

# @jaxtyped
# @typeguard.typechecked
def triangle_ray_intersects(A: Point, B: Point, C: Point, O: Point, D: Point) -> bool:
    '''
    A: shape (3,), one vertex of the triangle
    B: shape (3,), second vertex of the triangle
    C: shape (3,), third vertex of the triangle
    O: shape (3,), origin point
    D: shape (3,), direction point

    Return True if the ray and the triangle intersect.
    '''
    vector = O-A
    matrix = t.stack((-D, B-A, C-A), dim=1)
    s,u,v = t.linalg.solve(matrix, vector)
    return ((u + v) <= 1) & (u >= 0) & (v >= 0)

    
tests.test_triangle_ray_intersects(triangle_ray_intersects)

All tests in `test_triangle_ray_intersects` passed!


## Single-Triangle Rendering

### Exercise - implement raytrace_triangle

Difficulty: 🔴🔴🔴🔴⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 15-20 minutes on this exercise.

This is about as hard as `intersect_rays_1d`, although hopefully you should find it more familiar.

In [7]:
def raytrace_triangle(
    rays: Float[Tensor, "nrays rayPoints=2 dims=3"],
    triangle: Float[Tensor, "trianglePoints=3 dims=3"]
) -> Bool[Tensor, "nrays"]:
    '''
    For each ray, return True if the triangle intersects that ray.
    '''
    # should broadcast the triangle dims to number of rays, so i dont have to do any repeats. oops wrong
    O,D = rays.unbind(dim=1)    # shapes are (225, 3)
    A, B, C = triangle.unbind(dim=0)        #shapes are (3)
    vector = O-A
    m = einops.repeat(B-A, 'xyz -> nrays xyz', nrays=rays.shape[0])
    n = einops.repeat(C-A, 'xyz -> nrays xyz', nrays=rays.shape[0])
    matrix = t.stack([-D, m, n], dim=2)
    s,u,v = t.unbind(t.linalg.solve(matrix, vector), dim=1)
    return ((u + v) <= 1) & (u >= 0) & (v >= 0)

    
    def test_ray(num):
        print('od at batch', O[num], D[num])
        print('abc', A, B, C)
        print('matrix', matrix[num])
        print('s u v at batch', s[num], u[num], v[num])
        print(triangle_ray_intersects(A, B, C, O[num], D[num]))
    # test_ray(200)

A = t.tensor([1, 0.0, -0.5])
B = t.tensor([1, -0.5, 0.0])
C = t.tensor([1, 0.5, 0.5])
num_pixels_y = num_pixels_z = 30
y_limit = z_limit = 0.5

# Plot triangle & rays
test_triangle = t.stack([A, B, C], dim=0)
rays2d = make_rays_2d(num_pixels_y, num_pixels_z, y_limit, z_limit)
triangle_lines = t.stack([A, B, C, A, B, C], dim=0).reshape(-1, 2, 3)
render_lines_with_plotly(rays2d, triangle_lines)

# Calculate and display intersections
intersects = raytrace_triangle(rays2d, test_triangle)
img = intersects.reshape(num_pixels_y, num_pixels_z).int()
imshow(img, origin="lower", width=600, title="Triangle (as intersected by rays)")

### Debugging Tools

In [None]:
def raytrace_triangle_with_bug(
    rays: Float[Tensor, "nrays rayPoints=2 dims=3"],
    triangle: Float[Tensor, "trianglePoints=3 dims=3"]
) -> Bool[Tensor, "nrays"]:
    '''
    For each ray, return True if the triangle intersects that ray.
    '''
    NR = rays.size[0]

    A, B, C = einops.repeat(triangle, "pts dims -> pts NR dims", NR=NR)

    O, D = rays.unbind(-1)

    mat = t.stack([- D, B - A, C - A])

    dets = t.linalg.det(mat)
    is_singular = dets.abs() < 1e-8
    mat[is_singular] = t.eye(3)

    vec = O - A

    sol = t.linalg.solve(mat, vec)
    s, u, v = sol.unbind(dim=-1)

    return ((u >= 0) & (v >= 0) & (u + v <= 1) & ~is_singular)


intersects = raytrace_triangle_with_bug(rays2d, test_triangle)
img = intersects.reshape(num_pixels_y, num_pixels_z).int()
imshow(img, origin="lower", width=600, title="Triangle (as intersected by rays)")

### Mesh Rendering

In [9]:
with open(section_dir / "pikachu.pt", "rb") as f:
    triangles = t.load(f)

### Exercise - implement raytrace_mesh

Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵🔵🔵🔵⚪

You should spend up to 20-25 minutes on this exercise.

This is the main function we've been building towards, and marks the end of the core exercises. If should involve a lot of repurposed code from the last excercise.

In [13]:
def raytrace_mesh(
    rays: Float[Tensor, "nrays rayPoints=2 dims=3"],
    triangles: Float[Tensor, "ntriangles trianglePoints=3 dims=3"]
) -> Float[Tensor, "nrays"]:
    '''
    For each ray, return the distance to the closest intersecting triangle, or infinity.
    '''
    rays = einops.repeat(rays, 'nr pts dims -> pts nr nt dims', nt = triangles.shape[0])
    triangles = einops.repeat(triangles, 'nt pts dims -> pts nr nt dims', nr= rays.shape[1])
    O,D = rays
    A,B,C = triangles
    vector = O-A
    matrix = t.stack([-D, B-A, C-A], dim=-1)
    s,u,v = t.unbind(t.linalg.solve(matrix, vector), dim=-1)
    intersections = ((u + v) <= 1) & (u >= 0) & (v >= 0)
    s[~intersections] = float('inf')
    min_distances = einops.reduce(s, 'nrays ntriangles -> nrays', 'min')
    return min_distances

num_pixels_y = 250
num_pixels_z = 250
y_limit = z_limit = 1

rays = make_rays_2d(num_pixels_y, num_pixels_z, y_limit, z_limit)
rays[:, 0] = t.tensor([-2, 0.0, 0.0])
dists = raytrace_mesh(rays, triangles)
intersects = t.isfinite(dists).view(num_pixels_y, num_pixels_z)
dists_square = dists.view(num_pixels_y, num_pixels_z)
img = t.stack([intersects, dists_square], dim=0)

fig = px.imshow(img, facet_col=0, origin="lower", color_continuous_scale="magma", width=1000)
fig.update_layout(coloraxis_showscale=False)
for i, text in enumerate(["Intersects", "Distance"]): 
    fig.layout.annotations[i]['text'] = text
fig.show()