# Image Fit With PyTorch Binding

A compact variant of the image-learning demo (`samples/mlp_learning_an_image_pytorch.py`).
It trains on random UV samples from `data/images/albert.jpg` and verifies reconstruction error improves.

In [1]:
import json
import math
import numpy as np
import torch
import tinycudann as tcnn
from PIL import Image

from pathlib import Path

def find_repo_root(start: Path) -> Path:
    for p in [start, *start.parents]:
        if (p / "README.md").exists() and (p / "CMakeLists.txt").exists():
            return p
    raise RuntimeError("Could not find tiny-cuda-nn repository root.")

ROOT = find_repo_root(Path.cwd().resolve())
print("Repository root:", ROOT)

if not torch.cuda.is_available():
    raise RuntimeError("CUDA is required for tinycudann examples.")

torch.manual_seed(13)
np.random.seed(13)
device = torch.device("cuda")

Repository root: /media/tunguz/3139-3535/tiny-cuda-nn


In [2]:
image_path = ROOT / "data" / "images" / "albert.jpg"
img_np = np.asarray(Image.open(image_path).convert("RGB"), dtype=np.float32) / 255.0
img = torch.from_numpy(img_np).to(device)
height, width, _ = img.shape

print(f"Loaded image: {image_path} ({width}x{height})")

config = json.loads((ROOT / "data" / "config_hash.json").read_text(encoding="utf-8"))
model = tcnn.NetworkWithInputEncoding(
    n_input_dims=2,
    n_output_dims=3,
    encoding_config=config["encoding"],
    network_config=config["network"],
).to(device)
model.jit_fusion = tcnn.supports_jit_fusion()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

Loaded image: /media/tunguz/3139-3535/tiny-cuda-nn/data/images/albert.jpg (3250x4333)


In [3]:
def sample_image(coords: torch.Tensor) -> torch.Tensor:
    # coords: [N, 2] in [0, 1], with columns (x, y)
    xs = coords[:, 0] * (width - 1)
    ys = coords[:, 1] * (height - 1)

    x0 = torch.floor(xs).long().clamp(0, width - 1)
    y0 = torch.floor(ys).long().clamp(0, height - 1)
    x1 = (x0 + 1).clamp(max=width - 1)
    y1 = (y0 + 1).clamp(max=height - 1)

    wx = (xs - x0.float()).unsqueeze(1)
    wy = (ys - y0.float()).unsqueeze(1)

    c00 = img[y0, x0]
    c10 = img[y0, x1]
    c01 = img[y1, x0]
    c11 = img[y1, x1]

    return (
        c00 * (1.0 - wx) * (1.0 - wy)
        + c10 * wx * (1.0 - wy)
        + c01 * (1.0 - wx) * wy
        + c11 * wx * wy
    )


def evaluate_full_image_mse() -> float:
    with torch.no_grad():
        xs = torch.linspace(0.0, 1.0, width, device=device)
        ys = torch.linspace(0.0, 1.0, height, device=device)
        yy, xx = torch.meshgrid(ys, xs, indexing="ij")
        coords = torch.stack([xx.reshape(-1), yy.reshape(-1)], dim=1)
        pred = model(coords).reshape(height, width, 3).clamp(0.0, 1.0)
        return torch.mean((pred - img) ** 2).item()

In [4]:
start_mse = evaluate_full_image_mse()
print(f"full-image MSE before training: {start_mse:.6f}")

steps = 120
batch_size = 16384
for step in range(steps):
    coords = torch.rand((batch_size, 2), device=device)
    target = sample_image(coords)
    pred = model(coords)
    loss = torch.mean((pred - target) ** 2)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if step in {0, 19, 39, 59, 79, 99, 119}:
        print(f"step={step:03d} mse={float(loss.item()):.6f}")

end_mse = evaluate_full_image_mse()
psnr = -10.0 * math.log10(max(end_mse, 1e-12))
print(f"full-image MSE after training:  {end_mse:.6f}")
print(f"final PSNR: {psnr:.2f} dB")

assert end_mse < start_mse, "Image training did not improve reconstruction error."

Failed to JIT-compile `inference_mp_network_with_input_encoding`. Disabling JIT.


full-image MSE before training: 0.222253
step=000 mse=0.221711
step=019 mse=0.010469
step=039 mse=0.004558
step=059 mse=0.003020
step=079 mse=0.002450
step=099 mse=0.001875
step=119 mse=0.001596
full-image MSE after training:  0.001681
final PSNR: 27.74 dB
