In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from IPython.display import Markdown as md
import bayes3d._mkl.trimesh_to_gaussians

md(bayes3d._mkl.trimesh_to_gaussians._doc_)


# Trimesh to Gaussians
> Pretty much self-explanatory

**Example:**
```python
from bayes3d._mkl.trimesh_to_gaussians import (
    patch_trimesh, 
    uniformly_sample_from_mesh, 
    ellipsoid_embedding, 
    get_mean_colors, 
    pack_transform
)
import trimesh
import numpy as np
import jax.numpy as jnp
import jax
from jax import jit, vmap
from sklearn.mixture import GaussianMixture
from bayes3d._mkl.utils import keysplit

# SEED
key = jax.random.PRNGKey(0)

# LOAD MESH
# -------------------
mesh = load_mesh(...)
mesh = patch_trimesh(mesh)

# SAMPLE FROM MESH
# ----------------
key = keysplit(key)
n = 20_000
xs, cs = uniformly_sample_from_mesh(key, n, mesh, with_color=True)

# GMM CONFIG
# ----------
key = keysplit(key)
n_components = 150
noise        = 0.0; 
X            = xs + np.random.randn(*xs.shape)*noise
means_init   = np.array(uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]);

# FIT THE GMM
# -----------
gm = GaussianMixture(n_components=n_components, 
                     tol=1e-3, max_iter=100, 
                     covariance_type="full", 
                     means_init=means_init).fit(X)

mus        = gm.means_
covs       = gm.covariances_
labels     = gm.predict(X)
choleskys  = vmap(ellipsoid_embedding)(covs)
transforms = vmap(pack_transform, (0,0,None))(mus, choleskys, 2.0)
mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)
```


In [3]:
from bayes3d._mkl.trimesh_to_gaussians import (
    patch_trimesh, 
    sample_from_mesh,
    barycentric_to_mesh as _barycentric_to_mesh,
    uniformly_sample_from_mesh, 
    ellipsoid_embedding, 
    get_mean_colors, 
    pack_transform,
    transform_from_gaussian
)
import trimesh
import numpy as np
import jax.numpy as jnp
import jax
from jax import jit, vmap
from sklearn.mixture import GaussianMixture
from bayes3d._mkl.utils import keysplit


# SEED
key = jax.random.PRNGKey(0)

barycentric_to_mesh = vmap(_barycentric_to_mesh, (0,0,None))

In [4]:
import traceviz.client
import numpy as np
from traceviz.proto import  viz_pb2
import json
import matplotlib.pyplot as plt

In [5]:
def load_mesh(t):
    mesh = trimesh.load(f"data/flag_objs/flag_t_{t}.obj")
    mesh = patch_trimesh(mesh)
    return mesh

In [6]:
# LOAD MESH
# -------------------
mesh = load_mesh(0)

In [7]:
print(mesh.vertices.max(axis=0) - mesh.vertices.min(axis=0))

[2.797654 2.322004 0.816412]


In [11]:
# SAMPLE FROM MESH
# ----------------
key = keysplit(key)
n = 100
xs, ps, fs = sample_from_mesh(key, n, mesh)

In [None]:
def fit(key, mesh, means_init, precisions_init, covariance_type="full", iter=20, noise=0.0):
    
    # SAMPLE FROM MESH
    # ----------------
    _, key = keysplit(key, 1, 1)
    n = 20_000
    xs, cs = uniformly_sample_from_mesh(key, n, mesh, with_color=True)

    # GMM CONFIG
    # ----------
    key = keysplit(key)
    n_components = means_init.shape[0]
    X            = xs + np.random.randn(*xs.shape)*noise

    # FIT THE GMM
    # -----------
    gm = GaussianMixture(n_components=n_components, 
                        tol=1e-3, max_iter=iter, 
                        covariance_type=covariance_type, 
                        means_init=means_init,
                        precisions_init=precisions_init).fit(X)

    mus    = gm.means_
    if gm.covariance_type == "spherical":
        covs = gm.covariances_[:,None,None]*jnp.eye(3)[None,:,:]
    else:
        covs       = gm.covariances_
    labels = gm.predict(X)
    mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)

    return mus, covs, mean_colors

In [89]:
# SAMPLE FROM MESH
# ----------------
key = keysplit(key)
n = 10_000
xs, cs = uniformly_sample_from_mesh(key, n, mesh, with_color=True)

