# Ray Marching with Taichi

In [None]:
!pip install taichi

In [None]:
import taichi as ti
import numpy as np
import matplotlib.pyplot as plt

## Ray Marching
**Ray marching** is a rendering technique used to visualize 3D scenes, particularly effective for complex shapes that are difficult to represent with traditional polygonal meshes.

- **Basic Concept:** Instead of tracing a ray until it hits a surface (like traditional ray tracing), ray marching steps along the ray in small increments, checking at each point whether you've hit something.
- **How It Works:** From each pixel on your screen, cast a ray into the 3D scene. March along this ray in small steps, repeatedly asking "am I inside an object yet?" until you either hit something or reach a maximum distance.
Distance Fields: The magic happens with signed distance functions (SDFs). These mathematical functions tell you the shortest distance from any point in space to the nearest surface. If the distance is zero or negative, you've hit the surface.

- **Sphere Marching:** A more efficient variant uses the distance field value as the step size. Since you know the closest surface is at least that distance away, you can safely jump that far without missing anything. This dramatically reduces the number of steps needed.
- **Advantages:** Ray marching excels at rendering smooth, organic shapes, fractals, and procedural geometry. It handles complex operations like blending, warping, and infinite detail naturally through mathematical functions.

- **Applications:** Popular in demoscene productions, shader art, and procedural graphics. It's particularly useful for creating landscapes, clouds, and abstract mathematical visualizations.
- **Key Insight:** While traditional 3D graphics break objects into triangles, ray marching treats the entire scene as a mathematical function, allowing for incredibly complex and smooth surfaces with relatively simple code.

<img src="https://www.researchgate.net/profile/Guillaume-Francois-3/publication/220792126/figure/fig4/AS:668488666451987@1536391526826/Notations-and-principle-of-a-classical-ray-marching-algorithm-to-compute-single.png" height=500 width=700>

Ray marching with sphere marching:

<img src="https://www.tylerbovenzi.com/RayMarch/Assets/figure3.png" height=400 width=700>

## Taichi Vector
### Vector:
- A data type that groups multiple scalar values together (like a mathematical vector)
- Represents a single point in multi-dimensional space
- Used for positions, velocities, colors, or any grouped data
- Behaves like a single entity with multiple components

### Taichi Field:
- A data structure that stores collections of data in memory
- Can be thought of as multi-dimensional arrays that live on GPU/CPU
- Must be declared with shape and data type: ti.field(ti.f32, shape=(100, 100))
- Stores actual simulation data that persists between kernel calls
- Can contain scalars, vectors, or matrices as elements

### Key Differences:
- **Storage:** Vectors are temporary values used in computations, while fields are persistent memory structures that hold your simulation state.
- **Scope:** Vectors exist only during kernel execution, while fields persist throughout your program's lifetime.
- **Declaration:** Vectors are created on-the-fly, fields must be pre-declared with specific shapes and types.
- **Usage Pattern:** You typically store vectors inside fields. For example, a velocity field might be declared as velocity = ti.Vector.field(3, ti.f32, shape=(100, 100)) - this creates a field where each element is a 3D vector.
- **Access:** Fields use indexing (field[i, j]), vectors use component access (.x, .y, .z or indexing).

Think of fields as containers that hold your data, while vectors are the mathematical objects that represent multi-component values within those containers.

For example, the following code snippet declares a 2D field of 2D vectors:

```
# Declares a 3x3 vector field comprising 2D vectors
f = ti.Vector.field(n=2, dtype=float, shape=(3, 3))
```


In [None]:
# Initialize Taichi
ti.init(arch=ti.gpu)  # Use GPU for parallel computation

# Image dimensions
width, height = 800, 600

# Create image buffer
image = ti.Vector.field(3, dtype=ti.f32, shape=(width, height))

# Camera and scene parameters
camera_pos = ti.Vector([0.0, 2.0, 5.0])
sphere_center = ti.Vector([0.0, 1.0, 0.0])
sphere_radius = 1.0
ground_y = 0.0

In [None]:
@ti.func
def sdf_sphere(p, center, radius):
    """Signed distance function for a sphere"""
    return (p - center).norm() - radius

