In [None]:
import jax
import jax.numpy as np
import jax.numpy as jnp
import os
import bayes3d as b
import matplotlib.pyplot as plt

In [None]:
'''
JAX implementation of Soft Rasterizer (softras)
(c) 2021 Kartik Chandra; see MIT license attached

Soft Rasterizer: A Differentiable Renderer for Image-based 3D Reasoning
Shichen Liu, Tianye Li, Weikai Chen, and Hao Li (ICCV 2019)
https://arxiv.org/abs/1904.01786
https://github.com/ShichenLiu/SoftRas
'''

def get_pixel(left=-1, top=1, right=1., bottom=-1., xres=50, yres=50):
    '''
    Generates a grid of pixel samples in a given viewport, at a given resolution.
    
    Parameters:
        left (float): left edge of viewport
        top (float): top edge of viewport
        right (float): right edge of viewport
        bottom (float): bottom edge of viewport
        xres (int): number of samples along horizontal axis
        yres (int): number of samples along vertical axis
    Returns:
        pixel (N x 3): array of sample coordinates,
        shape (tuple): shape to reshape softras'ed outputs to get an image
    '''
    Xs = np.linspace(left, right, xres)
    Ys = np.linspace(top, bottom, yres)
    X, Y = np.meshgrid(Xs, Ys)
    Z = np.zeros_like(X)
    pixel = np.stack((X, Y, Z), axis=2).reshape(-1, 1, 3)
    return pixel, Z.shape

eps = 1e-8

def softras(mesh, pixel, C, SIGMA=1e-1, GAMMA=1e-1):
    '''
    Differentiably rasterizes a mesh using the SoftRas algorithm.
    
    Parameters:
        mesh (T x 3[face] x 3[xyz]): mesh, as list of triples of vertices
        pixel (N x 3[xyz]): pixel locations at which to render
        C (T x 3[face]): texture brightness at each face
        SIGMA (float): parameter from softras paper
        GAMMA (float): parameter from softras paper
    Returns:
        image (N): rendered pixel values, should be reshaped to form image
    '''
    Zbuf = mesh[:, :, 2]
    proj = mesh.at[:, :, 2].set(0)
    
    def dot(a, b):
        return (a * b).sum(axis=-1, keepdims=True)

    def d2_point_to_finite_edge(i):
        A = proj[:, i, :]
        B = proj[:, (i + 1) % 3, :]
        Va = B - A
        Vp = pixel - A
        projln = dot(Vp, Va) / (dot(Va, Va) + eps)
        projpt = np.clip(projln, 0, 1) * Va[None, :, :]
        out = dot(Vp - projpt, Vp - projpt)
        return out[:, :, 0]

    d2 = np.minimum(
        np.minimum(d2_point_to_finite_edge(0), d2_point_to_finite_edge(1)),
        d2_point_to_finite_edge(2)
    )
    
    def signed_area_to_point(i):
        A = proj[:, i, :]
        B = proj[:, (i + 1) % 3, :]
        Va = B - A
        area = np.cross(Va, pixel - A)[:, :, 2] / 2
        return area

    Aa = signed_area_to_point(0)
    Ab = signed_area_to_point(1)
    Ac = signed_area_to_point(2)
    Aabc = Aa + Ab + Ac + eps
    in_triangle =\
        np.equal(np.sign(Aa), np.sign(Ab)).astype('float32') *\
        np.equal(np.sign(Aa), np.sign(Ac)).astype('float32') * 2 - 1

    D = jax.nn.sigmoid(in_triangle * (d2 + 0.02) / SIGMA)

    bary = np.stack([Aa, Ab, Ac], axis=2) / Aabc[:, :, None]
    bary_clipped = np.clip(bary, 0, 1)
    bary_clipped = bary_clipped / (bary_clipped.sum(axis=2, keepdims=True) + eps)

    Zb = (bary_clipped * np.roll(Zbuf, 1, axis=1)).sum(axis=2)
    Zb = (Zb.max() - Zb) / (Zb.max() - Zb.min())

    Zbe = np.exp(np.clip(Zb / GAMMA, -20., 20.))
    DZbe = D * Zbe
    w = DZbe / (DZbe.sum(axis=1, keepdims=True) + np.exp(eps / GAMMA))
    return (w * DZbe).sum(axis=1)

