In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.

## 1. Imports and Model Loading

In [None]:
import os
import uuid
import imageio
import numpy as np
import torch
from IPython.display import Image as ImageDisplay
from pytorch3d.transforms import Transform3d

from inference import Inference, ready_gaussian_for_video_rendering, load_image, load_masks, display_image, make_scene, render_video, interactive_visualizer

In [None]:
PATH = os.getcwd()
TAG = "hf"
config_path = f"{PATH}/../checkpoints/{TAG}/pipeline.yaml"
inference = Inference(config_path, compile=False)

## 2. Load input image to lift to 3D (multiple objects)

In [None]:
IMAGE_PATH = f"{PATH}/images/shutterstock_stylish_kidsroom_1640806567/image.png"
IMAGE_NAME = os.path.basename(os.path.dirname(IMAGE_PATH))

image = load_image(IMAGE_PATH)
masks = load_masks(os.path.dirname(IMAGE_PATH), extension=".png")
display_image(image, masks)

In [None]:
# Run MoGe depth model to get pointmap
# Access the depth model from the inference pipeline
depth_model = inference._pipeline.depth_model

# Prepare image for depth inference
from sam3d_objects.data.dataset.tdfy.img_and_mask_transforms import get_mask

loaded_image = inference._pipeline.image_to_float(image)
loaded_image = torch.from_numpy(loaded_image)
loaded_mask = loaded_image[..., -1]
loaded_image_rgb = loaded_image.permute(2, 0, 1).contiguous()[:3]

# Run depth inference
with torch.no_grad():
    with torch.autocast(device_type="cuda", dtype=torch.float16):
        depth_output = depth_model(loaded_image_rgb)

# Extract pointmap and convert to PyTorch3D camera convention
pointmap = depth_output["pointmaps"]

In [None]:
# Camera convention transformation (R3 -> PyTorch3D)
from pytorch3d.renderer import look_at_view_transform

r3_to_p3d_R, r3_to_p3d_T = look_at_view_transform(
    eye=np.array([[0, 0, -1]]),
    at=np.array([[0, 0, 0]]),
    up=np.array([[0, -1, 0]]),
    device=pointmap.device,
)

camera_convention_transform = Transform3d(device=pointmap.device).rotate(r3_to_p3d_R)
pointmap = camera_convention_transform.transform_points(pointmap)

print(f"Pointmap shape: {pointmap.shape}")
print(f"Pointmap min: {pointmap.min()}, max: {pointmap.max()}")

In [None]:
# Visualize pointmap
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Assuming pointmap is a tensor of shape (H, W, 3)
pointmap_np = pointmap.cpu().numpy()

# Map position to RGB colors for visualization
normed_x = (pointmap_np[..., 0] - pointmap_np[..., 0].min()) / (pointmap_np[..., 0].max() - pointmap_np[..., 0].min())
normed_y = (pointmap_np[..., 1] - pointmap_np[..., 1].min()) / (pointmap_np[..., 1].max() - pointmap_np[..., 1].min())
normed_z = (pointmap_np[..., 2] - pointmap_np[..., 2].min()) / (pointmap_np[..., 2].max() - pointmap_np[..., 2].min())
color_map = np.stack([normed_x, normed_y, normed_z], axis=-1)

# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

# Visualize color encoding of pointmap
ax1.imshow(color_map)
ax1.set_title('Pointmap Color Visualization', fontsize=14)
ax1.axis('off')

# Visualize the depth (z-coordinate) with matching colorbar height
im = ax2.imshow(pointmap_np[..., 2], cmap='plasma')
ax2.set_title('Pointmap Depth Visualization', fontsize=14)
ax2.axis('off')

# Create colorbar with same height as the image
divider = make_axes_locatable(ax2)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im, cax=cax, label='Depth (Z-coordinate)')

plt.tight_layout()
plt.show()

In [None]:
# Visualize 3D point cloud using Plotly (works in remote Jupyter notebooks)
import plotly.graph_objects as go

# Reshape pointmap to (N, 3) where N = H * W
points_3d = pointmap_np.reshape(-1, 3)

# Get colors from the original image (RGB values in range [0, 255])
image_rgb = loaded_image_rgb.cpu().numpy().transpose(1, 2, 0)  # Convert to (H, W, 3)
colors_rgb = (image_rgb * 255).astype(np.uint8)  # Convert to 0-255 range
colors_flat = colors_rgb.reshape(-1, 3)

# Optional: Downsample for better performance (every Nth point)
max_nr_points = 100000  # Adjust this based on performance needs
if points_3d.shape[0] > max_nr_points:
    downsample_factor = points_3d.shape[0] // max_nr_points
else:
    downsample_factor = 100  # Adjust this to control point cloud density
points_downsampled = points_3d[::downsample_factor]
colors_downsampled = colors_flat[::downsample_factor]

# Filter outliers
valid_mask = np.linalg.norm(points_downsampled, axis=1) < 3.0
points_filtered = points_downsampled[valid_mask]
colors_filtered = colors_downsampled[valid_mask]

print(f"Original points: {len(points_3d):,}")
print(f"Downsampled points: {len(points_downsampled):,}")
print(f"Filtered points: {len(points_filtered):,}")

# Create RGB color strings for Plotly
rgb_colors = [f'rgb({r},{g},{b})' for r, g, b in colors_filtered]

# Create 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=points_filtered[:, 0],
    y=points_filtered[:, 1],
    z=points_filtered[:, 2],
    mode='markers',
    marker=dict(
        size=2,
        color=rgb_colors,
        opacity=0.8
    ),
    text=[f'({x:.2f}, {y:.2f}, {z:.2f})' for x, y, z in points_filtered],
    hoverinfo='text'
)])

# Update layout for better visualization
fig.update_layout(
    title='3D Point Cloud from MoGe Depth Estimation',
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z',
        aspectmode='data',
        camera=dict(
            eye=dict(x=1.5, y=1.5, z=1.5)
        )
    ),
    width=1000,
    height=800,
    margin=dict(l=0, r=0, t=40, b=0)
)

# Show interactive plot in notebook
fig.show()

print("\nInteractive controls:")
print("- Click and drag to rotate")
print("- Scroll to zoom")
print("- Double-click to reset view")

## 3. Generate Gaussian Splats

In [None]:
# Pass the pointmap to inference for all masks
# The pointmap is shared across all objects in the same scene

outputs = [inference(image, mask, seed=42, pointmap=pointmap) for mask in masks]

## 4. Visualize Gaussian Splat of the Scene
### a. Animated Gif

In [None]:
scene_gs = make_scene(*outputs)
scene_gs = ready_gaussian_for_video_rendering(scene_gs)

In [None]:
# export gaussian splatting (as point cloud)
scene_gs.save_ply(f"{PATH}/gaussians/multi_wt_moge/{IMAGE_NAME}.ply")

In [None]:
# render video
video = render_video(
    scene_gs,
    r=1,
    fov=60,
    resolution=512,
)["color"]

# save video as gif
imageio.mimsave(
    os.path.join(f"{PATH}/gaussians/multi_wt_moge/{IMAGE_NAME}.gif"),
    video,
    format="GIF",
    duration=1000 / 30,  # default assuming 30fps from the input MP4
    loop=0,  # 0 means loop indefinitely
)

# notebook display
ImageDisplay(url=f"gaussians/multi_wt_moge/{IMAGE_NAME}.gif?cache_invalidator={uuid.uuid4()}",)