In [None]:
pip install -q git+https://github.com/cpgoodri/jax_transformations3d.git

In [8]:
%load_ext autoreload
%autoreload 2

import jax
from jax import jit, vmap
from functools import partial
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jax_transformations3d as jts
import colorsys
import numpy as np
import math
import time
from tqdm.auto import tqdm, trange

from raytrace import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [124]:
# bounce count, time
# 4, 31,
# 8, 61.3, 71
# 16, 121-122

sphere_pos, sphere_radius, mat_color, em_color, em_strength, mat = stack_dict_list(
    spheres
)

key = jax.random.PRNGKey(0)
result_img = jnp.zeros((res_y, res_x, 3))

# this should be before every ray trace, but for now keep it here
x_offset, y_offset = (jax.random.uniform(subkey, (2,)) - 0.5) * 0.005
ray_pos, ray_dirs = get_init(
    res_x, res_y, x_persp, y_persp, camera_persp, x_offset, y_offset
)
#

result_img.block_until_ready()
t0 = time.time()

k = 10
for i in range(k):
    key, subkey = jax.random.split(key, 2)
    key_grid = jax.random.split(subkey, res_x * res_y).reshape((res_x, res_y, -1))

    result_img += full_ray_trace(
        ray_pos,
        ray_dirs,
        key_grid,
        sphere_pos,
        sphere_radius,
        mat_color,
        em_color,
        em_strength,
        mat,
    )

result_img.block_until_ready()
print((time.time() - t0) * 1000)

plt.imshow(result_img / jnp.quantile(result_img.flatten(), 0.95))
plt.show()

In [4]:
base_res = 256
x_persp, y_persp = 1.5 * 3, 1 * 3
res_x, res_y = int(base_res * x_persp), int(base_res * y_persp)
camera_persp = 12


n = 5
spheres = [
    {
        "pos": [5, -5 + l * 10, 0],
        "radius": 1,
        "mat_color": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 1),
        "em_color": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 1),
        "em_strength": 0,
        "mat": max(l, 0.01),
    }
    for l in jnp.linspace(0, 1, n)
] + [
    # {'pos': [10000+15, 0, 0], 'radius': 10000, 'mat_color': [1, 1, 1], 'em_color': [1, 1, 1], 'em_strength': 0.01, 'mat': 0.5},
    # ground
    {
        "pos": [5, 0, 40000],
        "radius": 40000 - 1,
        "mat_color": [1, 1, 1],
        "em_color": [1, 1, 1],
        "em_strength": 0,
        "mat": 1,
    },
    # back wall
    # {'pos': [40000, 0, 0], 'radius': 40000-100, 'mat_color': [0.1, .1, 0.1], 'em_color': [1, 1, 1], 'em_strength': 0, 'mat': 1},
    # ceiling light
    {
        "pos": [5, 0, -40000],
        "radius": 40000 - 5000,
        "mat_color": [1, 1, 1],
        "em_color": [1, 1, 1],
        "em_strength": 0.000001,
        "mat": 1,
    },
    # {'pos': [10, 5, -5], 'radius': 4, 'mat_color': [1, 1, 1], 'em_color': [1, 1, 1], 'em_strength': 0, 'mat': 0},
    # {'pos': [5-math.sin(math.pi*i*2)*10, math.cos(math.pi*i*2)*10, -4], 'radius': 1, 'mat_color': [0, 0, 0], 'em_color': [1, 1, 1], 'em_strength': 3, 'mat': 1},
    # {'pos': [6, -10, 0], 'radius': 0.5, 'mat_color': [0, 0, 0], 'em_color': [0, 0.5, 1], 'em_strength': 40, 'mat': 1},
]

In [None]:
def render_img(spheres):
    sphere_pos, sphere_radius, mat_color, em_color, em_strength, mat = stack_dict_list(
        spheres
    )

    key = jax.random.PRNGKey(0)
    result_img = jnp.zeros((res_y, res_x, 3))

    k = 1000
    for i in trange(k):
        key, subkey = jax.random.split(key, 2)
        x_offset, y_offset = (jax.random.uniform(subkey, (2,)) - 0.5) * 0.005
        ray_pos, ray_dirs = get_init(
            res_x, res_y, x_persp, y_persp, camera_persp, x_offset, y_offset
        )
        key_grid = jax.random.split(subkey, res_x * res_y).reshape((res_x, res_y, -1))
        result_img += ray_trace(
            ray_pos.copy(),
            ray_dirs.copy(),
            key_grid,
            sphere_pos,
            sphere_radius,
            mat_color,
            em_color,
            em_strength,
            mat,
        )

    return result_img

In [None]:
images = []

