In [9]:
import condorgmm
import torch
import gsplat
from condorgmm.ng.torch_utils import render_rgbd
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:0")

condorgmm.rr_init("coarse_models_sweep")

video = condorgmm.data.YCBTestVideo(48)
frame0 = video[0]

fx, fy, cx, cy = frame0.intrinsics
height, width = frame0.depth.shape

viewmat = torch.tensor(
    [
        [1.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0],
        [0.0, 0.0, 1.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ],
    device=device,
    dtype=torch.float32,
)
K = torch.tensor(
    [
        [frame0.fx, 0, frame0.width / 2],
        [0, frame0.fy, frame0.height / 2],
        [0, 0, 1],
    ],
    device=device,
    dtype=torch.float32,
)

object_mask = np.full(frame0.depth.shape, True)

camera_posquat = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device=device, requires_grad=True)
posquat = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device=device, requires_grad=True)
means = torch.tensor([[0.0, 0.0, 1.0]], device=device, requires_grad=True)
quats = torch.tensor([[0.0, 0.0, 0.0, 1.0]], device=device, requires_grad=True)
scales = torch.tensor([[0.01, 0.01, 0.01]], device=device, requires_grad=True)
opacities = torch.tensor([1.0], device=device, requires_grad=True)
rgbs = torch.tensor([[1.0, 0.0, 0.0]], device=device, requires_grad=True)

rendered_rgb, rendered_depth, rendered_silhouette = render_rgbd(
    camera_posquat,
    posquat,
    means,
    quats,
    torch.exp(scales),
    torch.sigmoid(opacities),
    rgbs,
    viewmat[None],
    K[None],
    frame0.width,
    frame0.height,
)



AttributeError: module 'gsplat_cuda' has no attribute 'projection_ewa_3dgs_fused_fwd'

In [None]:
print(f"Running for {num_gaussians} Gaussians")
sampled_indices = np.random.choice(
    object_mask.sum(), num_gaussians, replace=False
)

rgb_means_np = frame0.rgb[object_mask][sampled_indices].astype(np.float32)
spatial_means_np = condorgmm.utils.common.xyz_from_depth_image(
    frame0.depth, fx, fy, cx, cy
)[object_mask][sampled_indices].astype(np.float32)

rgb_means_np = frame0.rgb[object_mask][sampled_indices].astype(np.float32)
spatial_means_np = condorgmm.utils.common.xyz_from_depth_image(
    frame0.depth, fx, fy, cx, cy
)[object_mask][sampled_indices].astype(np.float32)
# Subsample the points

means = torch.tensor(
    spatial_means_np, device=device, requires_grad=True, dtype=torch.float32
)
rgbs = torch.tensor(
    rgb_means_np, device=device, requires_grad=True, dtype=torch.float32
)

quats = torch.randn(num_gaussians, 4, device=device, requires_grad=True)
quats = torch.tensor(
    np.tile(np.array([0.0, 0.0, 0.0, 1.0]), (num_gaussians, 1)),
    device=device,
    requires_grad=True,
    dtype=torch.float32,
)
scales = torch.tensor(
    torch.log(torch.randn(num_gaussians, 3) * 0.01),
    device=device,
    requires_grad=True,
    dtype=torch.float32,
)
opacities = torch.tensor(
    np.ones(num_gaussians) * 4.0,
    device=device,
    requires_grad=True,
    dtype=torch.float32,
)

# Get target image
target_img = torch.tensor(frame0.rgb, device=device).float()
target_depth = torch.tensor(frame0.depth, device=device).float()

object_mask_torch = torch.tensor(object_mask, device=device).float()
# object_mask = torch.tensor(frame0.masks[object_index], device=device).float()

scales.requires_grad_ = True
opacities.requires_grad_ = True

posquat = torch.tensor(
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device=device, requires_grad=True
)
camera_posquat = torch.tensor(
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], device=device, requires_grad=True
)

# Setup optimizer
optimizer = torch.optim.Adam([means, rgbs, opacities, scales], lr=2e-3)
n_steps = 1500
pbar = tqdm(range(n_steps))
for step in pbar:
    optimizer.zero_grad()

    # Forward pass
    rendered_rgb, rendered_depth, rendered_silhouette = render_rgbd(
        camera_posquat,
        posquat,
        means,
        quats,
        torch.exp(scales),
        torch.sigmoid(opacities),
        rgbs,
        viewmat[None],
        K[None],
        frame0.width,
        frame0.height,
    )