# Ray cast vizualizations

In [1]:
# Imports

import torch
import torch.nn as nn
import numpy as np
import open3d
import trimesh
import tensorflow as tf
from tensorflow.compat.v1 import enable_eager_execution
enable_eager_execution()
import matplotlib.pyplot as plt

ModuleNotFoundError: No module named 'open3d'

In [3]:
# Data Loading

data = np.load('tiny_nerf_data.npz')
images = data['images']
poses = data['poses']
focal = data['focal']
H, W = images.shape[1:3]
print(images.shape, poses.shape, focal)

testimg, testpose = images[101], poses[101]
images = images[:100,...,:3]
poses = poses[:100]

plt.imshow(testimg)
plt.axis("off")
plt.show()


FileNotFoundError: [Errno 2] No such file or directory: 'tiny_nerf_data.npz'

In [None]:
# Helper functions

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
        
def pos_enc(a,L):
    
    x=[torch.sin(2.**i*a) for i in range(20)] + [torch.sin(2.**i*np.pi*a) for i in range(20)] + [a]
    return torch.cat(x,dim=0)

class NeRF(nn.Module):
    
    def __init__(self,Lp):
        super().__init__()
        
        module = []
        
        module.extend([nn.Linear(3*2*Lp+3,256),nn.ReLU()])
        
        for i in range(7):
            module.extend([nn.Linear(256,256),nn.ReLU()])
        
        module.extend([nn.Linear(256,4)])
        self.nerf = nn.Sequential(*module)
        
        self.apply(lambda x: init_weights(x))
        
    def forward(self, input):
        
        inp = pos_enc(input)
        
        rgba = self.nerf(inp)
        return torch.sigmoid(rgba[:,:3]), torch.relu(rgba[:,3])
    

        

In [None]:
nerf=NeRF(6)
optim=torch.optim.Adam(nerf.parameters(),0.001,(0.9,0.99))

In [None]:
# Tf compares
def get_rays(H, W, focal, c2w):
    i, j = tf.meshgrid(tf.range(W, dtype=tf.float32), tf.range(H, dtype=tf.float32), indexing='xy')
    dirs = tf.stack([(i*10-50)/focal, -(j*10-50)/focal, -tf.ones_like(i)], -1)
    rays_d = tf.reduce_sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)
    rays_o = tf.broadcast_to(c2w[:3,-1], tf.shape(rays_d))
    return rays_o, rays_d, dirs
rays_o, rays_d, dirs = get_rays(H/10,W/10, focal, poses[1])

##
c2w=torch.from_numpy(poses[0])
gur=th.bmm(p,c2w[:3,:3].T.view(1,3,3))

##
rays_o, rays_d
N_samples = 10
z_vals = tf.linspace(5.0, 8.0, N_samples)
z_vals += tf.random.uniform(list(rays_o.shape[:-1]) + [N_samples]) * (8.0-5.0)/N_samples
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None]

In [None]:
import torch as th

def cast_rays(H, W, focal, scale, c2w, ns):
    
    nx = H // scale
    ny = W // scale
    
    
    
    # Sample front plane
    yy,xx = torch.meshgrid(torch.linspace(0,ny-1,ny),torch.linspace(0,nx-1,nx))
    o = torch.Tensor(1,nx*ny,1).fill_(1)
    
    ix = xx.reshape(1,-1)*scale
    iy = yy.reshape(1,-1)*scale
    points = th.cat([(ix.unsqueeze(-1)-W/2)/focal, -(iy.unsqueeze(-1)-H/2)/focal, -o], dim=-1)
    
    camtrans = -torch.matmul(c2w[:3,-1],c2w[:3,:3])
    near_distance = camtrans
    far_distance = camtrans+5
    
    t=torch.linspace(0,1,ns)
    t_noisy = t.view(1,ns,1,1) + torch.Tensor(1,ns,nx*ny,1).uniform_(0,1)/ns
    
    t_scale = t_noisy*far_distance + (1-t_noisy)*near_distance
    points_move = t_scale * points.view(1,1,nx*ny,3)
    
    points_world = th.bmm((points_move).view(1,-1,3),c2w[:3,:3].T.view(1,3,3)) 
    
    return points, points_move.view(1,-1,3), points_world

def raytrace(z,sigma_a,rgb):
    dists = th.cat([z[..., 1:,:] - z[..., :-1,:], th.ones_like(z[...,:1,:]).fill_(1e10)], -1) 
    alpha = 1.-th.exp(-sigma_a * dists)  
    weights = alpha * th.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    
    rgb_map = (weights[...,None] * rgb).sum(dim=-2) 
    depth_map =(weights * z_vals).sum(dim=-1)
    acc_map = weights.sum(dim=-1)

    return rgb_map, depth_map, acc_map

In [None]:
# Train

N_samples = 64
N_iters = 1000
psnrs = []
iternums = []
i_plot = 25

