## Imports and initialization

In [None]:
from __future__ import annotations

import io
import math
import requests
from collections import namedtuple
from math import cos, sin
import itertools

import numpy as np
import torch
import PIL.Image
import rerun as rr  # pip install rerun-sdk
import rerun.blueprint as rrb

## Define neural field class

First, we define the neural field class which we can be used to represent any continuous ND signal. I.e., it maps an ND point to another ND point. In this notebook we fit fields to map from 2D image coordinates to RGB colors. This way the network weights can be interpreted as encoding a continuous image.

In [None]:
class NeuralField(torch.nn.Module):
    """Simple neural field composed of positional encoding, MLP, and activation function."""
    def __init__(
        self, 
        num_layers: int, 
        dim_hidden: int, 
        dim_in: int=2, 
        dim_out: int=3,
        activation: str="sigmoid",
        pe_sigma: Optional[float]=None,
    ):
        super().__init__()

        self.num_layers = num_layers
        self.dim_hidden = dim_hidden
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.activation = activation
        self.pe_sigma = pe_sigma

        sizes = [dim_in] + [dim_hidden for _ in range(num_layers-1)] + [dim_out]
        self.linears = torch.nn.ModuleList()
        for in_size, out_size in zip(sizes[:-1], sizes[1:]):
            self.linears.append(torch.nn.Linear(in_size, out_size))

        if self.pe_sigma is not None:
            torch.nn.init.normal_(self.linears[0].weight, 0.0, self.pe_sigma)


    def __str__(self) -> str:
        return f"{self.num_layers} lay., {self.dim_hidden} neu., pe σ: {self.pe_sigma}"


    def forward(self, input_points: torch.Tensor):
        """Compute output for given input points.

        Args:
            input_points: 
        """
        if self.pe_sigma is None:
            out = torch.relu(self.linears[0](input_points))
        else:
            out = torch.sin(self.linears[0](input_points))
            
        for linear in self.linears[1:-1]:
            out = torch.relu(linear(out))

        out = self.linears[-1](out)

        if self.activation == "sigmoid":
            out = torch.sigmoid(out)
            
        return out

## Initialize and visualize neural fields

Now we create a few neural fields with different parameters and visualize their output as images. We assume that images are fit in a 0 to 1 unit square, so we query in a dense grid (with some additional margin to observe out-of-training behavior) to retrieve the image from the network. Note that the positional encoding encodes how quickly the neural field varies out-of-the-box. This corresponds to the amount of detail that the field can easily represent, but also determines how the field extrapolates outside of the training region.

In [None]:
fields = [
    NeuralField(num_layers=5, dim_hidden=128, pe_sigma=5),
    NeuralField(num_layers=5, dim_hidden=128, pe_sigma=15),
    NeuralField(num_layers=5, dim_hidden=128, pe_sigma=30),
    NeuralField(num_layers=5, dim_hidden=128, pe_sigma=100)
]
total_iterations = [0 for _ in fields]

rr.init("rerun_example_cube")

blueprint = rrb.Blueprint(
    rrb.Vertical(
        rrb.Grid(
            rrb.Spatial2DView(name="Target", origin=f"target"),
            *[rrb.Spatial2DView(name=str(field), origin=f"field_{i}")
              for i, field in enumerate(fields)],
        ),
        rrb.TimeSeriesView(
            name="Losses", 
            origin="/", 
            defaults=[rr.components.AggregationPolicyBatch("average")],
            plot_legend=rrb.Corner2D.LeftTop
        ),
        row_shares=[0.7,0.3]
    ),
    collapse_panels=True
)
for i, field in enumerate(fields):
    rr.log(f"loss/field_{i}", rr.SeriesLine(name=str(field)), static=True)

rr.notebook_show(blueprint=blueprint, width=1050, height=600)

@torch.no_grad()
def log_field_as_image(entity_path: str, field: NeuralField, min_uv: Tuple[float], max_uv: Tuple[float], uv_resolution: Tuple[int]): 
    u_values = torch.linspace(min_uv[0], max_uv[0], uv_resolution[0])
    v_values = torch.linspace(min_uv[1], max_uv[1], uv_resolution[1])
    uv_points = torch.cartesian_prod(u_values, v_values) + 0.5 / torch.tensor(uv_resolution)  # 0.5 is the center of a pixel
    predictions = field(uv_points)
    image_prediction = torch.clamp(predictions.reshape(uv_resolution[0], uv_resolution[1], 3), 0, 1)
    image_prediction = image_prediction.permute(1, 0, 2)
    rr.log(entity_path, rr.Image(image_prediction.numpy(force=True)))


rr.set_time_sequence("iteration", 0)
for i, field in enumerate(fields):
    log_field_as_image(f"field_{i}", field, (-0.1, -0.1), (1.1,1.1), (100, 100))

## Train neural field

Now we train the neural fields for a fixed number of iterations. If you run the cell twice, we continue training where we left off. To reset the fields, run the previous cell again.

In [None]:
field_ids = [0,1,2,3]  # if you only want to train one of the fields
num_iterations = 10000
batch_size = 1000
learning_rate = 1e-3
log_image_period = 10


response = requests.get("https://storage.googleapis.com/rerun-example-datasets/example_images/tiger.jpg")
# response = requests.get("https://storage.googleapis.com/rerun-example-datasets/example_images/bird.jpg")
target_image = torch.from_numpy(np.asarray(PIL.Image.open(io.BytesIO(response.content)))).float() / 255

rr.log("target", rr.Image(target_image))

try:
    parameters = itertools.chain(*list(fields[field_id].parameters() for field_id in field_ids))
    optimizer = torch.optim.Adam(parameters, lr=learning_rate)
    for iteration in range(num_iterations):
        optimizer.zero_grad()
        
        target_uvs = torch.rand(batch_size, 2)
        target_jis = (target_uvs * torch.tensor([target_image.shape[1], target_image.shape[0]])).int()
        target_rgbs = target_image[target_jis[:,1], target_jis[:,0]]
        for field_id in field_ids:
            field = fields[field_id]
            total_iterations[field_id] += 1
            
            predicted_rgbs = field(target_uvs)
            loss = torch.nn.functional.mse_loss(target_rgbs, predicted_rgbs)
            
            rr.set_time_sequence("iteration", total_iterations[field_id])
            rr.log(f"loss/field_{field_id}", rr.Scalar(loss.item()))
            loss.backward()
    
                
        optimizer.step()
            
        if iteration % log_image_period == 0:
            for field_id in field_ids:
                log_field_as_image(f"field_{field_id}", fields[field_id], (-0.1, -0.1), (1.1,1.1), (100, 100))
except KeyboardInterrupt:
    print("Training stopped.")