In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import jax.numpy as jnp
from genjax import Mask
from condorgmm.utils.jax import unproject
from condorgmm.condor.types import Intrinsics
import condorgmm.condor.tiling as t

In [None]:
config = t.GridTilingConfig(
    tile_size_x=2,
    tile_size_y=2,
    intrinsics=Intrinsics(
        2.0, 2.0, 4.0, 2.0, 1e-5, 1e5, image_height=8, image_width=4
    ),
    n_gaussians=16,
    max_n_gaussians_per_tile=4,
)

In [None]:
pixel_coords_og = jnp.array([
    [y + 0.5, x + 0.5] for y in range(8) for x in range(4)
])
depths = jnp.arange(32) + 1
i = config.intrinsics
coords_3d = jax.vmap(lambda x, y, z: unproject(x, y, z, i.fx, i.fy, i.cx, i.cy))(
    pixel_coords_og[:, 0], pixel_coords_og[:, 1], depths
)
t.GridTiling.from_gaussian_means(config, coords_3d).tile_to_gaussians

In [None]:
from condorgmm.utils.jax import xyz_to_pixel_coordinates
intrinsics = config.intrinsics
pixel_coords = xyz_to_pixel_coordinates(
    coords_3d,
    intrinsics.fx,
    intrinsics.fy,
    intrinsics.cx,
    intrinsics.cy,
)

In [None]:
gaussian_to_tile = jax.vmap(config.pixel_coordinate_to_tile_index)(
    pixel_coords[:, 0],
    pixel_coords[:, 1],
)

In [None]:
gaussian_to_tile.shape

In [None]:
M=config.max_n_gaussians_per_tile
key = jax.random.key(0)
K=64
R=4

In [None]:
n_tiles_y, n_tiles_x = config.n_tiles_y, config.n_tiles_x
n_gaussians = gaussian_to_tile.shape[0]

In [None]:
tile_to_gaussian_large = -jnp.ones((n_tiles_y, n_tiles_x, K), dtype=int)
gaussian_to_R_idxs = jax.random.randint(key, (n_gaussians, R), 0, K)

In [None]:
gaussian_to_R_idxs

In [None]:
tile_to_gaussian_large.shape

In [None]:
jnp.repeat(gaussian_to_tile[:, 0:1], R, axis=1).shape

In [None]:
tgl2 = tile_to_gaussian_large.at[
    jnp.repeat(gaussian_to_tile[:, 0:1], R, axis=1),
    jnp.repeat(gaussian_to_tile[:, 1:2], R, axis=1),
    gaussian_to_R_idxs
].set(jnp.arange(n_gaussians)[:, None])

In [None]:
jax.vmap(
    lambda a: jax.vmap(
        lambda idxs_large: t.GridTiling._compress_to_M(idxs_large, M)
    )(a)
)(tgl2)