In [None]:
"""Figure 6: ablation of the sphere tracing weights"""
%reload_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.abspath('../'))
from common import *
from disk_tracing import *

import mitsuba as mi

from sdf2d.shapes import DiskSDF, UnionSDF

# dr.set_log_level(3)
fig_name = 'sphere_tracing_weights_ablation'

sdf = DiskSDF(mi.Vector2f(0.5, 0.24), mi.Float(0.25))
disk2 = DiskSDF(mi.Vector2f(0.3, 0.8), mi.Float(0.15))
sdf = UnionSDF(sdf, disk2, k=64, smooth=True)

RAY_COLOR = [0.2, 0.3, 0.5]
WARP_T_COLOR = [0.8, 0.2, 0.2]
SURFACE_COLOR = [1.0, 0.9, 0.7]
BG_COLOR = [0.95, 0.95, 0.95]

draw_sdf = False
base_size = 4
n_cols = 2
n_rows = 1
n_isolines = 31
y_offset = -0.15
fontsize = 14
fig = plt.figure(1, figsize=(n_cols * base_size, n_rows * base_size * 1.02), constrained_layout=False)
gs = fig.add_gridspec(n_rows, n_cols, wspace=0.025, hspace=0.025)
circle_pos = -0.05
ray_res = 1024
res = 512
params = [False, True]

labels = [r"(a) $\mathbf{x} + t \bm{\omega}$ without $w_{\mathrm{dist}}$", r"(b) $\mathbf{x} + t \bm{\omega}$ with $w_{\mathrm{dist}}$"]
for col, (label, use_approach_weighting) in enumerate(zip(labels, params)):
    ax = fig.add_subplot(gs[col])

    ray_origin = mi.Point2f(dr.sin(circle_pos * 2 * dr.pi), dr.cos(circle_pos * 2 * dr.pi)) * (sdf.sdf1.r + 1e-4) + sdf.sdf1.p
    n = dr.normalize(sdf.eval_grad(ray_origin))
    t = dr.arange(mi.Float, ray_res) / ray_res * dr.pi
    T = mi.Matrix2f([[-n.y, n.x],
                      [n.x, n.y]])
    d = dr.normalize(mi.Vector2f(dr.cos(t), dr.sin(t)))
    d = T @ d
    return_points = dr.width(t) == 1
    dr.set_flag(dr.JitFlag.LoopRecord, not return_points)
    ray = mi.Ray2f(ray_origin, d, 0.0, [])
    t, warp_t, points, dists, weight_integral = intersect_sdf_simple(sdf, ray, symbolic=True, use_approach_weighting=use_approach_weighting)

    y, x = dr.meshgrid(*[dr.linspace(mi.Float, 0, 1, res) for i in range(2)], indexing='ij')
    p = mi.Vector2f(x, y)
    sdf_values = sdf.eval(p)
    r = 0.6
    if draw_sdf:
        ax.imshow(np.reshape(sdf_values, (res, res)), interpolation='none', extent=[0,1,0,1],
                  cmap='coolwarm', vmin=-r, vmax=r, origin='lower')
    else:
        colors = dr.select(sdf_values < 0, mi.Vector3f(SURFACE_COLOR), mi.Vector3f(BG_COLOR))
        ax.imshow(np.reshape(colors, (res, res, 3)), interpolation='none', extent=[0,1,0,1], origin='lower')
      
    ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), levels=[0], colors='k')
    warp_p = ray(warp_t)
    warp_weight = dr.maximum(1 - dr.abs(sdf.eval(warp_p)) / 0.1, 0.0)

    # Multiply by weight integral to capture degenerate case of weights being zero
    warp_weight *= dr.clamp(weight_integral, 0.0, 1.0)
    if col == 0:
        warp_weight = np.ones_like(warp_weight)

    point_color = np.zeros((ray_res, 4))
    point_color[..., 0] = WARP_T_COLOR[0]
    point_color[..., 1] = WARP_T_COLOR[1]
    point_color[..., 2] = WARP_T_COLOR[2]
    point_color[..., -1] = warp_weight

    if draw_sdf:
        ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), 
                        levels=np.linspace(-1,1,n_isolines), alpha=0.9, colors='k', linewidths=0.2)
    ax.scatter(warp_p.x, warp_p.y, marker='.', color=point_color, edgecolor='none', s=50, zorder=11)
    ax.scatter(ray_origin.x, ray_origin.y, marker='.', color="white", edgecolor=[0.2,0.3,0.6], s=300, zorder=20, lw=1.5)

    # Draw a subset of rays 
    freq = 50
    ray_o = np.array(ray.o)[::freq, :] * np.ones((ray_res, 2))[::freq, :]
    ray_d = np.array(warp_p - ray.o)[::freq, :]

    arrow_color = np.zeros((ray_res, 4))[::freq, :]
    arrow_color[..., 0] = RAY_COLOR[0]
    arrow_color[..., 1] = RAY_COLOR[1]
    arrow_color[..., 2] = RAY_COLOR[2]
    arrow_color[..., -1] = 1

    ax.quiver(ray_o[:, 0], ray_o[:, 1], ray_d[:, 0], ray_d[:, 1], scale=1, scale_units='xy', color=arrow_color, zorder=15)

    ax.set_title(label, y=y_offset, fontsize=fontsize)
    ax.set_ylim(0.43, 0.85)
    ax.set_xlim(0.1, 0.7)
    disable_ticks(ax)

    ax.text(ray_origin.x[0] + 0.02, ray_origin.y[0] - 0.02, "$\\mathbf{x}$", ha='left')

    if col == 0:
        from matplotlib.lines import Line2D
        ax.text(ray_origin.x[0] + 0.1, 0.73, r"$\mathbf{x} + t \bm{\omega}$", ha='left')
        ax.add_line(Line2D([0.48, 0.51],[0.72,0.73], color='k'))

plt.margins(0, 0)
# save_fig(fig_name)

In [None]:
# 