In [None]:
try: # install all dependencies in colab 
    import google.colab
    !sudo apt-get update -y
    !sudo apt-get install -y libnvidia-gl-555 vulkan-tools glslang-tools vulkan-validationlayers-dev
    !pip install pyav==13.1.0
    !pip install git+https://github.com/rendervous/rendervous_project.git
except:
    print("Executing locally")

In [2]:
import torch
import rendervous as rdv
import matplotlib.pyplot as plt
import vulky.datasets as datasets
import numpy as np
from tqdm import tqdm


# load the disney cloud as a tensor
cloud = datasets.Volumes.disney_cloud

camera = rdv.PerspectiveCameraSensor(512, 512, rdv.look_at_poses((-2.8, 0.2, -.4)))

# create a grid map as reference
bmin, bmax = rdv.normalized_box(cloud)
grid = rdv.Grid3D(cloud, bmin, bmax)

# create a latent grid to represent the compact feature
latent = torch.nn.Parameter(torch.zeros(16, 16, 16, 8, device=rdv.device()))
latent_grid = rdv.Grid3D(latent, bmin, bmax)
# create a MLP to represent the scene
def dense(input_dim, output_dim):
    k = 1 / input_dim
    A_0 = torch.nn.Parameter((torch.rand(output_dim, input_dim, device=rdv.device())*2 - 1)*np.sqrt(k))
    B_0 = torch.nn.Parameter((2 * torch.rand(output_dim, device=rdv.device())-1)*np.sqrt(k))
    return A_0 @ rdv.X + rdv.const[B_0]

maps = [dense(8, 32), rdv.relu, dense(32, 32), rdv.relu, dense(32, 32), rdv.relu, dense(32, 1)]
mlp = None
for m in maps: mlp = m if mlp is None else mlp.then(m)

rep_map = latent_grid.then(mlp)

# train the representation
bmin, bmax = bmin.to(rdv.device()), bmax.to(rdv.device())
opt = torch.optim.NAdam(list(mlp.parameters())+[latent], lr=0.002)
sch = torch.optim.lr_scheduler.OneCycleLR(opt, 0.002, 1000)
steps_iterator = tqdm(range(1000))
for s in steps_iterator:
    with torch.no_grad():
        x = torch.rand(32*1024, 3, device=rdv.device()) * (bmax - bmin) + bmin
        ref_values = grid(x)

    opt.zero_grad()
    inf_values = rep_map(x)
    loss = torch.nn.functional.mse_loss(ref_values, inf_values, reduction='sum')
    loss.backward()
    opt.step()
    # sch.step()
    steps_iterator.set_description_str(f"Loss: {loss.item()}")

Loss: 4733.5:  53%|█████▎    | 532/1000 [01:14<01:05,  7.19it/s]          


KeyboardInterrupt: 