import time
t = time.time()
for i in range(N_iters+1):
    
    img_i = np.random.randint(images.shape[0])
    target = torch.from_numpy(images[img_i]).unsqueeze(0)
    pose = torch.from_numpy(poses[img_i]).unsqueeze(0)
    p,pm,pw=cast_rays(H,W,focal,10,torch.from_numpy(poses[2]),10)
    
    rgb,alpha=nerf(pw)
    
    loss = th.mean((rgb - target).pow(2))
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    if i%i_plot==0:
        print(i, (time.time() - t) / i_plot, 'secs per iter')
        t = time.time()
        
        # Render the holdout view for logging
        rays_o, rays_d = get_rays(H, W, focal, testpose)
        rgb, depth, acc = render_rays(model, rays_o, rays_d, near=2., far=6., N_samples=N_samples)
        loss = tf.reduce_mean(tf.square(rgb - testimg))
        psnr = -10. * tf.math.log(loss) / tf.math.log(10.)

        psnrs.append(psnr.numpy())
        iternums.append(i)
        
        plt.figure(figsize=(10,4))
        plt.subplot(121)
        plt.imshow(rgb)
        plt.title(f'Iteration: {i}')
        plt.subplot(122)
        plt.plot(iternums, psnrs)
        plt.title('PSNR')
        plt.show()

print('Done')

In [None]:
import open3d

p,pm,pw=cast_rays(H,W,focal,10,torch.from_numpy(poses[0]),10)
p,pm,pw_2=cast_rays(H,W,focal,10,torch.from_numpy(poses[2]),10)

In [None]:
mesh=trimesh.load('./mesh.obj')
verts=np.array(mesh.vertices)

In [None]:
# Point clouds
viz=open3d.JVisualizer()

pcd1 = open3d.geometry.PointCloud()
pcd1.points = open3d.utility.Vector3dVector(p[0].numpy())
pcd1.paint_uniform_color(np.array([1,0,0],dtype=np.float32))

pcd2 = open3d.geometry.PointCloud()
pcd2.points = open3d.utility.Vector3dVector(pm[0].numpy())
pcd2.paint_uniform_color(np.array([0,1,0],dtype=np.float32))

pcd3 = open3d.geometry.PointCloud()
pcd3.points = open3d.utility.Vector3dVector(pw[0].numpy())
pcd3.paint_uniform_color(np.array([0,0,1],dtype=np.float32))

pcd4 = open3d.geometry.PointCloud()
pcd4.points = open3d.utility.Vector3dVector(pw_2[0].numpy())
pcd4.paint_uniform_color(np.array([1,0,1],dtype=np.float32))

pcd5=open3d.geometry.PointCloud()
pcd5.points = open3d.utility.Vector3dVector(verts)
pcd5=pcd5.uniform_down_sample(100)
pcd5.paint_uniform_color(np.array([0,1,1],dtype=np.float32))

In [None]:
viz=open3d.JVisualizer()

viz.add_geometry(pcd5)
viz.add_geometry(pcd1)
viz.show()

In [None]:
viz=open3d.JVisualizer()

viz.add_geometry(pcd5)
viz.add_geometry(pcd2)
viz.show()

In [None]:
viz=open3d.JVisualizer()

viz.add_geometry(pcd5)
viz.add_geometry(pcd3)
viz.show()

In [None]:
plt.imshow(images[0])
plt.axis("off")

In [None]:
viz=open3d.JVisualizer()

viz.add_geometry(pcd5)
viz.add_geometry(pcd3)
viz.add_geometry(pcd4)
viz.show()

## Debug

In [None]:
pcd5 = open3d.geometry.PointCloud()
pcd5.points = open3d.utility.Vector3dVector(pws[0].numpy())
pcd5.paint_uniform_color(np.array([1,0,1],dtype=np.float32))

pcd6 = open3d.geometry.PointCloud()
pcd6.points = open3d.utility.Vector3dVector(pts_tp)
pcd6.paint_uniform_color(np.array([1,1,0],dtype=np.float32))

In [None]:
pws=pw.clamp(-1.5,1.5)

In [None]:
#viz.add_geometry(pcd3)
viz=open3d.JVisualizer()

viz.add_geometry(pcd5)
#viz.add_geometry(pcd6)
viz.add_geometry(pcd4)
viz.show()
del viz

In [None]:
mesh=trimesh.load('./mesh.obj')

In [None]:
g=np.array(mesh.vertices)

In [None]:
g[:,0]=g[:,0]
g[:,1]=g[:,1]
g[:,2]=-g[:,2]

In [None]:
pcd4=open3d.geometry.PointCloud()
pcd4.points = open3d.utility.Vector3dVector(k)
pcd4=pcd4.uniform_down_sample(100)
pcd4.paint_uniform_color(np.array([0,0,1],dtype=np.float32))


In [None]:
pcd5=open3d.geometry.PointCloud()
pcd5.points = open3d.utility.Vector3dVector(points[0].numpy())
pcd5=pcd5.uniform_down_sample(100)
pcd5.paint_uniform_color(np.array([1,1,1],dtype=np.float32))

In [None]:
k=g[:,(1,0,2)]
l=th.from_numpy(k)

In [None]:
lm=th.cat([l.float().unsqueeze(0),torch.ones_like(l).float().unsqueeze(0)],dim=-1)
lm=lm[:,:,:4]

In [None]:
c2w=torch.from_numpy(poses[2])
points=th.bmm(l.float().unsqueeze(0),c2w[:3,:3].unsqueeze(0)) 

In [None]:
z=l[:,:2]
pts=z*focal

In [None]:
plt.imshow(images[2])
plt.scatter(pts[:,0]+50,pts[:,1]+50,alpha=0.3)

In [None]:
pts.shape

In [None]:
plt.imshow(images[2])

In [None]:
poses[0]