In [None]:
import condorgmm
import warp as wp
wp.init()
from condorgmm.utils.common import get_assets_path
import condorgmm.data as data
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import warp as wp
from condorgmm.warp_gmm.adam import Adam
from condorgmm.warp_gmm.state import State
import importlib
importlib.reload(condorgmm.warp_gmm.adam)
condorgmm.rr_init("test")

In [None]:
H,W = 50,50
rgb = np.zeros((H,W,3), dtype=np.float32)
depth = np.ones((H,W), dtype=np.float32) * 2.0
intrinsics = np.array([100.0, 100.0, W/2.0, H/2.0])
camera_pose = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
frame = condorgmm.Frame(rgb=rgb, depth=depth, intrinsics=intrinsics, camera_pose=camera_pose)

frame_warp = frame.as_warp()

In [None]:
import warp as wp
import numpy as np
from tqdm import tqdm
import rerun as rr


condorgmm.rr_init("warp")

@wp.func
def safe_normalize(quaternion: wp.quat):
    return quaternion / wp.sqrt(wp.dot(quaternion, quaternion) + 1e-10)

# Define warp kernel for computing loss
@wp.kernel
def compute_gaussian_loss(
    points: wp.array(dtype=wp.vec3, ndim=1),
    mean: wp.array(dtype=wp.vec3, ndim=1),
    quat_imag: wp.array(dtype=wp.vec3, ndim=1),
    quat_real: wp.array(dtype=wp.float32, ndim=1), 
    log_scales: wp.array(dtype=wp.vec3, ndim=1),
    losses: wp.array(dtype=wp.float32, ndim=1),
):
    idx = wp.tid()
    if idx >= points.shape[0]:
        return
        
    # Get point
    p = points[idx]
    
    # Compute difference from mean
    diff = p - mean[0]
    
    # Build covariance from quaternion and scales
    quaternion = safe_normalize(wp.quat(quat_imag[0], quat_real[0]))
    rot_matrix = wp.quat_to_matrix(quaternion)
    scales = wp.vec3(wp.exp(log_scales[0][0]), wp.exp(log_scales[0][1]), wp.exp(log_scales[0][2]))

    diagonal_epsilon = wp.vec3(1e-10, 1e-10, 1e-10)  # Increased epsilon
    cov = (
        rot_matrix
        * wp.diag(scales)
        * wp.diag(scales)
        * wp.transpose(rot_matrix)
    ) + wp.diag(diagonal_epsilon)
    cov_inv = wp.inverse(cov)
    log_det_cov = 2.0 * (wp.log(scales[0]) + wp.log(scales[1]) + wp.log(scales[2]))

    logpdf = 0.5 * (
        3.0 * wp.log(2.0 * wp.pi)
        + log_det_cov
        + wp.dot(diff, cov_inv * diff)
    )

    losses[idx] = logpdf


    # scale_matrix = jnp.diag(scales)
    # cov = rot_matrix @ jnp.diag(scales**2) @ rot_matrix.T
    
    # # Compute multivariate Gaussian negative log likelihood
    # centered = point - mean
    # log_det = jnp.sum(jnp.log(scales)) * 2  # log determinant of covariance
    # mahalanobis = jnp.sum(centered * (jnp.linalg.solve(cov, centered)))

    # return 0.5 * (3 * jnp.log(2 * jnp.pi) + log_det + mahalanobis)





def score_fn(warp_observed_points, mean, quat_imag, quat_real, log_scales, loss):
    wp.launch(kernel=compute_gaussian_loss, dim=(warp_observed_points.shape[0],), inputs=(warp_observed_points, mean, quat_imag, quat_real, log_scales, loss))

# The variable 'points' is not defined - should be warp_observed_points
warp_observed_points = wp.array(np.array(condorgmm.xyz_from_depth_image(frame.depth, *frame.intrinsics).reshape(-1,3)), dtype=wp.vec3)

