In [1]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn
import numpy as np

import dataset
import rendering
import model
import train

import mcubes
import trimesh
from model import NeRFLightning
from config import (
    DATA_DIR,
    IMG_SIZE,
    LEARNING_RATE,
    TN,
    TF,
    NB_BINS,
    GAMMA,
    ACCELERATOR,
    DEVICES,
    PRECISION,
    MAX_EPOCHS,
    SCALE,
)

In [None]:
CKPT_DIR = "models/epoch=16-step=83670.ckpt"

device = "cuda" if torch.cuda.is_available() else "cpu"
lit_nerf = NeRFLightning.load_from_checkpoint(CKPT_DIR)
lit_nerf.eval()
N = 100
x = torch.linspace(-SCALE, SCALE, N)
y = torch.linspace(-SCALE, SCALE, N)
z = torch.linspace(-SCALE, SCALE, N)

# 3d grid values
x, y, z = torch.meshgrid((x, y, z))
xyz = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)), dim=1)
with torch.inference_mode():
    _, density = lit_nerf.forward(xyz.to(device), torch.zeros_like(xyz).to(device))
density = density.cpu().numpy().reshape(N, N, N)

In [None]:
vertices, triangles = mcubes.marching_cubes(density, 1 * np.mean(density))
mesh = trimesh.Trimesh(vertices / N, triangles)
mesh.show()