# NeRF Test

In [1]:
import sys, os
__basedir__ = os.path.dirname(os.path.realpath('.'))
print(__basedir__)
if __basedir__ not in sys.path:
     sys.path.insert(0, __basedir__)

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

d:\NeRF\nerf-lighting


In [2]:
import tensorflow as tf
import numpy as np
import imageio
import json
import random
import time
import cv2

from run_nerf_helpers import *
from nerf_renderer import *
from misc_helpers import *
from load_llff import load_llff_data

import matplotlib.pyplot as plt

In [3]:
HEIRARCHICAL_SAMPLING_METHOD = 2

In [8]:
def load_data(args):
    datadir = os.path.normcase(os.path.join(__basedir__, args.datadir))
    images, poses, bds, render_poses, i_test = load_llff_data(datadir, args.factor,
                                                                  recenter=True, bd_factor=.75,
                                                                  spherify=args.spherify)
    hwf = poses[0, :3, -1]
    H, W, focal = hwf
    H, W = int(H), int(W)

    poses = poses[:, :3, :4]
    print('Loaded llff', images.shape,
            render_poses.shape, hwf, args.datadir)

    print('DEFINING BOUNDS')
    if args.no_ndc:
        near = tf.reduce_min(bds) * .9
        far = tf.reduce_max(bds) * 1.
    else:
        near = 0.
        far = 1.
    print('NEAR FAR', near, far)   

    return images, poses, H, W, focal, near, far, bds, render_poses

In [9]:
if HEIRARCHICAL_SAMPLING_METHOD == 1:
    args_file = '../data/rubiks/cropped/logs/rubiks/args.txt'
else:
    args_file = '../data/rubiks2/logs/rubiks2/args.txt'
args = ConfigReader(args_file)

embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
    embeddirs_fn, input_ch_views = get_embedder(
        args.multires_views, args.i_embed)
output_ch = 4
skips = [4]
model = init_nerf_model(
    D=args.netdepth, W=args.netwidth,
    input_ch=input_ch, output_ch=output_ch, skips=skips,
    input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)

model_fine = None
model_fine = init_nerf_model(
    D=args.netdepth_fine, W=args.netwidth_fine,
    input_ch=input_ch, output_ch=output_ch, skips=skips,
    input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs)

def network_query_fn(inputs, viewdirs, network_fn): return run_network(
    inputs, viewdirs, network_fn,
    embed_fn=embed_fn,
    embeddirs_fn=embeddirs_fn,
    netchunk=args.netchunk)

render_kwargs = {
    'network_query_fn': network_query_fn,
    'perturb': args.perturb,
    'N_importance': args.N_importance,
    'network_fine': model_fine,
    'N_samples': args.N_samples,
    'network_fn': model,
    'use_viewdirs': args.use_viewdirs,
    'white_bkgd': args.white_bkgd,
    'raw_noise_std': args.raw_noise_std,
    'chunk': args.chunk,
    'ndc': False,
    'lindisp': args.lindisp
}

print(render_kwargs)

MODEL_NERF 63 27 <class 'int'> <class 'int'> True
(None, 90) (None, 63) (None, 27)
MODEL_NERF 63 27 <class 'int'> <class 'int'> True
(None, 90) (None, 63) (None, 27)
{'network_query_fn': <function network_query_fn at 0x00000263363FAAF8>, 'perturb': 1.0, 'N_importance': 128, 'network_fine': <tensorflow.python.keras.engine.functional.Functional object at 0x000002634DBCEFC8>, 'N_samples': 64, 'network_fn': <tensorflow.python.keras.engine.functional.Functional object at 0x000002634DB9F708>, 'use_viewdirs': True, 'white_bkgd': False, 'raw_noise_std': 1.0, 'chunk': 32768, 'ndc': False, 'lindisp': False}


In [10]:
if HEIRARCHICAL_SAMPLING_METHOD == 1:
    model_coarse_weights_file = '../data/rubiks/cropped/logs/rubiks/model_500000.npy'
    model_fine_weights_file = '../data/rubiks/cropped/logs/rubiks/model_fine_500000.npy'

    model.set_weights(np.load(model_coarse_weights_file, allow_pickle=True))
    model_fine.set_weights(np.load(model_fine_weights_file, allow_pickle=True))
