# STATIC: Sparse Transition-Accelerated Trie Index for Constrained Decoding (JAX)

This notebook demonstrates the basic usage of the STATIC library for high-throughput constrained decoding on hardware accelerators. We start with some basic imports.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from static_decoding.csr_utils import build_static_index
from static_decoding.decoding_jax import RandomModel
from static_decoding.decoding_jax import sparse_transition_jax

## (1) Index Construction

We synthesize a set of valid Semantic IDs (SIDs) and build the static index using $d=2$ (`dense_lookup_layers=2`).

In [None]:
# Parameters: vocab size, sid sequence length, and number of SIDs in the restrict vocabulary
V, L, N = 2048, 8, 10000

# Generate random unique SIDs and sort them lexically
valid_sids = np.random.randint(0, V, size=(N, L), dtype=np.int32)
valid_sids = np.unique(valid_sids, axis=0)
valid_sids = valid_sids[np.lexsort([valid_sids[:, i] for i in range(L-1, -1, -1)])]

# Build STATIC Index
packed_csr, indptr, layer_max_branches, start_mask, d_mask, d_states = build_static_index(valid_sids, V, dense_lookup_layers=2)

## (2) Hardware-accelerated Decoding

The constrained beam search execution is JIT-compiled to TPU machine code, ensuring that the constraint masking logic adds negligible overhead to the inference time.

In [None]:
# Move components to accelerator device (JAX Arrays)
packed_csr_j = jnp.array(packed_csr)
indptr_j = jnp.array(indptr)
start_mask_j = jnp.array(start_mask)
d_mask_j = jnp.array(d_mask)
d_states_j = jnp.array(d_states)

# Instantiate the Mock Model
# In a real scenario, this would be your Flax model wrapper
model = RandomModel(vocab_size=V)

# Execute JIT-compiled constrained decoding
print("Executing constrained retrieval...")
key = jax.random.PRNGKey(42)

outputs = sparse_transition_jax(
    model=model,
    key=key,
    batch_size=2,
    beam_size=10,
    tokens_per_beam=20,
    start_token=0,
    max_sample_len=L,
    vocab_size=V,
    max_branch_factors=layer_max_branches,
    packed_csr=packed_csr_j,
    csr_indptr=indptr_j,
    start_mask=start_mask_j,
    dense_mask=d_mask_j,
    dense_states=d_states_j,
    d_dense=2
)
print(f"Decoding complete. Output shape: {outputs.shape}")

In [4]:
# print a few semantic ids from the first batch
print(f'Sampled SIDs:\n{outputs[0, :5]}')

Sampled SIDs:
[[ 636  594  492 1507  975 1279 1217 1336]
 [1452 1444 1944 1301 1245  886  738  783]
 [1511 1513  744 1899 1913 1397  147  854]
 [1620  139  893 1747  797  687 1003  429]
 [1583  962 1214 1765 1498  445   80   30]]


## (3) Verification

Finally, we verify that 100% of the decoded sequences strictly adhere to the input vocabulary.

In [5]:
# 1. Convert JAX array to NumPy for host-side verification
decoded_sids_np = np.array(outputs).reshape(-1, L)

# 2. Re-create the valid set from the original NumPy sids
valid_set = {tuple(row) for row in valid_sids}

# 3. Perform the count using the NumPy-converted results
valid_count = sum(1 for sid in decoded_sids_np if tuple(sid) in valid_set)

print(f"Verification: {valid_count}/{decoded_sids_np.shape[0]} beams are valid.")
assert valid_count == decoded_sids_np.shape[0], "Error: Constraint violation detected!"

Verification: 20/20 beams are valid.