In [None]:
key = jax.random.PRNGKey(7)

In [None]:
mesh_path = os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj")
m = b.utils.load_mesh(mesh_path)
vertices = jnp.array(m.vertices)
faces = jnp.array(m.faces)
width = 100
pixel, size = get_pixel( xres=width, yres=width)
mesh = vertices[faces][:-4]
C = jax.random.uniform(jax.random.PRNGKey(3), shape=(1, mesh.shape[0]))
print(b.utils.aabb(vertices))

In [None]:
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
idx = 14
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
m = b.utils.load_mesh(mesh_path)
m = b.utils.scale_mesh(m, 1.0/100.0)

vertices = jnp.array(m.vertices)
faces = jnp.array(m.faces)[:8000]
C = jax.random.uniform(jax.random.PRNGKey(3), shape=(1, mesh.shape[0]))
print(b.utils.aabb(vertices))

In [None]:
def render_img(quat):
    pose = b.transform_from_rot_and_pos(
        b.quaternion_to_rotation_matrix(quat),
        jnp.array([0.0, 0.0, 3.0])
    )
    vertices_transformed = b.apply_transform(vertices, pose) 
    mesh = vertices_transformed[faces]
    img = softras(mesh, pixel, C,  SIGMA=0.02, GAMMA=100.0)
    return img

In [None]:
key = jax.random.split(key,1)[1]
R = b.distributions.vmf_jit(key, 0.0001)
print(R)
gt_orientation = b.rotation_matrix_to_quaternion(R)
print(gt_orientation)
print(b.quaternion_to_rotation_matrix(gt_orientation))
gt_img = render_img(gt_orientation)
plt.imshow(gt_img.reshape(size))
plt.colorbar()

In [None]:
def loss(quat):
    reconstruction = render_img(quat)
    return ((reconstruction - gt_img)**2).sum()
value_and_grad_loss = jax.jit(jax.value_and_grad(loss))

In [None]:
key = jax.random.split(key,1)[1]
R_start = b.distributions.vmf_jit(key, 0.0001)
estimated_quat = b.rotation_matrix_to_quaternion(R_start)
reconstruction = render_img(estimated_quat)
plt.imshow(reconstruction.reshape(size))
plt.colorbar()

In [None]:
quats = []
for _ in range(100):
    loss_val, gradient_quat = value_and_grad_loss(estimated_quat)
    print(loss_val)
    estimated_quat -= gradient_quat * 0.01


In [None]:
plt.imshow(jnp.hstack([render_img(gt_orientation).reshape(size), render_img(estimated_quat).reshape(size)]))

In [None]:
# key = jax.random.split(key,1)[1]
random_pose = b.distributions.vmf_jit(key, 0.0001)
vertices = b.apply_transform(vertices, random_pose) 
vertices = vertices + jnp.array([0.0, 0.0, 3.0]) 
mesh = vertices[faces][:-4]
width = 100
pixel, size = get_pixel( xres=width, yres=width)
C = jax.random.uniform(jax.random.PRNGKey(3), shape=(1, mesh.shape[0]))

img = softras(mesh, pixel, C,  SIGMA=0.01, GAMMA=10.0)
plt.imshow(img.reshape(size))
plt.colorbar()

In [None]:
b.

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.imshow(pixel[:, 0, 2].reshape(size))

In [None]:
pixel.shape

In [None]:
model_dir = os.path.join(b.utils.get_assets_dir(),"bop/ycbv/models")
idx = 14
mesh_path = os.path.join(model_dir,"obj_" + "{}".format(idx).rjust(6, '0') + ".ply")
m = b.utils.load_mesh(mesh_path)
m = b.utils.scale_mesh(m, 1.0/100.0)