In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

## 1. Imports and Model Loading

In [None]:
import os
import uuid
import imageio
import numpy as np
import torch
from IPython.display import Image as ImageDisplay
from pytorch3d.transforms import Transform3d
from pytorch3d.renderer import look_at_view_transform

from inference import Inference, ready_gaussian_for_video_rendering, load_image, display_image, make_scene, render_video

In [None]:
PATH = os.getcwd()
TAG = "hf"
config_path = f"{PATH}/../checkpoints/{TAG}/pipeline.yaml"
inference = Inference(config_path, compile=False)

## 2. Load input image to lift to 3D (multiple objects)

In [None]:
from utils import load_image, load_masks, create_gaussians_from_depth, render_frame, create_gaussians_object

In [None]:
DATASET_PATH = "/mnt/lustre/work/geiger/gwb987/data/kubric4d"
SCENE_NAME = "scn02719"
DATA_PATH = os.path.join(DATASET_PATH, SCENE_NAME)
FRAMES_PATH = os.path.join(DATA_PATH, "frames_p0_v0")  # viewpoint 0

# get all rgba_0000.png images in FRAMES_PATH
IMAGE_NAMES = sorted([f for f in os.listdir(FRAMES_PATH) if f.startswith("rgba_") and f.endswith(".png")])
IMAGE_PATHS = [os.path.join(FRAMES_PATH, name) for name in IMAGE_NAMES]
IMAGE_PATH = IMAGE_PATHS[0]
image = load_image(IMAGE_PATH)
# drop alpha channel
image = image[..., :3]

H, W, _ = image.shape

# get all segmentation_0000.png masks in FRAMES_PATH
MASK_NAMES = sorted([f for f in os.listdir(FRAMES_PATH) if f.startswith("segmentation_") and f.endswith(".png")])
MASK_PATHS = [os.path.join(FRAMES_PATH, name) for name in MASK_NAMES]
MASK_PATH = MASK_PATHS[0]

masks = load_masks(MASK_PATH)
display_image(image, masks)

print(f"Image shape: {image.shape}, dtype: {image.dtype}, min: {image.min()}, max: {image.max()}")

In [None]:
# Convert depth map to point map using camera intrinsics
# Note: fx and fy should correspond to focal lengths along x (width) and y (height)
fx = float(W)  # focal length in x direction (width)
fy = float(H)  # focal length in y direction (height)
cx = W / 2.0  # principal point x (image center)
cy = H / 2.0  # principal point y (image center)

K_matrix = np.eye(3)
K_matrix[0, 0] = fx
K_matrix[1, 1] = fy
K_matrix[0, 2] = cx
K_matrix[1, 2] = cy

In [None]:
using_moge = True

if not using_moge:
    
    # Load depth maps
    DEPTH_NAMES = sorted([f for f in os.listdir(FRAMES_PATH) if f.startswith("depth_") and f.endswith(".tiff")])
    DEPTH_PATHS = [os.path.join(FRAMES_PATH, name) for name in DEPTH_NAMES]
    DEPTH_PATH = DEPTH_PATHS[0]
    depth_map = load_image(DEPTH_PATH, to_uint8=False)

    # normalize depth to max 1 meter
    # depth_map = depth_map / depth_map.max()

    print(f"Using camera intrinsics: fx={fx}, fy={fy}, cx={cx}, cy={cy}")
    print("Radial depth map with shape:", depth_map.shape, "dtype:", depth_map.dtype, "min:", depth_map.min(), "max:", depth_map.max())

    from utils import radial_to_z_depth

    # Convert radial depth to z-depth
    depth_map_z = radial_to_z_depth(depth_map, fx, fy, cx, cy)
    # depth_map_z = depth_map
    print(f"Converted radial depth map to z-depth map with shape: {depth_map_z.shape}, dtype: {depth_map_z.dtype}, min: {depth_map_z.min()}, max: {depth_map_z.max()}")

    # Generate 3D point cloud from z-depth
    # Create pixel coordinate grids (u, v)
    v_coords, u_coords = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')

    # Convert to 3D coordinates using pinhole camera model
    z = depth_map_z
    x = (u_coords - cx) * z / fx
    y = (v_coords - cy) * z / fy

    points = np.stack((x, y, z), axis=-1)  # (H, W, 3)
    pointmap = torch.from_numpy(points).float()  # (H, W, 3)
    
    print(f"Generated pointmap with shape: {pointmap.shape}, min: {pointmap.min():.3f}, max: {pointmap.max():.3f}")

