In [127]:
import jax.numpy as jnp
import jax
import matplotlib.pyplot as plt
import numpy as np
import os
import trimesh
import b3d
from jax.scipy.spatial.transform import Rotation as Rot
from b3d import Pose
import rerun as rr
import functools

rr.init("demo22.py")
rr.connect("127.0.0.1:8812")

image_width=100
image_height=100
fx=50.0
fy=50.0
cx=50.0
cy=50.0
near=0.001
far=16.0
renderer = b3d.Renderer(
    image_width, image_height, fx, fy, cx, cy, near, far
)

## Render color
from pathlib import Path
mesh_path = Path(b3d.__file__).parents[1] / "assets/006_mustard_bottle/textured_simple.obj"
mesh = trimesh.load(mesh_path)

vertices = jnp.array(mesh.vertices) * 20.0
vertices = vertices - vertices.mean(0)
faces = jnp.array(mesh.faces)
vertex_colors = jnp.array(mesh.visual.to_color().vertex_colors)[...,:3] / 255.0
ranges = jnp.array([[0, len(faces)]])

# vertices = jnp.array([
#     [0.0, 0.0, 0.0],
#     [1.0, 0.0, 0.0],
#     [0.0, 1.0, 0.0],
# ])
# faces = jnp.array([
#     [0, 1, 2]
# ])
# ranges = jnp.array([[0, 1]])
# vertex_colors = jnp.array([
#     [1.0, 0.0, 0.0],
#     [1.0, 0.0, 0.0],
#     [1.0, 0.0, 0.0],
# ])

gt_pose = Pose.from_position_and_target(
    jnp.array([1.5, 1.5, 0.0]),
    jnp.zeros(3)
).inv()

target_image, depth = renderer.render_attribute(gt_pose.as_matrix()[None,...], vertices, faces, ranges, vertex_colors)
rr.log("/rgb", rr.Image(target_image), timeless=True)


In [128]:
jnp.linalg.norm(jnp.zeros(3))

Array(0., dtype=float32)

In [129]:
jnp.max

<function jax._src.numpy.reductions.max(a: 'ArrayLike', axis: 'Axis' = None, out: 'None' = None, keepdims: 'bool' = False, initial: 'ArrayLike | None' = None, where: 'ArrayLike | None' = None) -> 'Array'>

In [131]:
print(jax.grad(lambda p: jnp.linalg.norm(p + 1e-20))(jnp.zeros(3)))

@functools.partial(
    jnp.vectorize,
    signature="(3)->()",
    excluded=(
        1,
    ),
)
def point_to_line_distance(point, vector):
    unit_vector = vector / jnp.linalg.norm(vector)
    return jnp.linalg.norm(point - (1e-5 + jnp.dot(point, unit_vector)) * unit_vector) + 1e-5

print(jax.grad(lambda p: 1.0 /point_to_line_distance(p, jnp.ones(3)).sum())(jnp.zeros(3)))

[0.5773518 0.5773518 0.5773518]
[0. 0. 0.]


In [143]:

