This notebook aims to create ground truth videos which we will track throughout the lifecycle of this project. It will use a genjax model and generate rendered images and save it to disk, incuding any related params and metadata as a JSON file.

In [1]:
import os
import sys
sys.path.append("../../")
import jax
import genjax
import bayes3d as b
from scipy.spatial.transform import Rotation as R

import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from models import *
from utils import *
from viz import *
from renderer_setup import *
import pickle
from dataclasses import dataclass
from genjax.generative_functions.distributions import ExactDensity

console = genjax.pretty()

In [None]:
##################################################################################################
###                                         SCENE 1                                            ###
##################################################################################################
SCENE = 1

# random PRNG key
key_number = 314159
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 2                                            ###
##################################################################################################
SCENE = 2

# random PRNG key
key_number = 314159
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(1),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 3                                            ###
##################################################################################################
SCENE = 3

# random PRNG key
key_number = 222222
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(25),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v1_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 1,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 4                                            ###
##################################################################################################
SCENE = 4

# random PRNG key
key_number = 222222
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20,
    'ids': range(1, 22)
}
# init renderer
setup_renderer_and_meshes_v1(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([23])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v2_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 2,
    "renderer_setup_version" : 1,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
####### ~~~~~~~~~~~~~~~~~~~`DEPRECATED DEPRECATED DEPRECATED DO NOT USE` ~~~~~~~~~~~~~~~~~#######
##################################################################################################
###                                         SCENE 5                                            ###
##################################################################################################
SCENE = 5

# random PRNG key
key_number = 3532452
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# model args
MODEL_ARGS = {
    'T_vec': jnp.zeros(5),
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'all_box_dims' : b.RENDERER.model_box_dims,
    'pose_bounds': jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]]),
    'outlier_volume': jnp.float32(1000.0),
    'init_vel_params': jnp.array([0.01, 10000.0]),
    'dynamics_params': jnp.array([[0.005, 10000.0], [0.005, 10000.0]]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

MODEL_ARGS['keys'] = [jax.random.PRNGKey(71*i) for i in range(MODEL_ARGS["T_vec"].shape[0])]
MODEL_ARGS['chms'] = [genjax.empty_choice_map() for i in range(MODEL_ARGS["T_vec"].shape[0])]

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([0])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# generate trace
key = jax.random.PRNGKey(key_number)
weight, trace = model_v3_importance_jit(key, chm, tuple(MODEL_ARGS.values()))

# save metadata
metadata = {
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : trace.get_score(),
    "rendered" : trace.get_retval()[0],
    "model_version" : 3,
    "renderer_setup_version" : 2,
    "init_pose" : trace["init_pose"]
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_trace(trace, rendered_addr = ("depths", "depths"), framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 6                                            ###
##################################################################################################
SCENE = 6

# random PRNG key
key_number = 685988
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 5
# model args
MODEL_ARGS = {
    'pose' : None,
    'velocity' : None,
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.01, 10000.0])
POSE_BOUNDS = jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([0])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# Transition through the HMM
pose = INIT_POSE
velocity = INIT_VELOCITY
scores = []
rendered_images = [b.RENDERER.render(pose[None,...], CHOICE_MAP_ARGS["indices"])[...,:3]]

for t in range(T):
    MODEL_ARGS["pose"] = pose
    MODEL_ARGS["velocity"] = velocity
    _, trace = model_v4_importance_jit(key, chm, tuple(MODEL_ARGS.values()))
    scores.append(trace.get_score())
    pose, velocity = trace.get_retval()[1]
    rendered_images.append(trace.get_retval()[0])

MODEL_ARGS["pose"] = None
MODEL_ARGS["velocity"] = None

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "scores" : scores,
    "rendered" : jnp.stack(rendered_images),
    "model_version" : 4,
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 7                                            ###
##################################################################################################
SCENE = 7

# random PRNG key
key_number = 90878
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'pose' : None,
    'velocity' : None,
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
POSE_BOUNDS = jnp.array([[-0.1, -0.1, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001,
    'indices': jnp.array([0])
}
chm = genjax.choice_map(CHOICE_MAP_ARGS)

# Transition through the HMM
pose = INIT_POSE
velocity = INIT_VELOCITY
scores = []
rendered_images = [b.RENDERER.render(pose[None,...], CHOICE_MAP_ARGS["indices"])[...,:3]]

for t in range(T):
    MODEL_ARGS["pose"] = pose
    MODEL_ARGS["velocity"] = velocity
    _, trace = model_v4_importance_jit(key, chm, tuple(MODEL_ARGS.values()))
    scores.append(trace.get_score())
    pose, velocity = trace.get_retval()[1]
    rendered_images.append(trace.get_retval()[0])

MODEL_ARGS["pose"] = None
MODEL_ARGS["velocity"] = None

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "scores" : scores,
    "rendered" : jnp.stack(rendered_images),
    "model_version" : 4,
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 5)

In [None]:
##################################################################################################
###                                         SCENE 8                                            ###
##################################################################################################
SCENE = 8

# random PRNG key
key_number = 90878
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
POSE_BOUNDS = jnp.array([[-0.0, -0.0, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)
model_v5_unfold = genjax.UnfoldCombinator.new(model_v5, T+1)

_, trace = jax.jit(model_v5_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : 5,
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
##################################################################################################
###                                         SCENE 8b                                           ###
##################################################################################################
SCENE = '8b'

# random PRNG key
key_number = 90878
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.01, 100.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
POSE_BOUNDS = jnp.array([[-0.0, -0.0, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE)
model_v5b_unfold = genjax.UnfoldCombinator.new(model_v5b, T+1)

_, trace = jax.jit(model_v5b_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : '5b',
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
##################################################################################################
###                                         SCENE 9                                            ###
##################################################################################################
SCENE = 9

# random PRNG key
key_number = 3434565
key_number = 90868576
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'T' : jnp.zeros(T),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.01, 0.01, 0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': 0.0001,
    'outlier_prob': 0.0001
}
chm = genjax.choice_map(
            CHOICE_MAP_ARGS
    )

_, trace = jax.jit(model_v6.importance)(key, chm, tuple(MODEL_ARGS.values()))

rendered_images = trace.get_retval()[0]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : 6,
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 10)

In [None]:
# Time Step
T = 50
model_v5_unfold = genjax.UnfoldCombinator.new(model_v5, T+1)
model_v5_unfold_importance_jit = jax.jit(model_v5_unfold.importance)

In [None]:
##################################################################################################
###                                         SCENE 10                                            ###
##################################################################################################
SCENE = 10

# random PRNG key
key_number = 12345
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)


# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.000005, 10000000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.000, 100000.0])
POSE_BOUNDS = jnp.array([[-0.0, -0.0, 1.5], [0.01, 0.01, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)


_, trace = model_v5_unfold_importance_jit(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : 5,
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
##################################################################################################
###                                         SCENE 11                                            ###
##################################################################################################
SCENE = 11

# random PRNG key
key_number = 90878
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v3(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
POSE_BOUNDS = jnp.array([[-0.0, -0.0, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)
model_v5_unfold = genjax.UnfoldCombinator.new(model_v5, T+1)

_, trace = jax.jit(model_v5_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : 5,
    "renderer_setup_version" : 3,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
##################################################################################################
###                                         SCENE 11b   VERY CUSTOMIZED DO NOT COPY            ###
##################################################################################################
SCENE = '11b'

# random PRNG key
key_number = 4567
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v3(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0007, 10000000000000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 100000000000000.0])
POSE_BOUNDS = jnp.array([[-0.0, -0.0, 1.5], [-0.2, -0.2, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)
model_v5_unfold = genjax.UnfoldCombinator.new(model_v5, T+1)

_, trace = jax.jit(model_v5_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : 5,
    "renderer_setup_version" : 3,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
##################################################################################################
###                                         SCENE 12                                            ###
##################################################################################################
SCENE = 12

# random PRNG key
key_number = 90878
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0]),
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
POSE_BOUNDS = jnp.array([[-0.0, -0.0, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)
model_v5b_unfold = genjax.UnfoldCombinator.new(model_v5b, T+1)

_, trace = jax.jit(model_v5b_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : '5b',
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
##################################################################################################
###                                         SCENE 13                                            ###
##################################################################################################
SCENE = 13

# random PRNG key
key_number = 90878
# renderer args
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}
# init renderer
setup_renderer_and_meshes_v2(**RENDERER_ARGS)

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0]),
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
POSE_BOUNDS = jnp.array([[-0.0, -0.0, 1.5], [0.1, 0.1, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[0]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)
model_v5b_unfold = genjax.UnfoldCombinator.new(model_v5c, T+1)

_, trace = jax.jit(model_v5c_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : '5c',
    "renderer_setup_version" : 2,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE))


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}

setup_renderer_and_meshes_v1(**RENDERER_ARGS)
b.RENDERER.mesh_names

In [None]:
for i,x in enumerate(b.RENDERER.mesh_names):
    print(f"{i} : {x}")

In [None]:
##################################################################################################
###                                         SCENE 14  MULTI -USE                               ###
##################################################################################################
SCENE = '14a'

# random PRNG key
key_number = 1234567
rend_id = 25

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'outlier_volume': jnp.float32(1000.0),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
POSE_BOUNDS = jnp.array([[0.0, 0, 1.5], [0.0, 0.0, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[rend_id]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)
model_v5_unfold = genjax.UnfoldCombinator.new(model_v5, T+1)

_, trace = jax.jit(model_v5_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    "INIT_VEL_PARAMS":INIT_VEL_PARAMS,
    "POSE_BOUNDS":POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "INIT_VELOCITY" :INIT_VELOCITY,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_version" : 5,
    "renderer_setup_version" : 1,
}
save_metadata(metadata, "scenes_old/scene_{}".format(SCENE), force_save = True)


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

In [None]:
# simple physics scene gen
import os
import sys
sys.path.append("../../")
import jax
import genjax
import bayes3d as b
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from models import *
from utils import *
from viz import *
from renderer_setup import *
import pickle
from dataclasses import dataclass
from scipy.spatial.transform import Rotation as R

from genjax.generative_functions.distributions import ExactDensity

console = genjax.pretty()

intrinsics = b.Intrinsics(
    height=50,
    width=80,
    fx=250.0, fy=250.0,
    cx=25.0, cy=25.0,
    near=0.1, far=20.0
)

b.setup_renderer(intrinsics)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/cube.obj"),scaling_factor=0.05)
b.RENDERER.add_mesh_from_file(os.path.join(b.utils.get_assets_dir(),"sample_objs/occulder.obj"),scaling_factor=0.05)

num_frames = 51

poses = [b.t3d.transform_from_pos(jnp.array([0.35, 0.05, 2]))]
delta_pose = b.t3d.transform_from_pos(jnp.array([-0.01, 0.0, 0.0]))

for t in range(num_frames-1):
    poses.append(poses[-1].dot(delta_pose))
poses = jnp.stack(poses)
print("Number of frames: ", poses.shape[0])

occ_pose = b.t3d.transform_from_rot_and_pos(
    R.from_euler('zyx', [0, 0, 90], degrees=True).as_matrix(),
    jnp.array([0.1, 0.0, 1.7])
)

occ_poses = jnp.tile(occ_pose[None,...],(51,1,1))

all_poses = jnp.stack([poses, occ_poses])
all_poses = jnp.swapaxes(all_poses, 0, 1)

rendered_images = b.RENDERER.render_many(all_poses,  jnp.array([0,1]))

# save metadata
metadata = {
    "T" : num_frames,
    "occ_poses":occ_poses,
    "poses":poses,
    "MODEL_ARGS" : MODEL_ARGS,
    "rendered" : rendered_images,
}
save_metadata(metadata, "scenes_old/physics_simple", force_save = True)

video_from_rendered(rendered_images[...,:3], framerate = 24)