In [1]:
%load_ext autoreload
%autoreload 2

In [13]:

import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import numpy as np
import os
import trimesh
import b3d
from jax.scipy.spatial.transform import Rotation as Rot
from b3d import Pose
import rerun as rr
import functools

jax.config.update("jax_debug_nans", False)

In [3]:
from demos.differentiable_renderer.utils import (
    center_and_width_to_vertices_faces_colors, rr_log_gt, ray_from_ij,
    fx, fy, cx, cy
)
from demos.differentiable_renderer.rendering import all_pairs, render, renderer, project_pixel_to_plane


In [16]:

particle_centers = jnp.array(
    [
        [0.0, 0.0, 1.0],
        [0.2, 0.2, 2.0],
        # [0., 0., 5.]
    ]
)
particle_widths = jnp.array([0.1, 0.3])
particle_colors = jnp.array(
    [
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        # [0.0, 0.0, 1.0]
    ]
)

ij = jnp.array([51, 52])

vertices_og, faces, colors, triangle_to_particle_index = jax.vmap(
    center_and_width_to_vertices_faces_colors
)(jnp.arange(len(particle_centers)), particle_centers, particle_widths, particle_colors)
vertices = vertices_og.reshape(-1, 3)
faces = faces.reshape(-1, 3)
colors = colors.reshape(-1, 3)
triangle_to_particle_index = triangle_to_particle_index.reshape(-1)
_, _, triangle_id_image, depth_image = renderer.rasterize(
    Pose.identity()[None, ...], vertices, faces, jnp.array([[0, len(faces)]])
)
particle_intersected = triangle_to_particle_index[triangle_id_image - 1] * (triangle_id_image > 0) + -1 * (triangle_id_image ==0 )
blank_color = jnp.array([0.1, 0.1, 0.1]) # gray for unintersected particles
extended_colors = jnp.concatenate([jnp.array([blank_color]), particle_colors], axis=0)
color_image = extended_colors[particle_intersected + 1]
triangle_colors = particle_colors[triangle_to_particle_index]


In [18]:

rr.init("softras_5")
rr.connect("127.0.0.1:8812")

rr_log_gt("gt", particle_centers, particle_widths, particle_colors)

SIGMA = 1e-4
GAMMA = 1e-4
EPSILON = 1e-5
hyperparams = (SIGMA, GAMMA, EPSILON)
rendered_soft = render(vertices, faces, triangle_colors, hyperparams)
rr.log("c/gt", rr.Image(color_image), timeless=True)
rr.log("c/rendered", rr.Image(rendered_soft), timeless=True)


In [22]:

def compute_error(centers):
    rendered = render_from_centers(centers)
    return jnp.sum(jnp.abs((rendered - rendered_soft)))

def render_from_centers(new_particle_centers):
    particle_center_delta  = new_particle_centers - particle_centers
    # vertices, faces, colors, triangle_to_particle_index = jax.vmap(
    #     center_and_width_to_vertices_faces_colors
    # )(jnp.arange(len(new_particle_centers)), new_particle_centers, particle_widths, particle_colors)
    # vertices = vertices.reshape(-1, 3)
    # faces = faces.reshape(-1, 3)
    # triangle_to_particle_index = triangle_to_particle_index.reshape(-1)
    new_vertices = vertices_og + jnp.expand_dims(particle_center_delta, 1)
    return render(new_vertices.reshape(-1, 3), faces.reshape(-1, 3), particle_colors[triangle_to_particle_index], hyperparams)

particle_centers_shifted = jnp.array(
    [
        [0.05, 0.0, 1.0],
        [0.15, 0.2, 2.0],
        # [0., 0., 5.]
    ]
)
rendered_shifted = render_from_centers(particle_centers_shifted)
rr.log("shifted", rr.Image(rendered_shifted), timeless=True)

print("ERROR:")
print(compute_error(particle_centers_shifted))
print("GRAD:")
print(jax.grad(compute_error)(particle_centers_shifted))

ERROR:
62.857605
GRAD:
[[nan nan nan]
 [nan nan nan]]


In [23]:
from demos.differentiable_renderer.rendering import get_pixel_color, WINDOW

uvs, _, triangle_id_image, depth_image = renderer.rasterize(
    Pose.identity()[None, ...], vertices, faces, jnp.array([[0, len(faces)]])
)

triangle_intersected_padded = jnp.pad(
    triangle_id_image, pad_width=[(WINDOW, WINDOW)], constant_values=-1
)
particle_center_delta  = particle_centers_shifted - particle_centers
vertices_shifted = vertices_og + jnp.expand_dims(particle_center_delta, 1)

ij = jnp.array([37, 37])
def get_pixel_color_from_vertices(ij, vertices):
    return get_pixel_color(
        ij, vertices, faces, triangle_colors, triangle_intersected_padded,
        hyperparams
    ).sum()
color = get_pixel_color_from_vertices(ij, vertices_shifted.reshape(-1, 3))
# grads = jax.vmap(
#     jax.grad(get_pixel_color_from_vertices, argnums=1),
#     in_axes=(0, None)
# )(all_pairs(100, 100), vertices_shifted.reshape(-1, 3))
# isnan_img = jnp.any(jnp.isnan(grads), axis=(1, 2)).reshape(100, 100).astype(float)


In [24]:
jax.grad(get_pixel_color_from_vertices, argnums=1)(ij, vertices_shifted.reshape(-1, 3))

Array([[nan, nan, nan],
       [nan, nan, nan],
       [nan, nan, nan],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.],
       [ 0.,  0.,  0.]], dtype=float32)

In [15]:
vertices_shifted.shape

(3, 4, 3)