## Imports and initialization

In [1]:
from __future__ import annotations

import math
from collections import namedtuple
from math import cos, sin

import numpy as np
import torch
import PIL.Image
import rerun as rr  # pip install rerun-sdk
import rerun.blueprint as rrb


## Define neural field class

First, we define the neural field class which we can be used to represent any continuous ND signal. I.e., it maps an ND point to another ND point. In this notebook we fit fields to map from 2D image coordinates to RGB colors. This way the network weights can be interpreted as encoding a continuous image.

In [73]:
class NeuralField(torch.nn.Module):
    """Simple neural field composed of positional encoding, MLP, and activation function."""
    def __init__(
        self, 
        num_layers: int, 
        dim_hidden: int, 
        dim_in: int=2, 
        dim_out: int=3):
        super().__init__()

        self.num_layers = num_layers
        self.dim_hidden = dim_hidden
        self.dim_in = dim_in
        self.dim_out = dim_out
        

    def __str__(self) -> str:
        return f"{self.num_layers}-layer MLP ({self.dim_hidden} neurons)"


    def forward(self, input_points: torch.Tensor):
        """Compute output for given input points.

        Args:
            input_points: 
        """
        num_points = len(input_points)
        return torch.rand(num_points, 3)
    

## Initialize and visualize neural fields

Now we create a few neural fields with different parameters and visualize their output as images. We assume that images are fit in a 0 to 1 unit square, so we query in a dense grid (with some additional margin to observe out-of-training behavior) to retrieve the image from the network. Note that the positional encoding encodes how quickly the neural field varies out-of-the-box. This corresponds to the amount of detail that the field can easily represent, but also determines how the field extrapolates outside of the training region.

In [81]:
fields = [
    NeuralField(num_layers=4, dim_hidden=32),
    NeuralField(num_layers=4, dim_hidden=64),
    NeuralField(num_layers=4, dim_hidden=128),
    NeuralField(num_layers=4, dim_hidden=256)
]
total_iterations = [0 for _ in fields]

rr.init("rerun_example_cube")

blueprint = rrb.Blueprint(
    rrb.Vertical(
        rrb.Grid(
            *[rrb.Spatial2DView(name=str(field), origin=f"field_{i}")
              for i, field in enumerate(fields)],
        ),
        rrb.TimeSeriesView(name="Losses", origin="/", defaults=[rr.components.AggregationPolicyBatch("average")]),
        row_shares=[0.7,0.3]
    ),
    collapse_panels=True
)
for i, field in enumerate(fields):
    rr.log(f"loss/field_{i}", rr.SeriesLine(name=str(field)), static=True)

rr.notebook_show(blueprint=blueprint, width=1050, height=800)

@torch.no_grad()
def log_field_as_image(entity_path: str, field: NeuralField, min_uv: Tuple[float], max_uv: Tuple[float], uv_resolution: Tuple[int]): 
    u_values = torch.linspace(min_uv[0], max_uv[0], uv_resolution[0])
    v_values = torch.linspace(min_uv[1], max_uv[1], uv_resolution[1])
    uv_points = torch.cartesian_prod(u_values, v_values) + 0.5 / torch.tensor(uv_resolution)  # 0.5 is the center of a pixel
    predictions = field(uv_points)
    image_prediction = torch.clamp(predictions.reshape(uv_resolution[0], uv_resolution[1], 3), 0, 1)
    image_prediction = image_prediction.permute(1, 0, 2)
    rr.log(entity_path, rr.Image(image_prediction.numpy(force=True)))


rr.set_time_sequence("iteration", 0)
for i, field in enumerate(fields):
    log_field_as_image(f"field_{i}", field, (-0.1, -0.1), (1.1,1.1), (100, 100))

[2024-06-27T15:36:20Z DEBUG re_chunk::batcher] creating new chunk batcher config=ChunkBatcherConfig { flush_tick: 8ms, flush_num_bytes: 1048576, flush_num_rows: 18446744073709551615, max_chunk_rows_if_unsorted: 256, max_commands_in_flight: None, max_chunks_in_flight: None, hooks: BatcherHooks { on_insert: None, on_release: Some(ArrowChunkReleaseCallback("0x23e98020310")) } }
[2024-06-27T15:36:20Z DEBUG re_sdk::recording_stream] setting recording info app_id=rerun_example_cube rec_id=ac0064a7-1de4-423d-8284-9dd6e792faea
[2024-06-27T15:36:20Z DEBUG re_chunk::batcher] creating new chunk batcher config=ChunkBatcherConfig { flush_tick: 8ms, flush_num_bytes: 1048576, flush_num_rows: 18446744073709551615, max_chunk_rows_if_unsorted: 256, max_commands_in_flight: None, max_chunks_in_flight: None, hooks: BatcherHooks { on_insert: None, on_release: Some(ArrowChunkReleaseCallback("0x23e98020a50")) } }
[2024-06-27T15:36:20Z DEBUG re_sdk::recording_stream] setting recording info app_id=rerun_example

Viewer()

## Train neural field

Now we train the neural fields for a fixed number of iterations. If you run the cell twice, we continue training where we left off. To reset the fields, run the previous cell again.

In [67]:
field_ids = [0,1,2, 3]  # if you only want to train one of the fields
num_iterations = 10000
batch_size = 1000
learning_rate = 1e-4

for field_id in field_ids:
    field = fields[field_id]
    for iteration in range(num_iterations):
        total_iterations[field_id] += 1
        rr.set_time_sequence("iteration", total_iterations[field_id])
        rr.log(f"loss/field_{field_id}", rr.Scalar(torch.rand(1,).item()))
        # train_field(fields[field_id], num_iterations, batch_size, learning_rate)