In [None]:
"""Figure 5: sphere tracing steps visualization"""

%reload_ext autoreload
%autoreload 2

import sys
import os
sys.path.append(os.path.abspath('../'))
from common import *
from disk_tracing import *

import mitsuba as mi

from sdf2d.shapes import Grid2d, UnionSDF

import skfmm

sdf_res = 512
y, x = dr.meshgrid(*[dr.linspace(mi.Float, 0, 1, sdf_res) for i in range(2)], indexing='ij')
x = x * 0.9
mask = y > (dr.sin(10*x - 4) * x*0.7 + 0.9 - 0.7 * dr.sin(x))
mask = mask | (x > 0.8) 
mask = np.flipud(np.reshape(mask, (sdf_res, sdf_res)))
sdf = Grid2d(skfmm.distance(0.5 - mask, 1 / sdf_res))

fig_name = 'diff_sphere_tracing'

def eval_warp_field(sdf, res, eval_pos=None):
    if eval_pos is None:
        y_v, x_v = dr.meshgrid(*[dr.linspace(mi.Float, 0 + 0.5/res, 1-0.5/res, res) for i in range(2)], indexing='ij')
        p_v = mi.Vector2f(x_v, y_v)
    else:
        p_v = eval_pos 
        x_v = eval_pos.x
        y_v = eval_pos.y

    sdf_grad = sdf.eval_grad(p_v)
    sdf_value = sdf.eval(p_v)
    warp = -sdf_value * dr.detach(sdf_grad / dr.squared_norm(sdf_grad))
    return x_v, y_v, warp

RAY_COLOR = [0.2, 0.3, 0.5]
WARP_T_COLOR = [0.8, 0.2, 0.2]
SURFACE_COLOR = [1.0, 0.9, 0.7]
BG_COLOR = [0.95, 0.95, 0.95]
draw_sdf = False

y_min = 0.02
y_max = 0.85

y_min = 0.0
y_max = 0.9

n_isolines = 33
y_offset = -0.13
fontsize = 12
base_size = 4
n_rows = 1
n_cols = 4

total_width = TEXT_WIDTH
aspect = n_rows / n_cols * 1.02

fig = plt.figure(1, figsize=(total_width, total_width * aspect), constrained_layout=False)
gs = fig.add_gridspec(n_rows, n_cols, wspace=0.025, hspace=0.025)

# ---------------------------- Plot 0 ------------------------------------
# Plot SDF and its isolines
ax = fig.add_subplot(gs[0])

res = 1024
vector_field_res = 12
y, x = dr.meshgrid(*[dr.linspace(mi.Float, 0, 1, res) for i in range(2)], indexing='ij')
p = mi.Vector2f(x, y)

sdf_values = sdf.eval(p)
r = 0.6
ax.imshow(np.reshape(sdf_values, (res, res)), interpolation='none', extent=[0,1,0,1], cmap='coolwarm', vmin=-r, vmax=r, origin='lower')
ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), levels=[0], colors='k')
ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), 
                    levels=np.linspace(-1,1,n_isolines), alpha=0.9, colors='k', linewidths=0.2)
disable_ticks(ax)
x_v, y_v, warp = eval_warp_field(sdf, vector_field_res)
ax.quiver(x_v, y_v, warp.x, warp.y, scale=4)
ax.set_title(r"(a) $\mathcal{V}(\mathbf{x}, \bm{\pi})$", y=y_offset, fontsize=fontsize)
ax.set_ylim(y_min, y_max)
ax.set_xlim(0, 1)

# ---------------------------- Plot 1 ------------------------------------
# Plot SDF vector field
ax = fig.add_subplot(gs[1])

if type(sdf) is UnionSDF:
    dr.enable_grad(sdf.sdf1.p.y)
    dr.set_grad(sdf.sdf1.p.y, 0.0)
else:
    dr.enable_grad(sdf.p.y)
    dr.set_grad(sdf.p.y, 0.0)

x_v, y_v, warp = eval_warp_field(sdf, vector_field_res)

if type(sdf) is UnionSDF:
    dr.forward(sdf.sdf1.p.y)
else:
    dr.forward(sdf.p.y)

warp = dr.grad(warp)

y, x = dr.meshgrid(*[dr.linspace(mi.Float, 0, 1, res) for i in range(2)], indexing='ij')
p = mi.Vector2f(x, y)

sdf_values = sdf.eval(p)
r = 0.6
ax.imshow(np.reshape(sdf_values, (res, res)), interpolation='none', extent=[0,1,0,1], cmap='coolwarm', vmin=-r, vmax=r, origin='lower')
ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), levels=[0], colors='k')
ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), 
                    levels=np.linspace(-1,1,n_isolines), alpha=0.9, colors='k', linewidths=0.2)

mask = np.zeros((vector_field_res, vector_field_res))
mask[2:-6, -1:] = 1
mask = 1 - mask
mask = mask.astype(bool).ravel()

x_v = np.array(x_v)[mask]
y_v = np.array(y_v)[mask]
warp_x = np.array(warp.x)[mask]
warp_y = np.array(warp.y)[mask]

ax.quiver(x_v, y_v, warp_x, warp_y)

ax.arrow(0.93, 0.2, 0.0, 0.3, head_width=0.03, overhang=0.2, facecolor='k', length_includes_head=True, zorder=10)
txt = ax.text(0.94, 0.35, r'$\pi$')
txt.set_path_effects([path_effects.withStroke(linewidth=1.0, foreground='white')])

