This notebook aims to create ground truth videos which we will track throughout the lifecycle of this project. It will use a genjax model and generate rendered images and save it to disk, incuding any related params and metadata as a JSON file.

In [2]:
import os
import sys
sys.path.append("../../")
import jax
import genjax
import bayes3d as b
from scipy.spatial.transform import Rotation as R

import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
from models import *
from utils import *
from viz import *
from renderer_setup import *
import pickle
from dataclasses import dataclass
from genjax.generative_functions.distributions import ExactDensity

console = genjax.pretty()

In [5]:
# Scene 1: cube moving around

RENDERER_ARGS = {
    'height': 50,
    'width': 50,
    'focal_length': 250,
    'near': 0.1,
    'far': 20
}

setup_renderer_and_meshes_v3(**RENDERER_ARGS)
SCENE = '1'

# random PRNG key
key_number = 65567567
rend_id = 0

# Time Step
T = 50
# model args
MODEL_ARGS = {
    'N_total_vec' : jnp.arange(len(b.RENDERER.meshes)),
    'N_vec': jnp.zeros(1),
    'vel_params': jnp.array([0.0005, 10000.0]),
    'variance_params': jnp.array([0.00000000001, 10000.0]),
    'outlier_prob_params': jnp.array([-0.01, 10000.0])
}

key = jax.random.PRNGKey(key_number)
INIT_VEL_PARAMS = jnp.array([0.001, 10000.0])
INIT_VELOCITY = b.gaussian_vmf_pose.sample(key, jnp.eye(4), *INIT_VEL_PARAMS)
POSE_BOUNDS = jnp.array([[0.0, 0, 1.5], [0.0, 0.0, 2]])
INIT_POSE = b.uniform_pose.sample(key, POSE_BOUNDS[0], POSE_BOUNDS[1])

# choice map args
CHOICE_MAP_ARGS = {
    'variance': jnp.repeat(0.0001,T+1),
    'outlier_prob': jnp.repeat(0.0001,T+1),
    'indices': jnp.tile(jnp.array([[rend_id]]), (T+1,1))
}
chm = genjax.index_choice_map(
        jnp.arange(0,T+1), genjax.choice_map(
            CHOICE_MAP_ARGS
        )
    )

# First rendered image is known
rendered_images = [b.RENDERER.render(INIT_POSE[None,...], CHOICE_MAP_ARGS["indices"][0])[...,:3]]

# unfold the model
INIT_STATE = (rendered_images[0], INIT_POSE, INIT_VELOCITY)
model_single_object_unfold = genjax.UnfoldCombinator.new(model_single_object, T+1)

_, trace = jax.jit(model_single_object_unfold.importance)(key, chm, (T, INIT_STATE, *tuple(MODEL_ARGS.values())))

rendered_images = trace.get_retval()[0]
poses = trace.get_retval()[1]
score = trace.get_score()

# save metadata
metadata = {
    "T" : T,
    'INIT_VELOCITY' : INIT_VELOCITY,
    'INIT_VEL_PARAMS' : INIT_VEL_PARAMS,
    "POSE_BOUNDS": POSE_BOUNDS,
    "INIT_POSE":INIT_POSE,
    "key_number": key_number,
    "RENDERER_ARGS": RENDERER_ARGS,
    "MODEL_ARGS" : MODEL_ARGS,
    "CHOICE_MAP_ARGS" : CHOICE_MAP_ARGS,
    "poses" : poses,
    "score" : score,
    "rendered" : rendered_images,
    "model_name" : "model_single_object",
    "renderer_setup_version" : 3,
}
save_metadata(metadata, "scenes/scene_{}".format(SCENE), force_save = True)


# see depth GT
video_from_rendered(metadata["rendered"], framerate = 24)

[E rasterize_gl.cpp:121] OpenGL version reported as 4.6


Increasing frame buffer size to (width, height, depth) = (64, 64, 1024)