mean = wp.array(np.array([[0.0, 0.0, 2.2]], dtype=np.float32), dtype=wp.vec3, requires_grad=True)
quat_imag = wp.array(np.zeros((1,3), dtype=np.float32), dtype=wp.vec3, requires_grad=True)
quat_real = wp.array(np.ones(1, dtype=np.float32), dtype=wp.float32, requires_grad=True)
log_scales = wp.array(np.log(0.1 * np.ones((1,3), dtype=np.float32)), dtype=wp.vec3, requires_grad=True)

losses = wp.zeros(warp_observed_points.shape[0], dtype=wp.float32, requires_grad=True)
print(losses.numpy())
score_fn(warp_observed_points, mean, quat_imag, quat_real, log_scales, losses) # Using points instead of warp_observed_points
print(losses.numpy())

from warp.optim import Adam
from condorgmm.warp_gmm.adam import Adam as Adam2

# Need to include all parameters that require gradients
params_to_optimize = [mean, log_scales]
optimizer = Adam2(params_to_optimize, lr=[0.01, 0.01])
 
condorgmm.rr_set_time(0)
condorgmm.rr_log_cloud(warp_observed_points.numpy(), channel="observed_points")
 
backward = wp.ones(len(losses), dtype=wp.float32, requires_grad=True)

num_steps = 20000
pbar = tqdm(range(num_steps))
scores = []
for step in pbar:
    loss = wp.zeros(1, dtype=wp.float32, requires_grad=True)
    tape = wp.Tape()
    with tape:
        # Using points instead of warp_observed_points
        score_fn(warp_observed_points, mean, quat_imag, quat_real, log_scales, losses)
    tape.backward(grads={losses: backward})
    optimizer.step([x.grad for x in params_to_optimize])
    logpdf = losses.numpy().sum()
    pbar.set_description(f"Loss: {logpdf}")
    scores.append(logpdf)
    tape.zero()
    wp.synchronize()
    
    condorgmm.rr_set_time(step)

    rr.log(
        "gmm_warp",
        rr.Ellipsoids3D(
            centers=mean.numpy(),
            half_sizes=np.exp(log_scales.numpy()),
            quaternions=np.concatenate([quat_imag.numpy(), quat_real.numpy()[...,None]], axis=-1),
        ),
    )
plt.plot(scores)

In [None]:
import condorgmm.warp_gmm as warp_gmm
import condorgmm.warp_gmm.kernels
condorgmm.rr_init("gradient_test")

from tqdm import tqdm

gmm = warp_gmm.gmm_warp.gmm_warp_from_numpy(
    spatial_means=np.array([[0.0, 0.0, 2.0]],dtype=np.float32),
    rgb_means=np.array([[0.0, 0.0, 0.0]],dtype=np.float32),
    log_spatial_scales=np.log(np.array([[0.05, 0.05, 0.05]],dtype=np.float32)),
    # quaternions_imaginary=np.array([[0.0, 0.1, -1.0]],dtype=np.float32),
)
warp_gmm_state = warp_gmm.initialize_state(gmm=gmm, frame=frame)
warp_gmm_state.hyperparams.window_half_width = 20
warp_gmm_state.hyperparams.outlier_volume = 1e6
warp_gmm_state.hyperparams.outlier_probability = 0.001
condorgmm.rr_set_time(0)
condorgmm.rr_log_frame(frame)

warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)
warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
plt.matshow(warp_gmm_state.log_score_image.numpy())

warp_gmm_state.gmm.spatial_means.requires_grad = True
warp_gmm_state.gmm.log_spatial_scales.requires_grad = True
warp_gmm_state.gmm.quaternions_imaginary.requires_grad = True
warp_gmm_state.gmm.quaternions_real.requires_grad = True
warp_gmm_state.gmm.rgb_means.requires_grad = True
print(warp_gmm_state.gmm.spatial_means.numpy())


