In [127]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import numpy as np
import os
import trimesh
import b3d
from jax.scipy.spatial.transform import Rotation as Rot
from b3d import Pose
import rerun as rr
import functools

rr.init("demo22.py")
rr.connect("127.0.0.1:8812")

image_width=100
image_height=100
fx=50.0
fy=50.0
cx=50.0
cy=50.0
near=0.001
far=16.0
renderer = b3d.Renderer(
    image_width, image_height, fx, fy, cx, cy, near, far
)

## Render color
from pathlib import Path
mesh_path = Path(b3d.__file__).parents[1] / "assets/006_mustard_bottle/textured_simple.obj"
mesh = trimesh.load(mesh_path)

vertices = jnp.array(mesh.vertices) * 20.0
vertices = vertices - vertices.mean(0)
faces = jnp.array(mesh.faces)
vertex_colors = jnp.array(mesh.visual.to_color().vertex_colors)[...,:3] / 255.0
ranges = jnp.array([[0, len(faces)]])

# vertices = jnp.array([
#     [0.0, 0.0, 0.0],
#     [1.0, 0.0, 0.0],
#     [0.0, 1.0, 0.0],
# ])
# faces = jnp.array([
#     [0, 1, 2]
# ])
# ranges = jnp.array([[0, 1]])
# vertex_colors = jnp.array([
#     [1.0, 0.0, 0.0],
#     [1.0, 0.0, 0.0],
#     [1.0, 0.0, 0.0],
# ])

gt_pose = Pose.from_position_and_target(
    jnp.array([1.5, 1.5, 0.0]),
    jnp.zeros(3)
).inv()

target_image, depth = renderer.render_attribute(gt_pose.as_matrix()[None,...], vertices, faces, ranges, vertex_colors)
rr.log("/rgb", rr.Image(target_image), timeless=True)


In [128]:
jnp.linalg.norm(jnp.zeros(3))

Array(0., dtype=float32)

In [129]:
jnp.max

<function jax._src.numpy.reductions.max(a: 'ArrayLike', axis: 'Axis' = None, out: 'None' = None, keepdims: 'bool' = False, initial: 'ArrayLike | None' = None, where: 'ArrayLike | None' = None) -> 'Array'>

In [131]:
print(jax.grad(lambda p: jnp.linalg.norm(p + 1e-20))(jnp.zeros(3)))

@functools.partial(
    jnp.vectorize,
    signature="(3)->()",
    excluded=(
        1,
    ),
)
def point_to_line_distance(point, vector):
    unit_vector = vector / jnp.linalg.norm(vector)
    return jnp.linalg.norm(point - (1e-5 + jnp.dot(point, unit_vector)) * unit_vector) + 1e-5

print(jax.grad(lambda p: 1.0 /point_to_line_distance(p, jnp.ones(3)).sum())(jnp.zeros(3)))

[0.5773518 0.5773518 0.5773518]
[0. 0. 0.]


In [167]:

def render(pose):
    width = 3
    _, _, triangle_ids, _ = renderer.render(pose.as_matrix()[None,...], vertices, faces, ranges)
    vertices_transformed_by_pose = pose.apply(vertices)
    triangle_ids_padded = jnp.pad(triangle_ids, pad_width=[(width, width)])
    ijs = jnp.moveaxis(jnp.mgrid[: image_height, : image_width], 0, -1)
    ijs_padded  = jnp.pad(ijs, pad_width=[(width, width),(width, width), (0,0)])

    @functools.partial(
        jnp.vectorize,
        signature="(2)->(3),(m,m),(m,m)",
    )
    def get_mixed_color(ij):
        i,j = ij
        triangle_ids_patch = jax.lax.dynamic_slice(
            triangle_ids_padded,
            jnp.array([i,j]),
            (2*width+1, 2*width+1),
        )
        ijs_patch = jax.lax.dynamic_slice(
            ijs_padded,
            jnp.array([i,j, 0]),
            (2*width+1, 2*width+1,2),
        )

        pixel_vectors = jnp.concatenate([(ijs_patch - jnp.array([cx, cy])) / jnp.array([fx, fy]), jnp.ones_like(ijs_patch[...,0:1])], axis=-1)

        pixel_vector = pixel_vectors[width, width]

        center_points_of_triangle =(
            vertices_transformed_by_pose[faces[triangle_ids_patch-1]]
        ).mean(-2) * (triangle_ids_patch > 0)[...,None] + (1.0 - (triangle_ids_patch > 0)[...,None]) * pixel_vectors * 20.0

        distances = point_to_line_distance(center_points_of_triangle, pixel_vector) + 1e-5
        # return distances
        weights = 1 / (distances + 1e-5)
        # return weights
        normalized_weights = weights / weights.sum()
        
        colors_patch = (vertex_colors[faces[triangle_ids_patch-1]] * (triangle_ids_patch > 0)[...,None,None]).mean(-2)
        final_color = (colors_patch * normalized_weights[...,None]).sum(0).sum(0)
        return jnp.clip(final_color, 0.0, 1.0), triangle_ids_patch, normalized_weights
    mixed_image, triangle_ids_patch, normalized_weights = get_mixed_color(ijs)
    return mixed_image, triangle_ids_patch, normalized_weights