ax.set_title(r"(b) $\partial_{\pi}\mathcal{V}(\mathbf{x}, \bm{\pi})$", y=y_offset, fontsize=fontsize)
disable_ticks(ax)
ax.set_ylim(y_min, y_max)
ax.set_xlim(0, 1)
# ---------------------------- Plot 2 ------------------------------------
ax = fig.add_subplot(gs[2])
ray = mi.Ray2f([0.0, 0.4], [1, 0])
ray = mi.Ray2f([0.0, 0.6], [1, 0])
its_t, warp_t_integral, points, dists, _ = intersect_sdf_simple(sdf, ray)
points = np.array(points)
dists = np.array(dists)
r = 35000
use_log_space = False
sdf_values = sdf.eval(p)
sdf_grad_values = sdf.eval_grad(p)

weights = sphere_tracing_step_weight(ray, sdf_values, sdf_grad_values)
if use_log_space:
    weights = np.log(weights)
    r = 8

im_weights = ax.imshow(np.reshape(weights, (res, res)), interpolation='none', extent=[0,1,0,1], cmap='coolwarm', vmin=0, vmax=r, origin='lower')
ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), levels=[0], colors='k')
ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), 
                    levels=np.linspace(-1,1,3), alpha=0.9, colors='k', linewidths=0.2, zorder=10)

plt.scatter(points[:, 0], points[:, 1], color='red',zorder=30)
for p, r in zip(points, dists):
    ptc = matplotlib.patches.Circle(p, r, facecolor=[1,1,1], alpha=0.1, lw=0.5, edgecolor=None, zorder=20)
    ax.add_patch(ptc)
    ptc = matplotlib.patches.Circle(p, r, fill=None, alpha=1.0, lw=0.5, edgecolor='white', zorder=22)
    ax.add_patch(ptc)

ray_o = np.array(ray.o).ravel()
ray_d = np.array(ray.d).ravel() * np.array(its_t).ravel()
ax.arrow(ray_o[0], ray_o[1], ray_d[0], ray_d[1], head_width=0.03, overhang=0.2, facecolor='k', length_includes_head=True)
ax.set_ylim(y_min, y_max)
ax.set_xlim(0, 1)
disable_ticks(ax)
ax.set_title("(c) Sphere tracing steps", y=y_offset, fontsize=fontsize)

cbar_ax = fig.add_axes([0.652, 0.23, 0.05, 0.01])
cb = fig.colorbar(im_weights, cax=cbar_ax, label='Weight', orientation='horizontal')
cb.set_label(label='Weight', size=fontsize)
cbar_ax.xaxis.set_ticks_position('top')
xlbl = cbar_ax.xaxis.get_label()
xlbl.set_color('white')
cbar_ax.tick_params(axis='x', colors='white', labelsize=9)

# ---------------------------- Plot 3 ------------------------------------
ax = fig.add_subplot(gs[3])
n_rays = 128
ray_d = mi.Vector2f(dr.zeros(mi.Float, n_rays) + 1, dr.zeros(mi.Float, n_rays))
ray = mi.Ray2f(mi.Point2f(0.0, dr.linspace(mi.Float, 0, 1, n_rays)), ray_d)
its_t, warp_t_integral, points, dists, _ = intersect_sdf_simple(sdf, ray)

r = 0.6
if draw_sdf:
    ax.imshow(np.reshape(sdf_values, (res, res)), interpolation='none', extent=[0,1,0,1], cmap='coolwarm', vmin=-r, vmax=r, origin='lower')
    ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), 
                    levels=np.linspace(-1,1,n_isolines), alpha=0.9, colors='k', linewidths=0.2, zorder=10)
else:
    colors = dr.select(sdf_values < 0, mi.Vector3f(SURFACE_COLOR), mi.Vector3f(BG_COLOR))
    ax.imshow(np.reshape(colors, (res, res, 3)), interpolation='none', extent=[0,1,0,1], origin='lower')

ray_o_values = np.array(ray.o.y)
ax.plot(warp_t_integral, ray_o_values, color=[0.8, 0.2, 0.2], lw=3)

# Compute actual warp field on the eval positions
if False:
    dr.enable_grad(sdf.p.y)
    dr.set_grad(sdf.p.y, 0.0)
    x_v, y_v, warp = eval_warp_field(sdf, vector_field_res, eval_pos=dr.detach(ray(warp_t_integral)))
    dr.forward(sdf.p.y)
    warp = dr.grad(warp)
    freq = 19
    x_v = np.array(x_v)[::freq]
    y_v = np.array(y_v)[::freq]
    warp_x = np.array(warp.x)[::freq]
    warp_y = np.array(warp.y)[::freq]
    ax.quiver(x_v, y_v, warp_x, warp_y, zorder=20)

ax.contour(np.reshape(x, (res, res)), np.reshape(y, (res, res)), np.reshape(sdf_values, (res, res)), levels=[0], colors='k')

# Evaluate some fewer rays to actually draw
n_rays = 8
ray_d = mi.Vector2f(dr.zeros(mi.Float, n_rays) + 1, dr.zeros(mi.Float, n_rays))
ray = mi.Ray2f(mi.Point2f(0.0, dr.linspace(mi.Float, 0 + 0.5 / n_rays, 1 - 0.5 / n_rays, n_rays)), ray_d)
its_t, warp_t_integral, points, dists, _ = intersect_sdf_simple(sdf, ray)

for ray_o, ray_d, dist in zip(np.array(ray.o), np.array(ray.d), np.array(warp_t_integral)):
    ray_d = ray_d * dist
    ax.arrow(ray_o[0], ray_o[1], ray_d[0], ray_d[1], head_width=0.03, overhang=0.2, edgecolor=RAY_COLOR, facecolor=RAY_COLOR, length_includes_head=True, zorder=10)
ax.set_ylim(y_min, y_max)
ax.set_xlim(0, 1)
disable_ticks(ax)
ax.set_title("(d) Evaluation distance", y=y_offset, fontsize=fontsize)

plt.margins(0, 0)
# save_fig(fig_name)