In [28]:
%matplotlib inline
from run_dnerf import config_parser, create_nerf
import matplotlib.pyplot as plt
import torch
from load_blender import pose_spherical
from run_dnerf import render_path
from run_dnerf_helpers import to8b
import numpy as np
import mcubes
import trimesh

In [29]:
# set cuda
torch.set_default_tensor_type('torch.cuda.FloatTensor')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# get config file
config_file = "configs/mutant.txt"
parser = config_parser()
args = parser.parse_args(f'--config {config_file}')

# set render params
hwf = [400, 400, 555.555]
_, render_kwargs_test, _, _, _ = create_nerf(args)
render_kwargs_test.update({'near' : 2., 'far' : 6.})

NeRF type selected: direct_temporal
Found ckpts ['./logs/mutant/800000.tar']
Reloading from ./logs/mutant/800000.tar


## Generate Debug Frame

In [30]:
def generate_img(time, azimuth, elevation):
    assert 0. <= time <= 1.
    assert -180 <= azimuth <= 180
    assert -180 <= elevation <= 180
    
    render_poses = torch.unsqueeze(pose_spherical(azimuth, elevation, 4.0), 0).to(device)
    render_times = torch.Tensor([time]).to(device)

    with torch.no_grad():
            rgbs, _ = render_path(render_poses, render_times, hwf, args.chunk, render_kwargs_test, render_factor=args.render_factor)
    rgbs = to8b(rgbs)
    return rgbs[0]

In [32]:
# genearte
time = .5  # in [0,1]
azimuth = -180  # in [-180,180]
elevation = -20  # in [-180,180]
img = generate_img(time, azimuth, elevation)

# display
plt.figure(2, figsize=(20,6))
plt.imshow(img)
plt.show()

  0%|          | 0/1 [00:00<?, ?it/s]


RuntimeError: CUDA out of memory. Tried to allocate 720.00 MiB (GPU 0; 10.92 GiB total capacity; 1.58 GiB already allocated; 170.31 MiB free; 2.44 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Reconstruct

In [26]:
N = 256
min_rec, max_rec = -1.2, 1.2
t = np.linspace(min_rec, max_rec, N+1)

query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)
sh = query_pts.shape
pts = torch.Tensor(query_pts.reshape([-1,3]))


def batchify(fn, chunk):
    if chunk is None:
        return fn
    def ret(inputs):
        return torch.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret

def reconstruct(points, time):
    net_fn = render_kwargs_test['network_query_fn']
    fn = lambda i0, i1 : net_fn(points[i0:i1,None,:], viewdirs=torch.zeros_like(points[i0:i1]), ts=torch.ones_like(points[i0:i1])[:,0:1]*time, network_fn=render_kwargs_test['network_fn'])
    chunk = 1024*64
    raw = np.concatenate([fn(i, i+chunk)[0].cpu().numpy() for i in range(0, points.shape[0], chunk)], 0)
    raw = np.reshape(raw, list(sh[:-1]) + [-1])
    sigma = np.maximum(raw[...,-1], 0.)
    return sigma
    
with torch.no_grad():
    sigma = reconstruct(pts, time)

In [27]:
# reconstruct
threshold = 40  # this is just a randomly found threshold
vertices, triangles = mcubes.marching_cubes(sigma, threshold)

# display
mesh = trimesh.Trimesh(vertices / N * (max_rec - min_rec) + min_rec , triangles)
mesh = trimesh.smoothing.filter_laplacian(mesh, iterations=3)
mesh.show()