# tiny-cuda-nn Function Fit

Train a `tinycudann.NetworkWithInputEncoding` to approximate a smooth synthetic RGB function
of 2D coordinates. This demonstrates the core encoding + fused-network workflow.

In [1]:
import math
import torch
import tinycudann as tcnn

from pathlib import Path

def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "README.md").exists() and (p / "CMakeLists.txt").exists():
            return p
    raise RuntimeError("Could not find tiny-cuda-nn repository root.")

ROOT = find_repo_root(Path.cwd().resolve())
print("Repository root:", ROOT)

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is required for tinycudann examples.")

torch.manual_seed(7)
device = torch.device("cuda")

Repository root: /media/tunguz/3139-3535/tiny-cuda-nn


In [2]:
def target_fn(xy: torch.Tensor) -> torch.Tensor:
    x = xy[:, 0:1]
    y = xy[:, 1:2]
    rgb = torch.cat([
        torch.sin(2.0 * math.pi * x) * torch.cos(2.0 * math.pi * y),
        torch.sin(4.0 * math.pi * (x + y)),
        torch.cos(2.0 * math.pi * (x - y)),
    ], dim=1)
    return 0.5 * (rgb + 1.0)

encoding_config = {
    "otype": "HashGrid",
    "n_levels": 12,
    "n_features_per_level": 2,
    "log2_hashmap_size": 15,
    "base_resolution": 8,
    "per_level_scale": 1.5,
}
network_config = {
    "otype": "FullyFusedMLP",
    "activation": "ReLU",
    "output_activation": "Sigmoid",
    "n_neurons": 64,
    "n_hidden_layers": 2,
}

model = tcnn.NetworkWithInputEncoding(
    n_input_dims=2,
    n_output_dims=3,
    encoding_config=encoding_config,
    network_config=network_config,
).to(device)
model.jit_fusion = tcnn.supports_jit_fusion()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [3]:
steps = 300
batch_size = 8192
losses = []

for step in range(steps):
    xy = torch.rand((batch_size, 2), device=device)
    target = target_fn(xy)
    pred = model(xy)
    loss = torch.mean((pred - target) ** 2)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    losses.append(float(loss.item()))
    if step in {0, 49, 99, 149, 199, 249, 299}:
        print(f"step={step:03d} mse={losses[-1]:.6f}")

Failed to JIT-compile `forward_network_with_input_encoding`. Disabling JIT.


step=000 mse=0.104096
step=049 mse=0.000713
step=099 mse=0.000075
step=149 mse=0.000034
step=199 mse=0.000021
step=249 mse=0.000015
step=299 mse=0.000012


In [4]:
with torch.no_grad():
    res = 128
    lin = torch.linspace(0.0, 1.0, res, device=device)
    yy, xx = torch.meshgrid(lin, lin, indexing="ij")
    coords = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=1)

    pred = model(coords)
    tgt = target_fn(coords)
    mse = torch.mean((pred - tgt) ** 2).item()
    psnr = -10.0 * math.log10(max(mse, 1e-12))

print(f"initial mse={losses[0]:.6f}")
print(f"final training-step mse={losses[-1]:.6f}")
print(f"grid mse={mse:.6f}, psnr={psnr:.2f} dB")

assert losses[-1] < losses[0], "Training did not reduce loss."
assert mse < 0.03, "Function fit did not converge sufficiently."

initial mse=0.104096
final training-step mse=0.000012
grid mse=0.000013, psnr=48.92 dB