from warp.optim import SGD
params_to_optimize = [warp_gmm_state.gmm.log_spatial_scales]
optimizer = Adam2(params_to_optimize, lr=[1e-3])
 
condorgmm.rr_set_time(0)
condorgmm.rr_log_cloud(warp_observed_points.numpy(), channel="observed_points")

num_steps = 20000
pbar = tqdm(range(num_steps))
scores = []
for step in pbar:
    tape = wp.Tape()
    with tape:
        condorgmm.warp_gmm.kernels.warp_gmm_forward(
            frame_warp,
            warp_gmm_state,
        )

    tape.backward(grads={warp_gmm_state.log_score_image: warp_gmm_state.backward})
    optimizer.step([x.grad for x in params_to_optimize])
    loss = warp_gmm_state.log_score_image.numpy().sum()
    pbar.set_description(f"Loss: {loss}")
    scores.append(loss)
    tape.zero()
    wp.synchronize()

    if step % 100 == 0:
        condorgmm.rr_set_time(step)
        warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm, size_scalar=1.0)

plt.plot(scores)

In [None]:
from scipy.spatial.transform import Rotation as R

# Sample from the GMM
num_samples = 1000
samples = np.zeros((num_samples, 3), dtype=np.float32)

# Get parameters from the GMM
spatial_mean = warp_gmm_state.gmm.spatial_means.numpy()[0]
spatial_scale = np.exp(warp_gmm_state.gmm.log_spatial_scales.numpy()[0])
quat_imag = warp_gmm_state.gmm.quaternions_imaginary.numpy()[0]
quat_real = warp_gmm_state.gmm.quaternions_real.numpy()[0]

# Construct rotation matrix from quaternion
quat = np.concatenate([quat_imag, [quat_real]])
rot_matrix = R.from_quat(quat).as_matrix()

# Generate samples
for i in range(num_samples):
    # Sample from standard normal
    z = np.random.normal(0, 1, 3).astype(np.float32)
    
    # Scale, rotate and translate
    sample = rot_matrix @ (spatial_scale * z) + spatial_mean
    samples[i] = sample

# Visualize samples
condorgmm.rr_set_time(0)
condorgmm.rr_log_cloud(samples, channel="gmm_samples")


In [None]:
# Fit analytically a gaussian to the original points
points = condorgmm.xyz_from_depth_image(frame.depth, *frame.intrinsics).reshape(-1,3)
mean = np.mean(points, axis=0)
cov = np.cov(points, rowvar=False)

# Draw samples from the gaussian
samples = np.random.multivariate_normal(mean, cov, size=num_samples)
condorgmm.rr_log_cloud(samples, channel="gmm_samples_analytical")

In [None]:
import condorgmm.warp_gmm as warp_gmm
import condorgmm.warp_gmm.kernels

from tqdm import tqdm

condorgmm.rr_set_time(0)
condorgmm.rr_log_frame(frame)

gmm = warp_gmm.gmm_warp.gmm_warp_from_numpy(
    spatial_means=np.array([[0.0, 0.0, 1.0]],dtype=np.float32),
    rgb_means=np.array([[0.0, 0.0, 0.0]],dtype=np.float32),
)
warp_gmm_state = warp_gmm.initialize_state(gmm=gmm, frame=frame)
warp_gmm_state.hyperparams.window_half_width = 20
warp_gmm_state.hyperparams.outlier_volume = 1e6
warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)


for _ in range(5):
    warp_gmm.warp_gmm_EM_step(frame_warp, warp_gmm_state)
warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)
warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)

plt.matshow(warp_gmm_state.log_score_image.numpy())

In [None]:
import condorgmm.warp_gmm as warp_gmm
import condorgmm.warp_gmm.kernels
condorgmm.rr_init("gradient_test")

from tqdm import tqdm

