In [13]:
from functools import partial
from pprint import pprint

import jax
import jax.numpy as jnp

import tune_jax

In [2]:
!uv pip install -U --force-reinstall -q tune-jax
# or
# !uv pip install -U git+https://github.com/rdyro/tune-jax.git

In [3]:
@partial(jax.jit, static_argnames=("steps",))
def fn(x, W, steps=20):
  for i in range(steps):
    x = (x @ W) / jnp.linalg.norm(x, axis=-1, keepdims=True)
  return x

### Just timing the function

In [4]:
x = jnp.ones((128, 1024))
W = jnp.ones((1024, 1024))

fn_tune = tune_jax.tune(fn)
fn_tune(x, W)
fn_tune.timing_results

{0: TimingResult(hyperparams={}, t_mean=3.7385333333333336e-05, t_std=2.3795424396767122e-08)}

In [None]:
fn_tune(x, W).shape  # no retuning on no shape change

(128, 1024)

### Tuning

In [5]:
x = jnp.ones((16, 1024))
W = jnp.ones((1024, 1024))

hyperparams = {"steps": [1, 2, 3, 30, 50]}
fn_tune = tune_jax.tune(fn, hyperparams=hyperparams)
fn_tune(x, W)
print(fn_tune.timing_results)
print(tune_jax.tabulate(fn_tune))

{0: TimingResult(hyperparams={'steps': 1}, t_mean=1.978333333333333e-06, t_std=1.247219128924741e-09), 1: TimingResult(hyperparams={'steps': 2}, t_mean=2.891666666666667e-06, t_std=2.819377393838738e-08), 2: TimingResult(hyperparams={'steps': 3}, t_mean=3.924333333333333e-06, t_std=9.177266598624108e-09), 3: TimingResult(hyperparams={'steps': 30}, t_mean=2.7376e-05, t_std=1.4142135623730016e-09), 4: TimingResult(hyperparams={'steps': 50}, t_mean=4.4640999999999995e-05, t_std=1.1032074449833296e-07)}
  id    steps    t_mean (s)    t_std (s)
----  -------  ------------  -----------
   0        1    1.9783e-06   1.2472e-09
   1        2    2.8917e-06   2.8194e-08
   2        3    3.9243e-06   9.1773e-09
   3       30    2.7376e-05   1.4142e-09
   4       50    4.4641e-05   1.1032e-07


### Config

In [14]:
pprint(tune_jax.CONFIG)

_Config(allow_fallback_timing=False,
        must_find_at_least_profiler_result_fraction=0.5,
        profiling_samples=5)


In [15]:
# how many samples to use for profiling (discards 1 slowest and 1 fastest
# sample) each xprof parsing takes at least 1-5 seconds so 5 might mean 25 s
# profiling even
tune_jax.CONFIG.profiling_samples = 5

In [16]:
# whether to throw an error if Python fallback profiling is necessary
tune_jax.CONFIG.allow_fallback_timing = False

In [17]:
# some hyperparam settings might not run correctly under profile / fail to parse from xprof
# what fraction should run before we fall back to Python timing
tune_jax.CONFIG.must_find_at_least_profiler_result_fraction = 0.5

### Subsampling

In [18]:
x = jnp.ones((16, 1024))
W = jnp.ones((1024, 1024))

hyperparams = {"steps": [1, 2, 3, 30, 50]}
fn_tune = tune_jax.tune(fn, hyperparams=hyperparams, sample_num=3)
fn_tune(x, W)
print(fn_tune.timing_results)
print(tune_jax.tabulate(fn_tune))

{0: TimingResult(hyperparams={'steps': 1}, t_mean=1.974333333333333e-06, t_std=1.699673171197584e-09), 1: TimingResult(hyperparams={'steps': 2}, t_mean=2.8996666666666665e-06, t_std=9.92382094871841e-08), 2: TimingResult(hyperparams={'steps': 3}, t_mean=3.694666666666667e-06, t_std=3.3993463423952513e-09)}
  id    steps    t_mean (s)    t_std (s)
----  -------  ------------  -----------
   0        1    1.9743e-06   1.6997e-09
   1        2    2.8997e-06   9.9238e-08
   2        3    3.6947e-06   3.3993e-09


### If all hyperparams fail

In [19]:
@partial(jax.jit, static_argnames=("steps",))
def fn(x, W, steps=20):
  raise ValueError("Refusing to compile")
  for i in range(steps):
    x = (x @ W) / jnp.linalg.norm(x, axis=-1, keepdims=True)
  return x

In [20]:
# if all hyperparameters fail to compile (a common error) prints
# "No hyperameters compiled successfully"
# and will print ALL errors – a wall of text
tune_jax.tune(fn, hyperparams={"steps": list(range(10))})(x, W)

ERROR:tune_jax:Hyperparameters (0,) failed to compile with message:
Traceback (most recent call last):
  File "/home/rdyro_google_com/venv/lib/python3.12/site-packages/tune_jax/tuning.py", line 109, in _try_call
    _ = jax.jit(fn).lower(*args_val, **kws_val).compile()
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro_google_com/venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro_google_com/venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 325, in jit_lower
    return jit_trace(jit_func, *args, **kwargs).lower()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro_google_com/venv/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/rdyro_google_com/venv/lib/python3.12/site-packages/

ValueError: No hyperparameters compiled successfully

### Logger config - e.g. , are inputs concrete?

In [21]:
tune_jax.logger.setLevel("DEBUG")

In [22]:
tune_jax.tune(lambda x: x)(1)

DEBUG:tune_jax:All arguments are concrete, no need to pick random values.
DEBUG:tune_jax:All keyword arguments are concrete, no need to pick random values.
Compiling...: 100%|██████████| 1/1 [00:00<00:00, 46.95it/s]
Profiling tpu:   0%|          | 0/5 [00:00<?, ?it/s]DEBUG:tune_jax:Saving optimization profile to `/tmp/tuning_profile_2025-09-09_17:42:17_qvtc60m1`
Profiling tpu: 100%|██████████| 5/5 [00:00<00:00,  6.35it/s]
DEBUG:tune_jax:
  id    t_mean (s)    t_std (s)
----  ------------  -----------
   0    5.9400e-07   8.1650e-10
DEBUG:tune_jax:optimal hyperparams: {}


1

In [24]:
jax.jit(tune_jax.tune(lambda x: x))(1)

DEBUG:tune_jax:Selecting random input arguments.
DEBUG:tune_jax:All keyword arguments are concrete, no need to pick random values.
Compiling...: 100%|██████████| 1/1 [00:00<00:00, 46.47it/s]
Profiling tpu:   0%|          | 0/5 [00:00<?, ?it/s]DEBUG:tune_jax:Saving optimization profile to `/tmp/tuning_profile_2025-09-09_17:42:36_5811u_2_`
Profiling tpu: 100%|██████████| 5/5 [00:00<00:00,  6.31it/s]
DEBUG:tune_jax:
  id    t_mean (s)    t_std (s)
----  ------------  -----------
   0    5.9400e-07   0.0000e+00
DEBUG:tune_jax:optimal hyperparams: {}


Array(1, dtype=int32, weak_type=True)