else:
    
    # Run MoGe depth model to get pointmap
    # Access the depth model from the inference pipeline
    depth_model = inference._pipeline.depth_model

    # Prepare image for depth inference

    loaded_image = inference._pipeline.image_to_float(image)
    loaded_image = torch.from_numpy(loaded_image)
    loaded_mask = loaded_image[..., -1]
    loaded_image_rgb = loaded_image.permute(2, 0, 1).contiguous()[:3]

    # Run depth inference
    with torch.no_grad():
        with torch.autocast(device_type="cuda", dtype=torch.float16):
            depth_output = depth_model(loaded_image_rgb)

    # Extract pointmap and convert to PyTorch3D camera convention
    pointmap = depth_output["pointmaps"]
    depth_map_z = depth_output["depth"].cpu().numpy()

In [None]:
# Extract pointmap and convert to PyTorch3D camera convention

# Camera convention transformation (R3 -> PyTorch3D)

r3_to_p3d_R, r3_to_p3d_T = look_at_view_transform(
    eye=np.array([[0, 0, -1]]),
    at=np.array([[0, 0, 0]]),
    up=np.array([[0, -1, 0]]),
    device=pointmap.device,
)

camera_convention_transform = Transform3d(device=pointmap.device).rotate(r3_to_p3d_R)
pointmap = camera_convention_transform.transform_points(pointmap)

print(f"Pointmap shape: {pointmap.shape}")
print(f"Pointmap min: {pointmap.min()}, max: {pointmap.max()}")

In [None]:
output_path = "gaussians.ply"
gaussians = create_gaussians_from_depth(
    image=image,
    depth=depth_map_z,
    fx=fx,
    fy=fy,
    cx=cx,
    cy=cy,
    normalize_depth=False,
    output_path=output_path,
)

# Render Gaussians
c2w = torch.eye(4)

# Create intrinsics matrix (3x3)
K = torch.from_numpy(K_matrix).float()

# Render the frame
rendered_frame, rendered_alpha = render_frame(
    gaussians, 
    c2w=c2w, 
    K=K, 
    w=W, 
    h=H
)

# Display the rendered frame alongside the original image
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

ax1.imshow(image)
ax1.set_title('Original Image', fontsize=14)
# ax1.axis('off')

ax2.imshow(rendered_frame.cpu().numpy())
ax2.set_title('Rendered from Gaussian Splats', fontsize=14)
#   ax2.axis('off')

plt.tight_layout()
plt.show()
# plt.savefig("rendered_gaussians.png")

In [None]:
# Visualize pointmap
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Assuming pointmap is a tensor of shape (H, W, 3)
pointmap_np = pointmap.cpu().numpy()

# Map position to RGB colors for visualization
normed_x = (pointmap_np[..., 0] - pointmap_np[..., 0].min()) / (pointmap_np[..., 0].max() - pointmap_np[..., 0].min())
normed_y = (pointmap_np[..., 1] - pointmap_np[..., 1].min()) / (pointmap_np[..., 1].max() - pointmap_np[..., 1].min())
normed_z = (pointmap_np[..., 2] - pointmap_np[..., 2].min()) / (pointmap_np[..., 2].max() - pointmap_np[..., 2].min())
color_map = np.stack([normed_x, normed_y, normed_z], axis=-1)

# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

# Visualize color encoding of pointmap
ax1.imshow(color_map)
ax1.set_title('Pointmap Color Visualization', fontsize=14)
ax1.axis('off')

# Visualize the depth (z-coordinate) with matching colorbar height
im = ax2.imshow(pointmap_np[..., 2], cmap='plasma')
ax2.set_title('Pointmap Depth Visualization', fontsize=14)
ax2.axis('off')

# Create colorbar with same height as the image
divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, label='Depth (Z-coordinate)')

plt.tight_layout()
plt.show()

In [None]:
# Visualize 3D point cloud using Plotly (works in remote Jupyter notebooks)
import plotly.graph_objects as go

