This notebook aims to create ground truth videos which we will track throughout the lifecycle of this project. It will use a genjax model and generate rendered images and save it to disk, incuding any related params and metadata as a JSON file.

In [1]:
import os
import sys
sys.path.append("../../")
import jax
import genjax
import bayes3d as b
import jax.numpy as jnp
import numpy as np
from models import *
from utils import *
from viz import *
from renderer_setup import *
import pickle

console = genjax.pretty()

In [None]:
##################################################################################################
###                                         SCENE 1                                            ###
##################################################################################################
SCENE = 1

# random PRNG key
key_number = 314159
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 2                                            ###
##################################################################################################
SCENE = 2

# random PRNG key
key_number = 314159
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(1),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 3                                            ###
##################################################################################################
SCENE = 3

# random PRNG key
key_number = 222222
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(25),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 4                                            ###
##################################################################################################
SCENE = 4

# random PRNG key
key_number = 222222
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v2_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 2,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
####### ~~~~~~~~~~~~~~~~~~~`DEPRECATED DEPRECATED DEPRECATED DO NOT USE` ~~~~~~~~~~~~~~~~~#######
##################################################################################################
###                                         SCENE 5                                            ###
##################################################################################################
SCENE = 5

# random PRNG key
key_number = 3532452
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

MODEL_ARGS['keys'] = [jax.random.PRNGKey(71*i) for i in range(MODEL_ARGS["T_vec"].shape[0])]
MODEL_ARGS['chms'] = [genjax.empty_choice_map() for i in range(MODEL_ARGS["T_vec"].shape[0])]

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([0])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v3_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 3,
    "renderer_setup_version" : 2,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [2]:
##################################################################################################
###                                         SCENE 6                                            ###
##################################################################################################
SCENE = 6

# random PRNG key
key_number = 685988
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 5
# model args
MODEL_ARGS = {
    'pose' : None,
    'velocity' : None,
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.01, 10000.0])
POSE_BOUNDS = jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([0])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# Transition through the HMM
pose = INIT_POSE
velocity = INIT_VELOCITY
scores = []
rendered_images = [b.RENDERER.render(pose[None,...], CHOICE_MAP_ARGS["indices"])[...,:3]]

for t in range(T):
    MODEL_ARGS["pose"] = pose
    MODEL_ARGS["velocity"] = velocity
    _, trace = model_v4_importance_jit(key, chm, tuple(MODEL_ARGS.values()))
    scores.append(trace.get_score())
    pose, velocity = trace.get_retval()[1]
    rendered_images.append(trace.get_retval()[0])

MODEL_ARGS["pose"] = None
MODEL_ARGS["velocity"] = None

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "scores" : scores,
    "rendered" : jnp.stack(rendered_images),
    "model_version" : 4,
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 5)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)


In [7]:
##################################################################################################
###                                         SCENE 7                                            ###
##################################################################################################
SCENE = 7

# random PRNG key
key_number = 90878
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'pose' : None,
    'velocity' : None,
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
POSE_BOUNDS = jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([0])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# Transition through the HMM
pose = INIT_POSE
velocity = INIT_VELOCITY
scores = []
rendered_images = [b.RENDERER.render(pose[None,...], CHOICE_MAP_ARGS["indices"])[...,:3]]

for t in range(T):
    MODEL_ARGS["pose"] = pose
    MODEL_ARGS["velocity"] = velocity
    _, trace = model_v4_importance_jit(key, chm, tuple(MODEL_ARGS.values()))
    scores.append(trace.get_score())
    pose, velocity = trace.get_retval()[1]
    rendered_images.append(trace.get_retval()[0])

MODEL_ARGS["pose"] = None
MODEL_ARGS["velocity"] = None

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "scores" : scores,
    "rendered" : jnp.stack(rendered_images),
    "model_version" : 4,
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 5)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
