In [None]:
import os
import torch

import numpy as np
import pprint

import matplotlib.pyplot as plt

import run_nerf
import run_nerf_helpers

### Load trained network weights

In [None]:
basedir = './logs'
run = 'ship'

if run == 'ship':
    expname = 'blender_paper_ship'
    ckpt = '050000.tar'
    dataset = 'ship' # used when config file of pretrained model not provided

elif run == 'lego':
    expname = 'lego_test'
    ckpt = '200000.tar'
    dataset = 'lego'


torch.set_default_tensor_type('torch.cuda.FloatTensor')

try:
    config = os.path.join(basedir, expname, 'config.txt')
    print('Args:')
    print(open(config, 'r').read())
except Exception as e: # config file not provided
    config = os.path.join('./configs', dataset+'.txt')
    print('Args:')
    print(open(config, 'r').read())


parser = run_nerf.config_parser()
device = "cuda" if torch.cuda.is_available() else "cpu"

ft_str = '' 
ft_str = '--ft_path {}'.format(os.path.join(basedir, expname, ckpt))
args = parser.parse_args('--config {} --ft_path {} --device {}'.format(
                                                            config, 
                                                            os.path.join(basedir, expname, ckpt),
                                                            device
                                                            ))
# pprint.pprint(args)
# Create nerf model
_, render_kwargs_test, start, grad_vars, models = run_nerf.create_nerf(args)

net_fn = render_kwargs_test['network_query_fn']

# Render an overhead view to check model was loaded correctly
c2w = torch.eye(4, dtype=torch.float32)# identity pose matrix
c2w[2,-1] = 4.
# c2w = c2w.to(device)

near = 2.
far = 6.
bds_dict = {
        'near' : near,
        'far' : far,
    }
render_kwargs_test.update(bds_dict)

H, W, focal = 800, 800, 1200.
down = 8
H, W, focal = H//down, W//down, focal/down

K = np.array([
    [focal, 0, 0.5*W],
    [0, focal, 0.5*H],
    [0, 0, 1]
])

with torch.no_grad():
    rgb, disp, acc, _ = run_nerf.render(H, W, K, chunk=args.chunk, c2w=c2w[:3,:4], **render_kwargs_test)
test = rgb.cpu().numpy()
img = np.clip(test,0,1)
plt.imshow(img)
plt.show()


### Query network on dense 3d grid of points

In [None]:
N = 256
bound = 1.2
t = np.linspace(-1*bound, bound, N+1)

query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
print(query_pts.shape)
sh = query_pts.shape
flat = query_pts.reshape([-1,3])
flat = torch.tensor(flat)
    
fn = lambda i0, i1 : net_fn(flat[i0:i1,None,:], viewdirs=torch.zeros_like(flat[i0:i1]), network_fn=render_kwargs_test['network_fine'])
chunk = 1024*64
with torch.no_grad():
    raw = torch.concatenate([fn(i, i+chunk) for i in range(0, flat.shape[0], chunk)], 0)
raw = torch.reshape(raw, list(sh[:-1]) + [-1])
raw = raw.cpu().numpy()
sigma = np.maximum(raw[...,-1], 0.)

print(raw.shape)
plt.hist(np.maximum(0,sigma.ravel()), log=True)
plt.show()

### Marching cubes with [PyMCubes](https://github.com/pmneila/PyMCubes)
Change `threshold` to use a different sigma threshold for the isosurface

In [None]:
import mcubes

threshold = 50.
print('fraction occupied', np.mean(sigma > threshold))
vertices, triangles = mcubes.marching_cubes(sigma, threshold)
print('done', vertices.shape, triangles.shape)
vertices = 2*bound*vertices/N-bound # project back to nerf coordination

### read vertex color through NeRF

In [None]:

# i found x,y is inversed in texture, but haven't know why, so i did this for quick solve
vertices = vertices[...,[1,0,2]]
flat = torch.tensor(vertices, dtype=torch.float32)
fn = lambda i0, i1 : net_fn(flat[i0:i1,None,:], viewdirs=torch.zeros_like(flat[i0:i1]), network_fn=render_kwargs_test['network_fine'])
with torch.no_grad():
    raw = torch.concatenate([fn(i, i+chunk) for i in range(0, flat.shape[0], chunk)], 0)
raw = torch.reshape(raw, [flat.shape[0], -1])
raw = raw.cpu().numpy()
print(raw.shape)
rgb = 1./(1 + np.exp(-1*raw[...,:3]))
print(rgb.shape)

### Live preview with [trimesh](https://github.com/mikedh/trimesh)


In [None]:
import trimesh
plot_texture = True

if plot_texture:
    mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=rgb)
else:
    mesh = trimesh.Trimesh(vertices, triangles)

mesh.show()

### Export mesh

In [None]:
_ = mesh.export(dataset+'.obj', 
            header='https://github.com/sayoriaaa/nerf-pytorch',
            )

### Import mesh

In [None]:
dataset = 'ship'
mesh = trimesh.load(dataset+'.obj')
mesh.show()