gmm = warp_gmm.gmm_warp.gmm_warp_from_numpy(
    spatial_means=np.array([[0.0, 0.0, 2.1]],dtype=np.float32),
    rgb_means=np.array([[0.0, 0.0, 0.0]],dtype=np.float32),
    log_spatial_scales=np.log(np.array([[0.1, 0.1, 0.1]],dtype=np.float32)),
    # quaternions_imaginary=np.array([[0.0, 0.1, -1.0]],dtype=np.float32),
)
warp_gmm_state = warp_gmm.initialize_state(gmm=gmm, frame=frame)
warp_gmm_state.hyperparams.window_half_width = 20
warp_gmm_state.hyperparams.outlier_volume = 1e6
warp_gmm_state.hyperparams.outlier_probability = 0.0
condorgmm.rr_set_time(0)
condorgmm.rr_log_frame(frame)

warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)
warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
plt.matshow(warp_gmm_state.log_score_image.numpy())

warp_gmm_state.gmm.spatial_means.requires_grad = True
warp_gmm_state.gmm.log_spatial_scales.requires_grad = True
warp_gmm_state.gmm.quaternions_imaginary.requires_grad = True
warp_gmm_state.gmm.quaternions_real.requires_grad = True
warp_gmm_state.gmm.rgb_means.requires_grad = True
params_to_optimize = [warp_gmm_state.gmm.spatial_means, warp_gmm_state.gmm.log_spatial_scales]
optimizer = Adam(params_to_optimize, lr=[1e-10, 1e-2])
print(warp_gmm_state.gmm.spatial_means.numpy())

In [None]:
num_timesteps = 1000
pbar = tqdm(range(num_timesteps)) if tqdm else range(num_timesteps)
scores = []
for step in pbar:
    tape = wp.Tape()
    with tape:
        condorgmm.warp_gmm.kernels.warp_gmm_forward(
            frame_warp,
            warp_gmm_state,
        )

    tape.backward(grads={warp_gmm_state.log_score_image: warp_gmm_state.backward})
    # for x in params_to_optimize:
    #     print(x.grad.numpy(), x.grad.numpy())

    optimizer.step([x.grad for x in params_to_optimize])

    # quaternions = np.concatenate([warp_gmm_state.gmm.quaternions_imaginary.numpy(), warp_gmm_state.gmm.quaternions_real.numpy()[...,None]], axis=-1)
    # quaternions = quaternions / np.linalg.norm(quaternions, axis=-1, keepdims=True)
    # warp_gmm_state.gmm.quaternions_imaginary = wp.array(quaternions[:, :3],requires_grad=True, dtype=wp.vec3)
    # warp_gmm_state.gmm.quaternions_real = wp.array(quaternions[:, 3],requires_grad=True, dtype=wp.float32)

    warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
    score = warp_gmm_state.log_score_image.numpy().sum()
    scores.append(score)
    pbar.set_description(f"Log score: {score}")

    condorgmm.rr_set_time(step)
    warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)

warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
plt.matshow(warp_gmm_state.log_score_image.numpy())

In [None]:
plt.plot(scores)

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit
import numpy as np
import rerun as rr

condorgmm.rr_init("jax")

# Generate some random 3D point cloud data
key = jax.random.PRNGKey(0)
observed_points = jnp.array(condorgmm.xyz_from_depth_image(frame.depth, *frame.intrinsics).reshape(-1,3)) 

condorgmm.rr_log_cloud(observed_points, channel="observed_points")
# n_points = 1000
# observed_points = jax.random.normal(key, (n_points, 3)) * 0.1 + jnp.array([1.0, 2.0, -0.5])


# Initialize parameters
def init_params():
    return {
        'log_scales': jnp.log(jnp.ones(3) * 0.1),  # log of scale parameters
        'quat': jnp.array([0.0, 0.0, 0.0, 1.0]),  # quaternion [x,y,z,w]
        'mean': jnp.array([0., 0., 2.1])  # mean position
    }

