# STATIC: Sparse Transition-Accelerated Trie Index for Constrained Decoding (JAX)

This notebook demonstrates the basic usage of the STATIC library for high-throughput constrained decoding on hardware accelerators. We start with some basic imports.

In [6]:
import jax
import jax.numpy as jnp
import numpy as np
from csr_utils import build_sparse_matrix_fast
from static.decoding_jax import sparse_transition_packed_jit

## (1) Index Construction

We synthesize a set of valid Semantic IDs (SIDs) and build the static index using $d=2$ dense lookup layers.

In [39]:
# Parameters: vocab size, sid sequence length, and number of SIDs in the restrict vocabulary
V, L, N = 2048, 8, 10000

# Generate random unique SIDs and sort them lexically
valid_sids = np.random.randint(0, V, size=(N, L), dtype=np.int32)
valid_sids = np.unique(sids, axis=0)
valid_sids = sids[np.lexsort([sids[:, i] for i in range(L-1, -1, -1)])]

# Build STATIC Index
packed_csr, indptr, lmb, start_mask, d_mask, d_states = build_sparse_matrix_fast(valid_sids, V, d=2)

## (2) Hardware-accelerated Decoding

The constrained beam search execution is JIT-compiled to TPU machine code, ensuring that the constraint masking logic adds negligible overhead to the inference time.

In [40]:
# Move components to accelerator device (JAX Arrays)
packed_csr_j = jnp.array(packed_csr)
indptr_j = jnp.array(indptr)
start_mask_j = jnp.array(start_mask)
d_mask_j = jnp.array(d_mask)
d_states_j = jnp.array(d_states)

# Execute JIT-compiled constrained decoding
print("Executing constrained retrieval...")
key = jax.random.PRNGKey(42)
outputs = sparse_transition_packed_jit(
    key, batch_size=2, beam_size=10, tokens_per_beam=20, 
    pad_token=0, max_sample_len=L, vocab_size=V, 
    max_branch_factors=lmb, packed_csr=packed_csr_j, 
    csr_indptr=indptr_j, start_mask=start_mask_j, 
    dense_mask=d_mask_j, dense_states=d_states_j, d_dense=2
)
print(f"Decoding complete. Output shape: {outputs.shape}")

Executing constrained retrieval...
Decoding complete. Output shape: (2, 10, 8)


In [41]:
# print a few semantic ids from the first batch
print(f'Sampled SIDs:\n{outputs[0, :5]}')

Sampled SIDs:
[[1322 1862 1880  826  683  860 1321 1808]
 [1583  871   18  999 1634  640  230 1053]
 [ 746 1198  836  136  721 1601 1556 1851]
 [1511  461  203  864  582  937  620 1067]
 [1820 1323 1198 1274   93  314  656  841]]


## (3) Verification

Finally, we verify that 100% of the decoded sequences strictly adhere to the input vocabulary.

In [42]:
# 1. Convert JAX array to NumPy for host-side verification
decoded_sids_np = np.array(outputs).reshape(-1, L)

# 2. Re-create the valid set from the original NumPy sids
valid_set = {tuple(row) for row in valid_sids}

# 3. Perform the count using the NumPy-converted results
valid_count = sum(1 for sid in decoded_sids_np if tuple(sid) in valid_set)

print(f"Verification: {valid_count}/{decoded_sids_np.shape[0]} beams are valid.")
assert valid_count == decoded_sids_np.shape[0], "Error: Constraint violation detected!"

Verification: 20/20 beams are valid.
