This notebook aims to create ground truth videos which we will track throughout the lifecycle of this project. It will use a genjax model and generate rendered images and save it to disk, incuding any related params and metadata as a JSON file.

In [1]:
import os
import sys
sys.path.append("../../")
import jax
import genjax
import bayes3d as b
import jax.numpy as jnp
import numpy as np
from models import *
from utils import *
from viz import *
from renderer_setup import *
import pickle

console = genjax.pretty()

AttributeError: 'function' object has no attribute 'simulate'

In [None]:
##################################################################################################
###                                         SCENE 1                                            ###
##################################################################################################
SCENE = 1

# random PRNG key
key_number = 314159
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 2                                            ###
##################################################################################################
SCENE = 2

# random PRNG key
key_number = 314159
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(1),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 3                                            ###
##################################################################################################
SCENE = 3

# random PRNG key
key_number = 222222
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(25),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [2]:
##################################################################################################
###                                         SCENE 4                                            ###
##################################################################################################
SCENE = 4

# random PRNG key
key_number = 222222
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v2_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 2,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

NameError: name 'setup_renderer_and_meshes_v1' is not defined