# Reshape pointmap to (N, 3) where N = H * W
points_3d = pointmap_np.reshape(-1, 3)

# Get colors from the original image (RGB values in range [0, 255])
# image_rgb = image  # Convert to (H, W, 3)
colors_rgb = image # (image_rgb * 255).astype(np.uint8)  # Convert to 0-255 range
colors_flat = colors_rgb.reshape(-1, 3)

# bgr to rgb
# colors_flat = colors_flat[:, ::-1]

# Optional: Downsample for better performance (every Nth point)
max_nr_points = 100000  # Adjust this based on performance needs
if points_3d.shape[0] > max_nr_points:
    downsample_factor = points_3d.shape[0] // max_nr_points
else:
    downsample_factor = 100  # Adjust this to control point cloud density
points_downsampled = points_3d[::downsample_factor]
colors_downsampled = colors_flat[::downsample_factor]

# # Filter outliers
# valid_mask = np.linalg.norm(points_downsampled, axis=1) < 3.0
# points_filtered = points_downsampled[valid_mask]
# colors_filtered = colors_downsampled[valid_mask]
points_filtered = points_downsampled
colors_filtered = colors_downsampled

print(f"Original points: {len(points_3d):,}")
print(f"Downsampled points: {len(points_downsampled):,}")
print(f"Filtered points: {len(points_filtered):,}")

# Create RGB color strings for Plotly
rgb_colors = [f'rgb({r},{g},{b})' for r, g, b in colors_filtered]

# Create 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=points_filtered[:, 0],
    y=points_filtered[:, 1],
    z=points_filtered[:, 2],
    mode='markers',
    marker=dict(
        size=2,
        color=rgb_colors,
        opacity=0.8
    ),
    text=[f'({x:.2f}, {y:.2f}, {z:.2f})' for x, y, z in points_filtered],
    hoverinfo='text'
)])

# Update layout for better visualization
fig.update_layout(
    title='3D Point Cloud from Depth',
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        aspectmode='data',
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    width=1000,
    height=800,
    margin=dict(l=0, r=0, t=40, b=0)
)

# Show interactive plot in notebook
fig.show()

print("\nInteractive controls:")
print("- Click and drag to rotate")
print("- Scroll to zoom")
print("- Double-click to reset view")

## 3. Generate Gaussian Splats

In [None]:
# Pass the pointmap to inference for all masks
# The pointmap is shared across all objects in the same scene

# time outputs for mask
import time

outputs = []
for mask in masks:
    start_time = time.time()
    output = inference(image, mask, seed=42, pointmap=pointmap)
    end_time = time.time()
    print(f"Inference time for mask: {end_time - start_time:.2f} seconds")
    outputs.append(output)

## 4. Visualize Gaussian Splat of the Scene

In [None]:
scene_gs = make_scene(*outputs)

In [None]:
# Visualize 3D point cloud using Plotly (works in remote Jupyter notebooks)
import plotly.graph_objects as go

# visualize means as point cloud
gaussians_xyz = scene_gs.get_xyz  # (N, 3)
gaussians_features = scene_gs.get_features.squeeze(1)  # (N, 3)
print(f"xyz shape: {gaussians_xyz.shape}, min: {gaussians_xyz.min()}, max: {gaussians_xyz.max()}")
print(f"features shape: {gaussians_features.shape}, min: {gaussians_features.min()}, max: {gaussians_features.max()}")

def SH2RGB(sh):
    C0 = 0.28209479177387814
    rgb = sh * C0 + 0.5
    return rgb

# convert features to rgb (sh0)
gaussians_rgb = SH2RGB(gaussians_features)
# clip to [0, 1]
gaussians_rgb = torch.clamp(gaussians_rgb, 0.0, 1.0)
print(f"rgb shape: {gaussians_rgb.shape}, min: {gaussians_rgb.min()}, max: {gaussians_rgb.max()}")

# Reshape pointmap to (N, 3) where N = H * W
points_3d = gaussians_xyz.cpu().numpy()
colors_flat = gaussians_rgb.cpu().numpy() * 255
colors_flat = colors_flat.astype(np.uint8)

# Optional: Downsample for better performance (every Nth point)
max_nr_points = 100000  # Adjust this based on performance needs
if points_3d.shape[0] > max_nr_points:
    downsample_factor = points_3d.shape[0] // max_nr_points
