In [1]:
# Evaluating a trained point net model
from train_nerf import create_train_state
from model import NerfMLP, render_image
from datasets import matrix_from_filename
from flax.training import checkpoints
from jax import random
import jax.numpy as jnp
import os
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

config = {
    "epochs": 100,
    "batch_size": 2048,
    "init_lr": 5e-5,
    "transition_steps": 7500,
    "decay_rate": 0.9,
    "min_lr": 5e-6,
    "model_dir": os.path.abspath("./checkpoints")
}

rng = random.PRNGKey(0)
rng, input_rng = random.split(rng)
target_state = create_train_state(NerfMLP(), config, input_rng)
state = checkpoints.restore_checkpoint(ckpt_dir= os.path.abspath("./checkpoints"), target=target_state)

# getting the test set
test_names = [
    "2_test_0000", "2_test_0016","2_test_0055", "2_test_0093", "2_test_0160"
]
test_poses = [matrix_from_filename("bottles/bottles/pose/" + name + ".txt") for name in test_names]
camera_cal = matrix_from_filename("bottles/bottles/intrinsics.txt")

# generating images for the test set
for i, pose in enumerate(test_poses):
    rng, input_rng = random.split(rng)
    rgbd = render_image(state, state.params, pose, camera_cal, jnp.array([800, 800]), rng)
    rgbd = np.array(rgbd)
    image = (rgbd[:,:,:3] * 255).astype(np.uint8)
    Image.fromarray(image).save("test_imgs/" + test_names[i] + ".png")
    plt.imshow(image)
    depth = rgbd[:,:,3]
    depth = (depth * 255).astype(np.uint8)
    plt.imshow(depth)
    # Image.fromarray(depth).save("test_imgs/" + test_names[i] + "_depth.png")



InvalidRngError: RNGs should be of shape (2,) or KeyArray in Module NerfMLP, but rngs are: 0.001 (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.InvalidRngError)