for i in tqdm(jnp.linspace(0, 1, 90)):
    n = 5
    spheres = [
        {
            "pos": [5, -5 + l * 10, 0],
            "radius": 1,
            "mat_color": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 1),
            "em_color": [1, 1, 1],
            "em_strength": 0,
            "mat": max(l, 0.01),
        }
        for l in jnp.linspace(0, 1, n)
    ] + [
        {
            "pos": [5, 0, 40001],
            "radius": 40000,
            "mat_color": [0, 0.2, 0.4],
            "em_color": [1, 1, 1],
            "em_strength": 0,
            "mat": 1,
        },
        # {'pos': [5, 0, 40001], 'radius': 40000, 'mat_color': [1, 1, 1], 'em_color': [1, 1, 1], 'em_strength': 0, 'mat': 1},
        {
            "pos": [
                5 - math.sin(math.pi * i * 2) * 10,
                math.cos(math.pi * i * 2) * 10,
                -4,
            ],
            "radius": 1,
            "mat_color": [0, 0, 0],
            "em_color": [1, 1, 1],
            "em_strength": 3,
            "mat": 1,
        },
        # {'pos': [6, -10, 0], 'radius': 0.5, 'mat_color': [0, 0, 0], 'em_color': [0, 0.5, 1], 'em_strength': 40, 'mat': 1},
    ]
    result_img = render_img(spheres)
    # plt.figure(figsize=(15, 15))
    # result_img = result_img.at[:400].set(0)
    # plt.imshow(result_img/result_img[300:].mean(axis=-1).max()*16, interpolation='none')
    # plt.show()
    images.append(result_img)

In [None]:
spheres = [
    {
        "pos": [5, -5 + l * 10, 0],
        "radius": 1,
        "mat_color": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 1),
        "em_color": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 1),
        "em_strength": 0,
        "mat": max(l, 0.01),
    }
    for l in jnp.linspace(0, 1, n)
] + [
    # {'pos': [10000+15, 0, 0], 'radius': 10000, 'mat_color': [1, 1, 1], 'em_color': [1, 1, 1], 'em_strength': 0.01, 'mat': 0.5},
    # ground
    {
        "pos": [5, 0, 40000],
        "radius": 40000 - 1,
        "mat_color": [1, 1, 1],
        "em_color": [1, 1, 1],
        "em_strength": 0,
        "mat": 1,
    },
    # back wall
    # {'pos': [40000, 0, 0], 'radius': 40000-100, 'mat_color': [0.1, .1, 0.1], 'em_color': [1, 1, 1], 'em_strength': 0, 'mat': 1},
    # ceiling light
    {
        "pos": [5, 0, -40000],
        "radius": 40000 - 5000,
        "mat_color": [1, 1, 1],
        "em_color": [1, 1, 1],
        "em_strength": 0.000001,
        "mat": 1,
    },
    # {'pos': [10, 5, -5], 'radius': 4, 'mat_color': [1, 1, 1], 'em_color': [1, 1, 1], 'em_strength': 0, 'mat': 0},
    # {'pos': [5-math.sin(math.pi*i*2)*10, math.cos(math.pi*i*2)*10, -4], 'radius': 1, 'mat_color': [0, 0, 0], 'em_color': [1, 1, 1], 'em_strength': 3, 'mat': 1},
    # {'pos': [6, -10, 0], 'radius': 0.5, 'mat_color': [0, 0, 0], 'em_color': [0, 0.5, 1], 'em_strength': 40, 'mat': 1},
]
result_img = render_img(spheres)
plt.figure(figsize=(15, 15))
# result_img = result_img.at[:400].set(0)
plt.imshow(result_img / result_img[300:].mean(axis=-1).max(), interpolation="none")
plt.show()

In [None]:
import imageio

results = []
for i in range(1000):
    results.append(render_img(spheres))

cpus = jax.devices("cpu")
results = jnp.stack([jax.device_put(r, cpus[0]) for r in results])
images_procs = jnp.stack(images)
max_v = jnp.quantile(images_procs[:, 250:].flatten(), 0.95)
images_procs = (images_procs / max_v * 255).astype(jnp.uint8)

imageio.mimsave("video.mp4", list(images_procs) * 2, fps=20)

In [None]:
n = 5
spheres = [
    {
        "pos": [5 + math.sin(l * math.pi) * 4 * 1.2, math.cos(l * math.pi) * 4, 2],
        "radius": 1,
        "mat_color": colorsys.hls_to_rgb(l * (1 - 1 / n), 0.5, 0.93),
        "em_color": [1, 1, 1],
        "em_strength": 0,
        "mat": max(l, 0.01),
    }
    for l in jnp.linspace(0, 1, n)
] + [
    {
        "pos": [5, 0, 40003],
        "radius": 40000,
        "mat_color": [0, 0.2, 0.4],
        "em_color": [1, 1, 1],
        "em_strength": 0.002,
        "mat": 1,
    },
    {
        "pos": [5, 0, -4],
        "radius": 1.5,
        "mat_color": [0, 0, 0],
        "em_color": [1, 1, 1],
        "em_strength": 3,
        "mat": 1,
    },
    {
        "pos": [6, -10, 0],
        "radius": 0.5,
        "mat_color": [0, 0, 0],
        "em_color": [0, 0.5, 1],
        "em_strength": 40,
        "mat": 1,
    },
]

In [None]:
"""@jit
def ray_trace_iter(ray_pos, ray_dirs, spheres, sphere_radii):
    closest_hit = jnp.nanargmin(dst, axis=-1)

    # find closest hits
    did_hit_result = jnp.take_along_axis(did_hit, indices=closest_hit[..., None], axis=-1)
    dst_result = jnp.take_along_axis(dst, indices=closest_hit[..., None], axis=-1)
    hit_point_result = jnp.take_along_axis(hit_point, indices=closest_hit[..., None, None], axis=-1).squeeze()
    normal_result = jnp.take_along_axis(normal, indices=closest_hit[..., None, None], axis=-1).squeeze()

    return did_hit_result, dst_result, hit_point_result, normal_result
"""