# GMM CONFIG
# ----------
key = keysplit(key)
n_components = 100
noise        = 0.01; 
X            = xs + np.random.randn(*xs.shape)*noise
means_init   = np.array(uniformly_sample_from_mesh(key, n_components, mesh, with_color=False)[0]);

# FIT THE GMM
# -----------
gm = GaussianMixture(n_components=n_components, 
                     tol=1e-3, max_iter=100, 
                     covariance_type="spherical", 
                     means_init=means_init).fit(X)


mus        = gm.means_
if gm.covariance_type == "spherical":
    covs = gm.covariances_[:,None,None]*jnp.eye(3)[None,:,:]
else:
    covs       = gm.covariances_
labels     = gm.predict(X)
transforms = vmap(transform_from_gaussian, (0,0,None))(mus, covs, 2.0)
mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)

print(f"""
n: {n}
n_components: {n_components}
noise: {noise}
""")


n: 10000
n_components: 100
noise: 0.01



In [90]:
msg = viz_pb2.Message()
msg.payload.json = json.dumps({"type": "setup"})
msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))
stub = traceviz.client.connect()
print('response: ', stub.Broadcast(msg))

msg = viz_pb2.Message()
msg.payload.json = json.dumps({"type": "Gaussians2"})
msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({
        'transforms': np.array(transforms )[nums>0], 
        'colors':     np.array(mean_colors)[nums>0]  
}))
stub = traceviz.client.connect()
print('response: ', stub.Broadcast(msg))


response:  listener_identifiers: "2023-11-22T15:51:30.513873 :: ipv4:127.0.0.1:53162"

response:  listener_identifiers: "2023-11-22T15:51:30.513873 :: ipv4:127.0.0.1:53162"



In [91]:
def fit(key, mesh, means_init, precisions_init, covariance_type="full", iter=20, noise=0.0):
    
    # SAMPLE FROM MESH
    # ----------------
    _, key = keysplit(key, 1, 1)
    n = 20_000
    xs, cs = uniformly_sample_from_mesh(key, n, mesh, with_color=True)

    # GMM CONFIG
    # ----------
    key = keysplit(key)
    n_components = means_init.shape[0]
    X            = xs + np.random.randn(*xs.shape)*noise

    # FIT THE GMM
    # -----------
    gm = GaussianMixture(n_components=n_components, 
                        tol=1e-3, max_iter=iter, 
                        covariance_type=covariance_type, 
                        means_init=means_init,
                        precisions_init=precisions_init).fit(X)

    mus    = gm.means_
    if gm.covariance_type == "spherical":
        covs = gm.covariances_[:,None,None]*jnp.eye(3)[None,:,:]
    else:
        covs       = gm.covariances_
    labels = gm.predict(X)
    mean_colors, nums = get_mean_colors(cs, gm.n_components, labels)

    return mus, covs, mean_colors

In [92]:
covs.shape
prec = vmap(jnp.linalg.inv)(CVs[-1])
prec.shape

(100, 3, 3)

In [93]:
MUs = [mus]
CVs = [covs]
CLs = [mean_colors]

for t in range(1,5):
    print(t, end="\r")
    mesh = load_mesh(t)
    key = keysplit(key)

    mus, covs, mean_colors = fit(key, mesh, MUs[-1], vmap(jnp.linalg.inv)(CVs[-1])[:,0,0], covariance_type="spherical", iter=10, noise=0.0)
    MUs.append(mus)
    CVs.append(covs)
    CLs.append(mean_colors)

1



2



3



4

In [103]:
t = 4
mus = MUs[t]
covs = CVs[t]
mean_colors = CLs[t]
transforms = vmap(transform_from_gaussian, (0,0,None))(mus, covs, 3.0)

colors = 0.4*jnp.ones_like(mean_colors)
colors = colors.at[:,3].set(1.)


msg = viz_pb2.Message()
msg.payload.json = json.dumps({"type": "setup"})
msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg([]))
stub = traceviz.client.connect()
print('response: ', stub.Broadcast(msg))

msg = viz_pb2.Message()
msg.payload.json = json.dumps({"type": "Gaussians2"})
msg.payload.data.MergeFrom(traceviz.client.to_pytree_msg({
        'transforms': np.array(transforms ), 
        'colors':     np.array(colors)
}))
stub = traceviz.client.connect()
print('response: ', stub.Broadcast(msg))


response:  listener_identifiers: "2023-11-22T16:16:10.810408 :: ipv4:127.0.0.1:38474"

response:  listener_identifiers: "2023-11-22T16:16:10.810408 :: ipv4:127.0.0.1:38474"