# Quaternion rotation
def quat_rotate(q, v):
    qx, qy, qz, qw = q
    return 2.0 * (
        v * (qw * qw - 0.5) +
        jnp.cross(jnp.array([qx, qy, qz]), v) * qw +
        jnp.array([qx, qy, qz]) * jnp.dot(jnp.array([qx, qy, qz]), v)
    )

# Negative log likelihood for a single point
def point_nll(params, point):
    # Extract parameters
    scales = jnp.exp(params['log_scales'])
    quat = params['quat'] / jnp.linalg.norm(params['quat'])  # normalize quaternion
    mean = params['mean']

    # Construct covariance matrix using rotation and scales
    rot_matrix = jnp.array([
        [1 - 2*(quat[1]**2 + quat[2]**2), 2*(quat[0]*quat[1] - quat[2]*quat[3]), 2*(quat[0]*quat[2] + quat[1]*quat[3])],
        [2*(quat[0]*quat[1] + quat[2]*quat[3]), 1 - 2*(quat[0]**2 + quat[2]**2), 2*(quat[1]*quat[2] - quat[0]*quat[3])],
        [2*(quat[0]*quat[2] - quat[1]*quat[3]), 2*(quat[1]*quat[2] + quat[0]*quat[3]), 1 - 2*(quat[0]**2 + quat[1]**2)]
    ])
    
    scale_matrix = jnp.diag(scales)
    cov = rot_matrix @ jnp.diag(scales**2) @ rot_matrix.T
    
    # Compute multivariate Gaussian negative log likelihood
    centered = point - mean
    log_det = jnp.sum(jnp.log(scales)) * 2  # log determinant of covariance
    mahalanobis = jnp.sum(centered * (jnp.linalg.solve(cov, centered)))

    return 0.5 * (3 * jnp.log(2 * jnp.pi) + log_det + mahalanobis)

# Total loss for all points
@jit
def total_loss(params, points):
    return jnp.mean(jax.vmap(lambda p: point_nll(params, p))(points))

# Gradient function
grad_fn = jit(grad(total_loss))

points = observed_points

params = init_params()

# log
rr.log(
    "jax_gmm_initial",
    rr.Ellipsoids3D(
        centers=jnp.array([params['mean']]),
        half_sizes=jnp.exp(params['log_scales']),
        quaternions=params['quat'],
    ),
)



# Initialize Adam optimizer state
from jax.example_libraries.optimizers import adam
opt_init, opt_update, get_params = adam(step_size=0.01)
opt_state = opt_init(params)

@jax.jit
def update_params(opt_state, points):
    grads = grad_fn(get_params(opt_state), points)
    opt_state = opt_update(i, grads, opt_state)
    return opt_state

num_steps = 5000
pbar = tqdm(range(num_steps))
scores = []
for i in pbar:

    opt_state = update_params(opt_state, points)   
    params = get_params(opt_state)
    loss = total_loss(params, points)
    scores.append(loss)
    pbar.set_description(f"Loss: {loss:.4f}")
    condorgmm.rr_set_time(i)
    # log
    rr.log(
        "jax_gmm_fitted",
        rr.Ellipsoids3D(
            centers=jnp.array([params['mean']]),
            half_sizes=jnp.exp(params['log_scales']),
            quaternions=params['quat'],
        ),
    )



print("\nFitted parameters:")
print("Scales:", jnp.exp(params['log_scales']))
print("Quaternion:", params['quat'])
print("Mean:", params['mean'])






# log
rr.log(
    "jax_gmm_fitted",
    rr.Ellipsoids3D(
        centers=jnp.array([params['mean']]),
        half_sizes=jnp.exp(params['log_scales']),
        quaternions=params['quat'],
    ),
)

plt.plot(scores)

In [None]:
params

In [None]:
np.concatenate([quat_imag.numpy(), quat_real.numpy()[...,None]], axis=-1)

In [None]:
warp_observed_points.numpy()