@ti.func
def sdf_plane(p, normal, d):
    """Signed distance function for a plane"""
    return p.dot(normal) + d

@ti.func
def scene_sdf(p):
    """Combine all objects in the scene"""
    # Sphere SDF
    sphere_dist = sdf_sphere(p, sphere_center, sphere_radius)

    # Ground plane SDF (plane at y = 0, normal pointing up)
    plane_dist = sdf_plane(p, ti.Vector([0.0, 1.0, 0.0]), -ground_y)

    # Return minimum distance (union of objects)
    return ti.min(sphere_dist, plane_dist)

In [None]:
@ti.func
def estimate_normal(p):
    """Estimate surface normal using gradient of SDF"""
    eps = 1e-4
    dx = ti.Vector([eps, 0.0, 0.0])
    dy = ti.Vector([0.0, eps, 0.0])
    dz = ti.Vector([0.0, 0.0, eps])

    normal = ti.Vector([
        scene_sdf(p + dx) - scene_sdf(p - dx),
        scene_sdf(p + dy) - scene_sdf(p - dy),
        scene_sdf(p + dz) - scene_sdf(p - dz)
    ])

    return normal.normalized()

@ti.func
def ray_march(origin, direction):
    """Ray marching algorithm"""
    max_steps = 1024
    max_distance = 100.0
    epsilon = 1e-3

    ###
    # ray marching code here

    # 1. variables to keep track of dist and whether ray hit an object
    # 2. for each step taken, find the point along the ray at current dist value, and find new min dist
    # 3. if the new min dist is lesser than threshold, consider as hit, else continue
    # 4. if total dist is more than max threshold, consider no objects along ray
    # 5. add the new min dist to current dist

@ti.func
def get_ray_direction(x, y):
    """Calculate ray direction from camera through pixel (x, y)"""
    # Convert pixel coordinates to normalized device coordinates
    ndc_x = (x / width - 0.5) * 2.0
    ndc_y = (y / height - 0.5) * 2.0 * (height / width)

    # Simple perspective projection
    direction = ti.Vector([ndc_x, ndc_y, -1.0]).normalized()
    return direction

In [None]:
@ti.func
def shade(hit_point, normal):
    """Simple shading calculation"""
    # Light position
    light_pos = ti.Vector([2.0, 4.0, 3.0])
    light_dir = (light_pos - hit_point).normalized()

    # Ambient + diffuse lighting
    ambient = 0.2
    diffuse = ti.max(0.0, normal.dot(light_dir)) * 0.8

    # Different colors for sphere and ground
    sphere_dist = sdf_sphere(hit_point, sphere_center, sphere_radius)
    ground_dist = sdf_plane(hit_point, ti.Vector([0.0, 1.0, 0.0]), -ground_y)

    color = ti.Vector([0.0, 0.0, 0.0])  # Default color

    if sphere_dist < ground_dist:
        # Sphere color (red)
        color = ti.Vector([0.8, 0.3, 0.3])
    else:
        # Ground color (gray with checkerboard pattern)
        checker_size = 1.0
        checker_x = ti.floor(hit_point.x / checker_size)
        checker_z = ti.floor(hit_point.z / checker_size)
        checker = (checker_x + checker_z) % 2

        if checker < 1:
            color = ti.Vector([0.7, 0.7, 0.7])
        else:
            color = ti.Vector([0.3, 0.3, 0.3])

    return color * (ambient + diffuse)

In [None]:
@ti.kernel
def render():
    """Main rendering kernel - runs in parallel"""
    for x, y in image:
      ###
      # rendering code here

      # 1. Get ray direction for this pixel
      # 2. Perform ray marching
      # 3. if hit, calculate hit point and normal, shade accordingly
      # 4. if no hit, assign bg color to the pixel
      # Get ray direction for this pixel


In [None]:
print("Rendering scene...")

# Render the scene
render()

# Convert Taichi field to numpy array for display
img_np = image.to_numpy()

img_np = np.rot90(img_np, k=1)

# Display the image
plt.figure(figsize=(12, 9))
plt.imshow(img_np)
plt.title("SDF-based Ray Tracer - Sphere on Ground Plane")
plt.axis('off')
plt.tight_layout()
plt.show()

print("Rendering complete!")