In [1]:
import condorgmm
import rerun as rr
import numpy as np
import matplotlib.pyplot as plt

In [None]:
condorgmm.rr_init("condorgmm")
scene = 49
video = condorgmm.data.YCBTestVideo(scene)
frame = video[0]
plt.imshow(frame.rgb)

In [None]:
import condorgmm.warp_gmm as warp_gmm
import warp as wp
STRIDE = 10
frame_warp = frame.as_warp()
camera_pose = pose

spatial_means = condorgmm.xyz_from_depth_image(
    frame.depth.astype(np.float32), *frame.intrinsics
)[::STRIDE, ::STRIDE].reshape(-1, 3)
rgb_means = frame.rgb[::STRIDE, ::STRIDE].reshape(-1, 3).astype(np.float32)
valid = spatial_means[:, 2] > 0.01
spatial_means = spatial_means[valid]
rgb_means = rgb_means[valid]

spatial_means = camera_pose.apply(spatial_means).astype(np.float32)

mask = frame.depth > 0.01
mask = wp.array(mask, dtype=wp.bool)
gmm = warp_gmm.gmm_warp_from_numpy(spatial_means, rgb_means)
gmm.camera_posquat = wp.array(camera_pose.posquat.astype(np.float32))

warp_gmm_state = warp_gmm.initialize_state(gmm=gmm, frame=frame)
warp_gmm_state.hyperparams.window_half_width = 5
warp_gmm_state.mask = mask

warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)

for _ in range(5):
    warp_gmm.warp_gmm_EM_step(frame_warp, warp_gmm_state)

assert warp_gmm_state.gmm.is_valid()

two_prev_camera_poses = (
    warp_gmm_state.gmm.camera_posquat.numpy(),
    warp_gmm_state.gmm.camera_posquat.numpy(),
)
state = (
    warp_gmm_state,
    two_prev_camera_poses,
)

warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm, "gmm", fill_mode="wireframe")

In [None]:
old_gmm = warp_gmm_state.gmm
warp_gmm_state.hyperparams.window_half_width = 50
warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
plt.matshow(warp_gmm_state.log_score_image.numpy())

In [None]:
import imageio.v2 as imageio
from pathlib import Path

# Create output directory if it doesn't exist
output_dir = Path("gaussian_progression")
output_dir.mkdir(exist_ok=True)

# Get camera intrinsics
fx, fy, cx, cy = frame.intrinsics

# Find gaussians that project onto image
camera_pose_inv = condorgmm.Pose(warp_gmm_state.gmm.camera_posquat.numpy()).inv()
spatial_means = warp_gmm_state.gmm.spatial_means.numpy()
spatial_means = condorgmm.Pose(warp_gmm_state.gmm.object_posquats.numpy()[0]).apply(spatial_means)
points_in_camera = camera_pose_inv.apply(spatial_means)

# Calculate projected pixel coordinates
pixel_x = fx * points_in_camera[:, 0] / points_in_camera[:, 2] + cx
pixel_y = fy * points_in_camera[:, 1] / points_in_camera[:, 2] + cy

# Find gaussians that project within image bounds and have positive z
valid_gaussians = (
    (pixel_x >= 0) & (pixel_x < frame.rgb.shape[1]) &
    (pixel_y >= 0) & (pixel_y < frame.rgb.shape[0]) &
    (points_in_camera[:, 2] > 0)
)
valid_indices = np.where(valid_gaussians)[0]
print(len(valid_indices))

from tqdm import tqdm

# Initialize empty mask
indices = np.zeros(len(warp_gmm_state.gmm.assignments.numpy()), dtype=np.bool_)

# Generate random ordering of valid indices
ordering = np.random.permutation(valid_indices)

ordering[0] = 317
ordering[1] = 318
ordering[2] = 291

# List to store image filenames
image_files = []

# Generate sequence of images adding one gaussian at a time
for i, index in tqdm(list(enumerate(ordering[::2]))): # Limit to first 50 for reasonable gif size
    indices[index] = True
    warp_gmm_state.gaussian_mask = wp.array(indices, dtype=wp.bool)
    warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
    
    plt.figure(figsize=(10,5))
    
    # Left subplot - RGB image with projected points
    plt.subplot(121)
    plt.imshow(frame.rgb)
    plt.axis('off')
    plt.gca().set_frame_on(False)
    plt.margins(0,0)
    
    # Set axis limits to image dimensions to prevent expansion
    plt.xlim(0, frame.rgb.shape[1])
    plt.ylim(frame.rgb.shape[0], 0)  # Reversed for image coordinates
    
    # Plot all selected gaussians so far
    selected_indices = np.where(indices)[0]
    for idx in selected_indices:
        selected_gaussian = spatial_means[idx]
        selected_gaussian_in_camera_frame = points_in_camera[idx]
        pixel_x = int(fx * selected_gaussian_in_camera_frame[0] / selected_gaussian_in_camera_frame[2] + cx)
        pixel_y = int(fy * selected_gaussian_in_camera_frame[1] / selected_gaussian_in_camera_frame[2] + cy)
        
        # Draw window around the latest added gaussian
        if idx == index:
            window_size = warp_gmm_state.hyperparams.window_half_width
            rect = plt.Rectangle(
                (pixel_x - window_size, pixel_y - window_size), 
                2*window_size, 2*window_size,
                fill=False, color='red', linewidth=2,
                clip_on=True  # Ensure rectangle is clipped to axes bounds
            )
            plt.gca().add_patch(rect)
        
        # Plot dot for each gaussian (red for latest, blue for previous)
        color = 'r' if idx == index else 'b'
        plt.plot(pixel_x, pixel_y, f'{color}.', markersize=10)

    # Right subplot - Log score image
    plt.subplot(122)
    plt.imshow(warp_gmm_state.log_score_image.numpy())
    plt.axis('off')
    plt.gca().set_frame_on(False)
    plt.margins(0,0)
    
    # Save figure with consistent size
    filename = output_dir / f"frame_{i:03d}.png"
    plt.savefig(filename, bbox_inches='tight', pad_inches=0, dpi=100)
    plt.close()
    
    image_files.append(filename)

# Create gif from saved images
images = [np.array(imageio.imread(f)) for f in image_files]
# Ensure all images have same size by resizing to first image's dimensions
target_shape = images[0].shape
images = [np.array(Image.fromarray(img).resize((target_shape[1], target_shape[0]))) for img in images]
imageio.mimsave(output_dir / 'gaussian_progression.gif', images, duration=0.2)

In [None]:
len(valid_indices)

In [None]:
indices = np.array([317, 318, 291])
# indices = np.array([317])
new_gmm = mask_gmm(old_gmm, indices)
warp_gmm_state.gmm = new_gmm


len(warp_gmm_state.gmm.assignments.numpy())




In [None]:

warp_gmm_state.gmm = old_gmm
warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
plt.matshow(warp_gmm_state.log_score_image.numpy())
plt.axis('off')
plt.gca().set_frame_on(False)
plt.margins(0,0)
plt.savefig("log_score_image.png", bbox_inches='tight', pad_inches=0)