else:
    downsample_factor = 100  # Adjust this to control point cloud density
points_downsampled = points_3d[::downsample_factor]
colors_downsampled = colors_flat[::downsample_factor]

# Filter outliers
points_filtered = points_downsampled
colors_filtered = colors_downsampled

print(f"Original points: {len(points_3d):,}")
print(f"Downsampled points: {len(points_downsampled):,}")
print(f"Filtered points: {len(points_filtered):,}")

# Create RGB color strings for Plotly
rgb_colors = [f'rgb({r},{g},{b})' for r, g, b in colors_filtered]

# Create 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=points_filtered[:, 0],
    y=points_filtered[:, 1],
    z=points_filtered[:, 2],
    mode='markers',
    marker=dict(
        size=2,
        color=rgb_colors,
        opacity=0.8
    ),
    text=[f'({x:.2f}, {y:.2f}, {z:.2f})' for x, y, z in points_filtered],
    hoverinfo='text'
)])

# Update layout for better visualization
fig.update_layout(
    title='Predicted 3D Point Cloud',
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        aspectmode='data',
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    width=1000,
    height=800,
    margin=dict(l=0, r=0, t=40, b=0)
)

# Show interactive plot in notebook
fig.show()

print("\nInteractive controls:")
print("- Click and drag to rotate")
print("- Scroll to zoom")
print("- Double-click to reset view")

In [None]:
# Get the denormalized xyz coordinates
xyz_unnormalized = scene_gs.get_xyz  # This applies: xyz * aabb[3:] + aabb[:3]

# Camera convention transformation (R3 -> PyTorch3D)

r3_to_p3d_R, r3_to_p3d_T = look_at_view_transform(
    eye=np.array([[0, 0, -1]]),
    at=np.array([[0, 0, 0]]),
    up=np.array([[0, -1, 0]]),
    device=pointmap.device,
)

# inverse transform (PyTorch3D -> R3)
p3d_to_r3_R = r3_to_p3d_R.transpose(1, 2)

camera_convention_transform = Transform3d(device=scene_gs.get_xyz.device).rotate(p3d_to_r3_R)
xyz = camera_convention_transform.transform_points(xyz_unnormalized)

# create new Gaussians object

new_scene_gs = create_gaussians_object(
    xyz=xyz,
    features=scene_gs.get_features,
    scales=scene_gs.get_scaling,
    rots=scene_gs.get_rotation,
    opacities=scene_gs.get_opacity,
)

# export gaussian splatting (as point cloud)
scene_gs.save_ply(f"{PATH}/gaussians/kubric4d/{SCENE_NAME}.ply")

# Render from the original camera viewpoint

# Alternative: use identity matrix (camera at origin looking along z-axis)
c2w = torch.eye(4)

# Create intrinsics matrix (3x3)
K_matrix = np.eye(3)
K_matrix[0, 0] = fx
K_matrix[1, 1] = fy
K_matrix[0, 2] = cx
K_matrix[1, 2] = cy
K = torch.from_numpy(K_matrix).float()

# Render the frame
rendered_frame, _ = render_frame(
    new_scene_gs, 
    c2w=c2w, 
    K=K, 
    w=W, 
    h=H,
)

# Display the rendered frame alongside the original image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

ax1.imshow(image)
ax1.set_title('Original Image', fontsize=14)
ax1.axis('off')

ax2.imshow(rendered_frame.cpu().numpy())
ax2.set_title('Rendered from Gaussian Splats', fontsize=14)
ax2.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# render video
video_scene_gs = ready_gaussian_for_video_rendering(scene_gs)
video = render_video(
    video_scene_gs,
    r=1,
    fov=60,
    resolution=512,
)["color"]

# save video as gif
imageio.mimsave(
    os.path.join(f"{PATH}/gaussians/kubric4d/{SCENE_NAME}.gif"),
    video,
    format="GIF",
    duration=1000 / 30,  # default assuming 30fps from the input MP4
    loop=0,  # 0 means loop indefinitely
)

# notebook display
ImageDisplay(url=f"gaussians/kubric4d/{SCENE_NAME}.gif?cache_invalidator={uuid.uuid4()}",)