render_jit = jax.jit(render)

In [172]:
# pose = Pose.from_position_and_target(
#     jnp.array([1.3, 1.5, 0.0]),
#     jnp.zeros(3)
# ).inv()

mixed_image,triangle_ids_patch,normalized_weights = render_jit(gt_pose)
rr.log("/rgb/reconstruction", rr.Image(mixed_image), timeless=True)


In [173]:
triangle_ids_patch[20,34]

Array([[    0,     0,     0,     0,     0,     0,     0],
       [    0,     0,     0,     0,     0,     0,     0],
       [    0,     0,     0,     0,     0,     0,     0],
       [    0,     0,     0,     0,     0, 11348, 14065],
       [    0,     0,     0,     0, 11235, 13792, 13792],
       [    0,     0,     0, 11229, 13805, 13790, 13790],
       [    0,     0, 11186, 13745, 13745, 13711, 13784]], dtype=int32)

: 

In [159]:
pose = Pose.from_position_and_target(
    jnp.array([1.3, 1.1, 0.0]),
    jnp.zeros(3)
).inv()

mixed_image = render_jit(pose)
rr.log("/rgb/reconstruction", rr.Image(mixed_image), timeless=True)


grad_func = jax.jit(jax.value_and_grad(lambda pose: jnp.linalg.norm(render(pose) - target_image)))
print(grad_func(pose))

for _ in range(100):
    loss, pose_grad = grad_func(pose)
    pose = pose - pose_grad * 0.001
    print(loss)
    mixed_image = render_jit(pose)
    rr.log("/rgb/reconstruction", rr.Image(mixed_image), timeless=True)



(Array(36.9046, dtype=float32), Pose(position=Array([ 0.73913133,  1.5306431 , -0.7605345 ], dtype=float32), quaternion=Array([ 1.4641305 , -0.0424298 ,  0.08400118, -1.1915321 ], dtype=float32)))
36.9046
36.79632
36.81175
36.74235
36.803078
36.815353
36.88841
36.726166
36.725086
36.697395
36.74521
36.788677
36.873352
36.886784
36.80979
36.844276
36.864563
36.925213
36.93096
36.920807
36.95823
36.947903
36.93476
36.87172
36.80154
36.81649
36.745106
36.71426
36.670208
36.665283
36.68799
36.688618
36.629105
36.608887
36.575424
36.520374
36.486187
36.483585
36.42747
36.41817
36.412933
36.388382
36.400772
36.33639
36.35362
36.278988
36.293106
36.290493
36.25327
36.271286
36.172806
36.226723
36.192474
36.20666
36.217766
36.179432
36.2262
36.14478
36.155922
36.144325
36.093166
36.033974
36.012035
36.01805
35.97001
35.953606
35.883945
35.85473
35.84537
35.734154
35.754444
35.710728
35.740326
35.67199
35.70551
35.689983
35.682835
35.63262
35.614235
35.642967
35.556
35.588215
35.5674
35.51017
3

In [121]:
pose_grad

Pose(position=Array([ 0.05088093, -0.04757963, -0.02065831], dtype=float32), quaternion=Array([-0.00589564, -0.00238853,  0.00535257, -0.00713056], dtype=float32))

In [66]:
mixed_image = render(render, pose, 0, 0)

NameError: name 'mixed_image' is not defined