In [1]:
import torch
from utils import *
from collections import defaultdict
import matplotlib.pyplot as plt
import time

from models.rendering import *
from models.nerf import *

import metrics

from datasets import dataset_dict
from datasets.llff import *

torch.backends.cudnn.benchmark = True

img_wh = (640, 480)

# dataset = dataset_dict['llff'] \
#           ('/home/ubuntu/data/nerf_example_data/nerf_llff_data/fern/', 'test_train', spheric_poses=False,
#            img_wh=img_wh)

# dataset = dataset_dict['llff_nocs'] \
#           ('data/scene_1',
#            img_wh=img_wh)
dataset = dataset_dict['llff_nocs']

print("dataset", dataset)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
dataset datasets.llff.LLFFDatasetNOCS


In [2]:
from torch.utils.data import DataLoader
kwargs = {'root_dir': 'data/scene_1',
          'img_wh': tuple(img_wh)}
kwargs['spheric_poses'] = False
kwargs['val_num'] = 1

train_dataset = dataset(split='val', **kwargs)

dataloader =  DataLoader(train_dataset,
                  shuffle=True,
                  num_workers=4,
                  batch_size=1024,
                  pin_memory=True)

self.focal 584.5808479748345
val image is data/scene_1/images/0023_color.png


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
import torch
import numpy as np
a = torch.tensor([[0.6, 0.0, 0.0, 0.0],
[0.0, 0.4, 0.0, 0.0],
[0.0, 0.0, 1.2, 0.0],
[0.0, 0.0, 0.0,-0.4]])
b = torch.nonzero(a)

print(b.shape)
c = np.nonzero(a.numpy())
c = np.transpose(np.asarray(c))
print(np.transpose(np.asarray(c)).shape)
print(np.nonzero(a.numpy()))

print(np.equal(b.numpy(), c))

In [3]:
print("len train dataset", len(train_dataset))
count = 0
for data in dataloader:
    for k,v in data.items():
        print("k, v", k, v.shape)
    count+=1
print(count)
#     print("data_rgb", data['rgbs'].shape)

len train dataset 1
k, v rays torch.Size([1, 307200, 8])
k, v rgbs torch.Size([1, 307200, 3])
k, v instance_mask torch.Size([1, 307200])
k, v instance_mask_weight torch.Size([1, 307200])
k, v instance_ids torch.Size([1, 307200])
1


In [None]:
embedding_xyz = Embedding(10)
embedding_dir = Embedding(4)

nerf_coarse = NeRF()
nerf_fine = NeRF()

# ckpt_path = 'ckpts_old/fern/epoch=29.ckpt'
ckpt_path = 'ckpts_old/lego/epoch=15.ckpt'

load_ckpt(nerf_coarse, ckpt_path, model_name='nerf_coarse')
load_ckpt(nerf_fine, ckpt_path, model_name='nerf_fine')

nerf_coarse.cuda().eval()
nerf_fine.cuda().eval();

In [None]:
models = {'coarse': nerf_coarse, 'fine': nerf_fine}
embeddings = {'xyz': embedding_xyz, 'dir': embedding_dir}

N_samples = 64
N_importance = 64
use_disp = False
chunk = 1024*32*4

@torch.no_grad()
def f(rays):
    """Do batched inference on rays using chunk."""
    B = rays.shape[0]
    results = defaultdict(list)
    for i in range(0, B, chunk):
        rendered_ray_chunks = \
            render_rays(models,
                        embeddings,
                        rays[i:i+chunk],
                        N_samples,
                        use_disp,
                        0,
                        0,
                        N_importance,
                        chunk,
                        dataset.white_back,
                        test_time=True)

        for k, v in rendered_ray_chunks.items():
            results[k] += [v]

    for k, v in results.items():
        results[k] = torch.cat(v, 0)
    return results

In [None]:
sample = dataset[0]
rays = sample['rays'].cuda()

t = time.time()
results = f(rays)
torch.cuda.synchronize()
print(time.time()-t)

In [None]:
img_gt = sample['rgbs'].view(img_wh[1], img_wh[0], 3)
img_pred = results['rgb_fine'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
alpha_pred = results['opacity_fine'].view(img_wh[1], img_wh[0]).cpu().numpy()
depth_pred = results['depth_fine'].view(img_wh[1], img_wh[0])

print('PSNR', metrics.psnr(img_gt, img_pred).item())

plt.subplots(figsize=(15, 8))
plt.tight_layout()
plt.subplot(221)
plt.title('GT')
plt.imshow(img_gt)
plt.subplot(222)
plt.title('pred')
plt.imshow(img_pred)
plt.subplot(223)
plt.title('depth')
plt.imshow(visualize_depth(depth_pred).permute(1,2,0))
plt.show()