# Testing

In [15]:
import torch
import numpy as np
import imageio
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader
from dataset import get_rays
from rendering import rendering
import torch.nn as nn
from ml_helpers import training
from model import Voxel, NeRF

In [16]:
test_o, test_d, target_px_values = get_rays('fruits', mode='test')

In [17]:
device = 'cuda'
model = torch.load('model_nerf_fruits')
tn = 8.
tf = 12.
H = 400
W = 400

In [18]:
def mse2psnr(mse):
    return 20 * np.log10(1 / np.sqrt(mse))

@torch.no_grad()
def test(model, o, d, tn, tf, nb_bins=100, chunk_size=10, target=None):
    o = o.chunk(chunk_size)
    d = d.chunk(chunk_size)

    image = []
    for o_batch, d_batch in zip(o, d):
        img_batch = rendering(model, o_batch, d_batch, tn, tf, nb_bins=nb_bins, device=o_batch.device)
        image.append(img_batch)

    image = torch.cat(image)
    image = image.reshape(H, W, 3).data.cpu().numpy()

    if target is not None:
        mse = ((image - target) ** 2).mean()
        psnr = mse2psnr(mse)
        return image, mse, psnr

    return image

In [19]:
import gc
torch.cuda.empty_cache()
gc.collect()

9

In [20]:
# TEST_IMG = 1
# img, mse, psnr = test(model, torch.from_numpy(test_o[TEST_IMG]).to(device).float(), torch.from_numpy(test_d[TEST_IMG]).to(device).float(),
#                     tn, tf, nb_bins=100, chunk_size=10, target=target_px_values[TEST_IMG].reshape(400, 400, 3))

In [21]:
# fix, ax = plt.subplots(1, 2, figsize=(10, 5))
# ax[0].imshow(img)
# ax[1].imshow(target_px_values[TEST_IMG].reshape(400, 400, 3))
# plt.show()

In [22]:
# mse, psnr

# Mesh Extraction

In [23]:
import torch
import numpy as np
import imageio
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader
from dataset import get_rays
from rendering import rendering
import torch.nn as nn
from ml_helpers import training
from model import Voxel, NeRF
import mcubes
import trimesh

In [24]:
device = 'cuda'
model = torch.load('model_nerf_fruits')
tn = 8.
tf = 12.
H = 720
W = 1280

In [25]:
N = 100
scale = 1.5

x = torch.linspace(-scale, scale, N)
y = torch.linspace(-scale, scale, N)
z = torch.linspace(-scale, scale, N)
x, y, z = torch.meshgrid((x, y, z))

xyz = torch.cat((x.reshape(-1, 1),
                y.reshape(-1, 1),
                z.reshape(-1, 1)), dim=1)

In [26]:
with torch.no_grad():
    _, density = model.forward(xyz.to(device), torch.zeros_like(xyz).to(device))

density = density.cpu().numpy().reshape(N, N, N)

In [27]:
vertices, traingles = mcubes.marching_cubes(density, 5 * np.mean(density))

In [28]:
mesh = trimesh.Trimesh(vertices, traingles)
mesh.show()