def render(pose):
    width = 3
    _, _, triangle_ids, _ = renderer.render(pose.as_matrix()[None,...], vertices, faces, ranges)
    vertices_transformed_by_pose = pose.apply(vertices)
    triangle_ids_padded = jnp.pad(triangle_ids, pad_width=[(width, width)])
    ijs = jnp.moveaxis(jnp.mgrid[: image_height, : image_width], 0, -1)
    ijs_padded  = jnp.pad(ijs, pad_width=[(width, width),(width, width), (0,0)])

    @functools.partial(
        jnp.vectorize,
        signature="(2)->(3)",
    )
    def get_mixed_color(ij):
        i,j = ij
        triangle_ids_patch = jax.lax.dynamic_slice(
            triangle_ids_padded,
            jnp.array([i,j]),
            (2*width+1, 2*width+1),
        )
        ijs_patch = jax.lax.dynamic_slice(
            ijs_padded,
            jnp.array([i,j, 0]),
            (2*width+1, 2*width+1,2),
        )

        pixel_vectors = jnp.concatenate([(ijs_patch - jnp.array([cx, cy])) / jnp.array([fx, fy]), jnp.ones_like(ijs_patch[...,0:1])], axis=-1)

        pixel_vector = pixel_vectors[width, width]

        center_points_of_triangle =(
            vertices_transformed_by_pose[faces[triangle_ids_patch-1]]
        ).mean(-2) * (triangle_ids_patch > 0)[...,None] + (1.0 - (triangle_ids_patch > 0)[...,None]) * pixel_vectors * 10.0

        distances = point_to_line_distance(center_points_of_triangle, pixel_vector) + 1e-10
        # return distances
        weights = 1 / (distances + 1e-10)**2
        # return weights
        normalized_weights = weights / weights.sum()
        
        colors_patch = (vertex_colors[faces[triangle_ids_patch-1]] * (triangle_ids_patch > 0)[...,None,None]).mean(-2)
        final_color = (colors_patch * normalized_weights[...,None]).sum(0).sum(0)
        return jnp.clip(final_color, 0.0, 1.0)
    mixed_image = get_mixed_color(ijs)
    return mixed_image
render_jit = jax.jit(render)

In [144]:
pose = Pose.from_position_and_target(
    jnp.array([1.3, 1.5, 0.0]),
    jnp.zeros(3)
).inv()

mixed_image = render_jit(pose)
rr.log("/rgb/reconstruction", rr.Image(mixed_image), timeless=True)


In [147]:
pose = Pose.from_position_and_target(
    jnp.array([1.3, 1.1, 0.0]),
    jnp.zeros(3)
).inv()

mixed_image = render_jit(pose)
rr.log("/rgb/reconstruction", rr.Image(mixed_image), timeless=True)


grad_func = jax.jit(jax.value_and_grad(lambda pose: jnp.linalg.norm(render(pose) - target_image)))
print(grad_func(pose))

for _ in range(100):
    loss, pose_grad = grad_func(pose)
    pose = pose - pose_grad * 0.001
    print(loss)
    mixed_image = render_jit(pose)
    rr.log("/rgb/reconstruction", rr.Image(mixed_image), timeless=True)



(Array(34.69789, dtype=float32), Pose(position=Array([ 3.6451669 ,  2.092086  , -0.49895382], dtype=float32), quaternion=Array([ 1.9135023 , -1.0528479 , -0.3247652 , -0.34368014], dtype=float32)))
34.69789
34.631855
34.69833
34.697136
34.60359
34.631237
34.82761
34.918217
34.95508
35.083687
35.092987
34.986073
35.021503
35.07751
35.02602
35.043663
35.03674
35.094616
35.135773
35.177944
35.16743
35.22364
35.178173
35.18571
35.156693
35.184204
35.20005
35.1813
35.196766
35.22987
35.263264
35.28628
35.35964
35.411514
35.436497
35.43949
35.445797
35.537056
35.55499
35.668697
35.67976
35.650578
35.66643
35.768482
35.874866
35.82325
35.87125
35.936207
35.926006
35.938522
36.075573
36.034767
35.97648
35.918198
35.90619
35.96716
35.92005
35.94612
35.89702
35.948463
35.956238
36.042915
36.139084
36.143463
36.235386
36.299976
36.333378
36.42592
36.59348
36.653328
36.754612
36.76366
36.851364
36.914455
36.95285
37.063515
37.158
37.251373
37.26795
37.314274
37.36315
37.47473
37.585007
37.71823
37

In [121]:
pose_grad

Pose(position=Array([ 0.05088093, -0.04757963, -0.02065831], dtype=float32), quaternion=Array([-0.00589564, -0.00238853,  0.00535257, -0.00713056], dtype=float32))

In [66]:
mixed_image = render(render, pose, 0, 0)

NameError: name 'mixed_image' is not defined