In [1]:
import torch
import b3d
import os
import numpy as np
import time
import jax.numpy as jnp
import jax
import trimesh

In [2]:
width=64
height=64
fx=32.0
fy=32.0
cx=32.0
cy=32.0
near=0.001
far=16.0
renderer = b3d.Renderer(
    width, height, fx, fy, cx, cy, near, far
)

In [3]:
def rotation_from_axis_angle(axis, angle):
    """Creates a rotation matrix from an axis and angle.

    Args:
        axis (jnp.ndarray): The axis vector. Shape (3,)
        angle (float): The angle in radians.
    Returns:
        jnp.ndarray: The rotation matrix. Shape (3, 3)
    """
    sina = jnp.sin(angle)
    cosa = jnp.cos(angle)
    direction = axis / jnp.linalg.norm(axis)
    # rotation matrix around unit vector
    R = jnp.diag(jnp.array([cosa, cosa, cosa]))
    R = R + jnp.outer(direction, direction) * (1.0 - cosa)
    direction = direction * sina
    R = R + jnp.array(
        [
            [0.0, -direction[2], direction[1]],
            [direction[2], 0.0, -direction[0]],
            [-direction[1], direction[0], 0.0],
        ]
    )
    return R

def transform_from_rot(rotation):
    """Creates a pose matrix from a rotation matrix.

    Args:
        rotation (jnp.ndarray): The rotation matrix. Shape (3, 3)
    Returns:
        jnp.ndarray: The pose matrix. Shape (4, 4)
    """
    return jnp.vstack(
        [jnp.hstack([rotation, jnp.zeros((3, 1))]), jnp.array([0.0, 0.0, 0.0, 1.0])]
    )

def transform_from_axis_angle(axis, angle):
    """Creates a pose matrix from an axis and angle.

    Args:
        axis (jnp.ndarray): The axis vector. Shape (3,)
        angle (float): The angle in radians.
    Returns:
        jnp.ndarray: The pose matrix. Shape (4, 4)
    """
    return transform_from_rot(rotation_from_axis_angle(axis, angle))



In [4]:
r_mat = transform_from_axis_angle(jnp.array([0,0,1]), jnp.pi/2)

vec_transform_axis_angle = jax.vmap(transform_from_axis_angle, (None, 0))
rots = vec_transform_axis_angle(jnp.array([0,0,1]), jnp.linspace(jnp.pi/4, 17*jnp.pi/4, 64))

In [5]:
mesh_path = os.path.join(b3d.get_root_path(),
    "assets/shared_data_bucket/ycb_video_models/models/003_cracker_box/textured_simple.obj")
mesh = trimesh.load(mesh_path)

object_library = b3d.MeshLibrary.make_empty_library()
object_library.add_trimesh(mesh)

cam_inv_pose = b3d.Pose.from_position_and_target(
    jnp.array([0.15, 0.15, 0.0]),
    jnp.array([0.0, 0.0, 0.0])
).inv()


in_place_rots = b3d.Pose.from_matrix(rots)


compound_pose = cam_inv_pose @ in_place_rots #in_place_rot


In [6]:
compound_pose.shape

(64,)

In [7]:

rgbs, depths = renderer.render_attribute_many(
    compound_pose[:,None,...],
    object_library.vertices,
    object_library.faces,
    jnp.array([[0, len(object_library.faces)]]),
    object_library.attributes
)
#b3d.get_rgb_pil_image(rgb).save(b3d.get_root_path() / "assets/test_ycb.png")

In [8]:
rgbs.shape

(64, 64, 64, 3)

In [9]:
frames = np.savez('frames_64', rgbs)

In [10]:
frames = np.load('frames_64.npz')['arr_0']
frames.shape

(64, 64, 64, 3)

In [11]:
frames = np.array(frames*255).astype(int)
print(frames.shape)

device = 'cuda'
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker2").to(device)

video = torch.tensor(frames).permute(0, 3, 1, 2)[None].float().to(device)  # B T C H W

grid_size = 25
t0 = time.time()
pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size) # B T N 2,  B T N 1
t1 = time.time()

print("Cotracker took ", t1-t0, " seconds")

t0 = time.time()
pred_tracks, pred_visibility = cotracker(video, grid_size=grid_size) # B T N 2,  B T N 1
t1 = time.time()

print("Cotracker took ", t1-t0, " seconds")


(64, 64, 64, 3)


Using cache found in /home/esli/.cache/torch/hub/facebookresearch_co-tracker_main


Cotracker took  5.535351514816284  seconds
Cotracker took  5.504244565963745  seconds


In [12]:
frames_resized =  jax.vmap(jax.image.resize, in_axes=(0, None, None))(
        frames, (512, 512, 3), "linear"
    )


video_resized = torch.tensor(np.array(frames_resized)).permute(0, 3, 1, 2)[None].float().to(device)  # B T C H W


In [14]:

from cotracker.utils.visualizer import Visualizer

vis = Visualizer(save_dir=".", pad_value=120, linewidth=3)
vis.visualize(video_resized, pred_tracks*8, pred_visibility, filename='video_64')

pred_tracks_ = pred_tracks.cpu().numpy()
pred_visibility_ = pred_visibility.cpu().numpy()
np.savez("cotracker_output_64.npz", pred_tracks=pred_tracks_, pred_visibility=pred_visibility_)

Video saved to ./video_64.mp4
