In [26]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import jax
import jax.numpy as jnp
import jax.random as jr
from jax.tree_util import Partial

import hugegp as gp
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

rng = jr.key(137)

In [28]:
covariance = Partial(gp.matern_cov_p, p=0, sigma=1.0, cutoff=1.0)

In [None]:
rng, key = jr.split(rng)
original_points = jr.normal(key, shape=(1_000_000, 2))
# original_points = jr.normal(key, shape=(2**16-1, 2))
# original_points = jnp.mgrid[-3:3:300j, -3:3:300j].reshape(2, -1).T
points, neighbors, offsets, indices = gp.build_strict_graph(original_points, n0=2000, k=10)
gp.check_graph(points, neighbors, offsets)
level_sizes = jnp.diff(jnp.array((0,) + offsets))
print(f'Depth: {len(offsets)}')

strict_generate_func = jax.jit(lambda p, n, x: gp.generate(p, n, offsets, covariance, x))

2025-08-28 11:26:17.206130: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %concatenate = s32[9998000,3]{1,0} concatenate(%iota.1, %constant.121, %broadcast.472), dimensions={1}, metadata={op_name="jit(query_preceding_neighbors)/jit(main)/vmap(while)/body/gather" source_file="/home/bendodge/hugegp/hugegp/tree.py" source_line=47}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-08-28 11:26:17.581921: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.375986127s
Constant folding an instruction is taking > 1s:

  %concatenate = s32[9998000,3]{1,0} concate

In [17]:
jax.jit(gp.generate_logdet, static_argnums=(2,3))(points, neighbors, offsets, covariance)

Array(-2.8253562e+06, dtype=float32)

In [8]:
for i in tqdm(range(100)):
    rng, key = jr.split(rng)
    xi = jr.normal(key, shape=(len(original_points),))
    values = strict_generate_func(points, neighbors, xi)
    values.block_until_ready()

  0%|          | 0/100 [00:00<?, ?it/s]

In [24]:
points, neighbors, offsets, indices = gp.build_lazy_graph(original_points, n0=2000, k=10, factor=1.1)
gp.check_graph(points, neighbors, offsets)
level_sizes = jnp.diff(jnp.array((0,) + offsets))
print(f'Depth: {len(offsets)}')

lazy_generate_func = jax.jit(lambda p, n, x: gp.generate(p, n, offsets, covariance, x))

Depth: 67


In [25]:
jax.jit(gp.generate_logdet, static_argnums=(2,3))(points, neighbors, offsets, covariance)

Array(-2.824382e+06, dtype=float32)

In [11]:
for i in tqdm(range(100)):
    rng, key = jr.split(rng)
    xi = jr.normal(key, shape=(len(original_points),))
    values = lazy_generate_func(points, neighbors, xi)
    values.block_until_ready()

  0%|          | 0/100 [00:00<?, ?it/s]