In [1]:
import jax
import jax.numpy as jnp
import optax
from inr_utils.sampling import NeRFSyntheticScenesSampler
from model_components.inr_modules import NeRF
from model_components.inr_layers import SirenLayer
from model_components.inr_layers import ClassicalPositionalEncoding

# Set JAX to run on CPU
# jax.config.update('jax_platform_name', 'cpu')

2025-01-22 14:44:24.605377: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.4 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [3]:


# Step 1: Initialize the NeRF model
key = jax.random.PRNGKey(0)
nerf_model = NeRF.from_config(
    in_size=(3, 2),
    out_size=(1, 3),
    bottle_size=256,
    block_length=4,
    block_width=512,
    num_blocks=2,
    condition_length=None,
    condition_width=None,
    layer_type=SirenLayer,
    activation_kwargs={"w0": 30.0},
    key=key,
    initialization_scheme=None,
    initialization_scheme_kwargs=None,
    positional_encoding_layer=ClassicalPositionalEncoding,
    # positional_encoding_layer=ClassicalPositionalEncoding.from_config(num_frequencies=3),
    # positional_encoding_layer=None,
    num_splits=1,
    post_processor=None,
    num_coarse_samples=64,
    num_fine_samples=128,
    use_viewdirs=True,
    near=2.0,
    far=6.0,
    noise_std=1.0,
    white_bkgd=True,
    lindisp=False
)

# Step 2: Initialize the sampler
sampler = NeRFSyntheticScenesSampler(
    split='train',
    name="vase",#'armchair_dataset_small',
    batch_size=10,
    poses_per_batch=10,
    base_path="./synthetic_scenes/",#"example_data/nerfdata",
    size_limit=-1
)

In [7]:
import pdb
import traceback

try:
    # Step 3: Use the sampler to generate rays and ground truth data
    key, sample_key = jax.random.split(key)
    ray_origins, ray_directions, sample_key, ground_truth = sampler(sample_key)

    # Step 4: Perform a forward pass through the NeRF model
    randomized = True
    output = jax.vmap(lambda ro, rd, ra: nerf_model(ro, rd, ra, key=sample_key), (0, 0, None))(ray_origins, ray_directions, randomized)

    # Print the output
    print(output)

except Exception as e:
    print(e)
    traceback.print_exc()
    pdb.post_mortem()


Traceback (most recent call last):
  File "/tmp/ipykernel_32273/2477426486.py", line 11, in <module>
    output = jax.vmap(lambda ro, rd, ra: nerf_model(ro, rd, ra, key=sample_key), (0, 0, None))(ray_origins, ray_directions, randomized)
  File "/home/simon/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/simon/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/jax/_src/api.py", line 994, in vmap_f
    out_flat = batching.batch(
  File "/home/simon/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/jax/_src/linear_util.py", line 187, in call_wrapped
    return self.f_transformed(*args, **kwargs)
  File "/home/simon/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/jax/_src/interpreters/batching.py", line 589, in _batch_outer
    outs, trace = f(tag, in_dims, *in_vals)
  File "/home/simon/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/jax

Too many indices: 1-dimensional array indexed with 2 regular indices.
> [0;32m/home/simon/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py[0m(12383)[0;36m_canonicalize_tuple_index[0;34m()[0m
[0;32m  12381 [0;31m  [0;32mif[0m [0mnum_dimensions_consumed[0m [0;34m>[0m [0marr_ndim[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m  12382 [0;31m    [0mindex_or_indices[0m [0;34m=[0m [0;34m"index"[0m [0;32mif[0m [0mnum_dimensions_consumed[0m [0;34m==[0m [0;36m1[0m [0;32melse[0m [0;34m"indices"[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m> 12383 [0;31m    raise IndexError(
[0m[0;32m  12384 [0;31m        [0;34mf"Too many indices: {arr_ndim}-dimensional array indexed "[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m  12385 [0;31m        f"with {num_dimensions_consumed} regular {index_or_indices}.")
[0m
> [0;32m/home/simon/miniconda3/envs/inr_edu_24/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py[0m(12064)[0;36m_index