else:
    model_coarse_weights_file = '../data/rubiks2/logs/rubiks2/model_350000.npy'
    model_fine_weights_file = '../data/rubiks2/logs/rubiks2/model_fine_350000.npy'

    model.set_weights(np.load(model_coarse_weights_file, allow_pickle=True))
    model_fine.set_weights(np.load(model_fine_weights_file, allow_pickle=True))

In [11]:
images, poses, H, W, focal, near, far, bds, render_poses = load_data(args)

Loaded image data (500, 500, 3, 132) [ 500.         500.        1098.8775858]
Loaded d:\nerf\nerf-lighting\.\data\rubiks2\ 2.2962166918186226 11.288918670814095
Data:
(132, 3, 5) (132, 500, 500, 3) (132, 2)
HOLDOUT view is 128
Loaded llff (132, 500, 500, 3) (120, 3, 5) [ 500.      500.     1098.8776] ./data/rubiks2/
DEFINING BOUNDS
NEAR FAR tf.Tensor(0.46544135, shape=(), dtype=float32) tf.Tensor(2.5425055, shape=(), dtype=float32)


In [14]:
import time
i = 43

results = []

for ns in [16]:
    for ni in [8]:
        print(ns, ni)
        render_kwargs['N_samples'] = ns
        render_kwargs['N_importance'] = ni

        start = time.time()
        rgb, disp, acc, _ = render(H, W, focal, near=near, far=far, c2w=poses[i, :3, :4], retraw=False, **render_kwargs)
        elapsed = time.time() - start
        psnr = mse2psnr(img2mse(images[i], rgb))
        _psnr = psnr.numpy().item()
        results.append([ns, ni, elapsed, psnr.numpy()])
        
        plt.imsave(f'../outputs/ns_ni_test2/{ns}_{ni}_{round(elapsed, 3)}_{round(_psnr, 3)}.png', rgb.numpy())

results = np.array(results, dtype = np.float32)
with open(f'../outputs/ns_ni_test2/results.npy', 'wb') as f:
    np.save(f, results)

print('FINISHED')

16 8
FINISHED


In [None]:
results

In [None]:
np.round(2.22222, 3)

In [None]:
plt.plot(results[:5, 1], results[:5, 3], label = 'n_samples = 64')
plt.plot(results[5:10, 1], results[5:10, 3], label = 'n_samples = 32')
plt.plot(results[10:15, 1], results[10:15, 3], label = 'n_samples = 16')
plt.legend()

In [None]:
x, y, z = results[:, 0], results[:, 1], results[:, 2]
fig = plt.figure(figsize=(6, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z,
           linewidths=1, alpha=.7,
           edgecolor='k',
           s = 200,
           c=z)
plt.show()

In [None]:
acc_ = acc.numpy()[..., np.newaxis]
acc_.dtype

In [None]:
plot_images([images[i], rgb, acc_], 1, 3)

In [None]:
plt.imsave(f'../outputs/acc_test/{i}_original.png', images[i])
plt.imsave(f'../outputs/acc_test/{i}_rgb.png', rgb.numpy())
plt.imsave(f'../outputs/acc_test/{i}_acc.png', acc.numpy(), cmap='gray')

In [None]:
def get_psnr(images, poses, near, far):
    rgbs = []
    psnrs = []    
    for i, target in enumerate(images):
        print('Rendering ', i + 1)
        rgb, _, _, _ = render(H, W, focal, near=near, far=far, c2w=poses[i, :3, :4], retraw=False, **render_kwargs)
        rgbs.append(rgb)
        psnr = mse2psnr(img2mse(rgb, target))
        psnrs.append(psnr)
        del rgb

    return rgbs, psnrs 

In [None]:
test_images = images[::8]
test_poses = poses[::8]
test_images.shape, test_poses.shape

In [None]:
rgbs, psnrs = get_psnr(images, poses, near, far)

In [None]:
plot_images(rgbs, 1, 1)

In [None]:
psnrs

In [None]:
plt.imsave('../outputs/res_sampling_test_img82.png', images[82])

In [None]:
mean_psnr = lambda x: tf.reduce_mean(x).numpy()

In [None]:
plot_images(rgbs, 14, 10, scale=1.0)

In [None]:
plot_images(rgbs[128:], 2, 2)

In [None]:
mean_psnr(psnrs)

In [None]:
mean_psnr(psnrs[::8])

In [None]:
i_test = np.arange(images.shape[0])[::8]
i_train = np.array([i for i in np.arange(int(images.shape[0])) if  i not in i_test])
train_psnrs = [psnrs[i] for i in i_train]

In [None]:
mean_psnr(psnrs[::8])