# STATIC Constrained Decoding with OpenOneRec-1.7B

This notebook demonstrates **STATIC** (Sparse Transition-Accelerated Trie Index for Constrained Decoding) applied to a real generative retrieval model — [OneRec-1.7B](https://huggingface.co/OpenOneRec/OneRec-1.7B) from OpenOneRec.

**Key idea:** OneRec generates 3-level Semantic IDs (SIDs) to retrieve items. STATIC enforces that every decoded SID belongs to a valid item catalog, using a hybrid dense + CSR constraint index that runs entirely on GPU with zero Python-loop overhead.

We use **synthetic SIDs** (real item catalogs aren't publicly available) to demonstrate that constraint enforcement works correctly with real model logits.

## 1. Install & Imports

In [4]:
!pip install -q transformers accelerate
!git clone https://github.com/google-research/static-constraint-decoding.git 2>/dev/null; pip install -q -e static-constraint-decoding

[31mERROR: static-constraint-decoding is not a valid editable requirement. It should either be a path to a local project or a VCS URL (beginning with bzr+http, bzr+https, bzr+ssh, bzr+sftp, bzr+ftp, bzr+lp, bzr+file, git+http, git+https, git+ssh, git+git, git+file, hg+file, hg+http, hg+https, hg+ssh, hg+static-http, svn+ssh, svn+http, svn+https, svn+svn, svn+file).[0m[31m
[0m

In [5]:
import torch
import torch.nn as nn
import numpy as np
from static_decoding.csr_utils import build_static_index
from static_decoding.decoding_pt import sparse_transition_torch
from transformers import AutoModelForCausalLM, AutoTokenizer

ModuleNotFoundError: No module named 'static_decoding'

## 2. Load OneRec-1.7B

OneRec-1.7B is a Qwen3-based model fine-tuned for generative retrieval. Its vocabulary extends Qwen3's base vocab (151,669 tokens) with three levels of SID tokens:

| Level | Tokens | Model IDs |
|-------|--------|-----------|
| a | `<s_a_0>` ... `<s_a_8191>` | `[base, base+8192)` |
| b | `<s_b_0>` ... `<s_b_8191>` | `[base+8192, base+16384)` |
| c | `<s_c_0>` ... `<s_c_8191>` | `[base+16384, base+24576)` |

followed by `<|sid_begin|>` and `<|sid_end|>` delimiters.

In [None]:
MODEL_NAME = "OpenOneRec/OneRec-1.7B"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True
)
model.eval()

device = next(model.parameters()).device
print(f"Model loaded on {device}")

# Dynamically detect SID token offsets from the tokenizer
s_a_base = tokenizer.convert_tokens_to_ids("<s_a_0>")
s_b_base = tokenizer.convert_tokens_to_ids("<s_b_0>")
s_c_base = tokenizer.convert_tokens_to_ids("<s_c_0>")
sid_begin_id = tokenizer.convert_tokens_to_ids("<|sid_begin|>")
sid_end_id = tokenizer.convert_tokens_to_ids("<|sid_end|>")

CODEBOOK_SIZE = 8192
assert s_b_base == s_a_base + CODEBOOK_SIZE, f"SID levels not contiguous: a={s_a_base}, b={s_b_base}"
assert s_c_base == s_b_base + CODEBOOK_SIZE, f"SID levels not contiguous: b={s_b_base}, c={s_c_base}"

print(f"SID offsets — Level a: {s_a_base}, Level b: {s_b_base}, Level c: {s_c_base}")
print(f"Special tokens — <|sid_begin|>: {sid_begin_id}, <|sid_end|>: {sid_end_id}")

## 3. Generate Synthetic SIDs & Build STATIC Index

Since the real item catalog is not publicly available, we generate ~50K random 3-level SIDs with `codebook_size=8192` and build the STATIC index with `d=2` dense layers (optimal for L=3: two dense layers cover levels 0–1, CSR handles level 2).

In [None]:
N_ITEMS = 50_000
L = 3  # 3-level SID structure

np.random.seed(42)
sids = np.random.randint(0, CODEBOOK_SIZE, size=(N_ITEMS, L), dtype=np.int32)
sids = np.unique(sids, axis=0)
sids = sids[np.lexsort([sids[:, i] for i in range(L - 1, -1, -1)])]
print(f"Generated {len(sids)} unique SIDs of length {L}")

# Build the STATIC index (hybrid dense + CSR)
packed_csr, indptr, lmb, start_mask, dense_mask, dense_states = build_static_index(
    sids, vocab_size=CODEBOOK_SIZE, d=2
)

# Move index tensors to GPU
packed_csr_t = torch.tensor(packed_csr, dtype=torch.long, device=device)
indptr_t = torch.tensor(indptr, dtype=torch.long, device=device)
start_mask_t = torch.tensor(start_mask, dtype=torch.bool, device=device)
dense_mask_t = torch.tensor(dense_mask, dtype=torch.bool, device=device)
dense_states_t = torch.tensor(dense_states, dtype=torch.long, device=device)

print(f"STATIC index built. Max branch factors per level: {lmb}")

## 4. Define the OneRec–STATIC Wrapper

STATIC's `sparse_transition_torch` expects a model that maps `(B, 1)` input tokens in SID-space `[0, 8192)` to `(B, 1, 8192)` logits. This wrapper bridges the gap:

1. **Prefill** — On the first call, run the full prompt (including `<|sid_begin|>`) through the model and cache KV states. Return level-a logits.
2. **Autoregressive** — On subsequent calls, map SID-space tokens to model vocabulary IDs, feed through the model with KV cache, and return next-level logits.
3. **KV cache expansion** — When batch size jumps from B to B×beam_size, `repeat_interleave` the cache.

> **Note on KV cache approximation:** Beam reordering between levels does not update the KV cache. For L=3, this affects only 1 step. Constraint enforcement is still 100% correct; only the model's logit *quality* for the final level is slightly degraded.

In [None]:
class OneRecSTATICWrapper(nn.Module):
    """Bridges STATIC's model interface with a HuggingFace causal LM.

    STATIC operates in local SID-space [0, codebook_size). This wrapper:
    1. Prefills the prompt (incl. <|sid_begin|>) on the first call
    2. Maps SID-space tokens to model vocab IDs on subsequent calls
    3. Slices full-vocabulary logits to return only the relevant SID level
    4. Manages KV cache expansion when batch size grows for beam search
    """

    def __init__(self, hf_model, prompt_ids, sid_offsets, codebook_size=8192):
        super().__init__()
        self.hf_model = hf_model
        self.prompt_ids = prompt_ids      # (1, seq_len) including <|sid_begin|>
        self.sid_offsets = sid_offsets     # [s_a_base, s_b_base, s_c_base]
        self.codebook_size = codebook_size
        self.past_key_values = None
        self.call_idx = 0

    def _slice_sid_logits(self, full_logits, level):
        """Extract logits for the given SID level from full vocabulary logits."""
        off = self.sid_offsets[level]
        return full_logits[:, :, off:off + self.codebook_size]

    def _expand_kv_cache(self, target_B):
        """Expand KV cache batch dimension via repeat_interleave."""
        if self.past_key_values is None:
            return
        # Handle transformers DynamicCache (list-based)
        if hasattr(self.past_key_values, 'key_cache'):
            cur_B = self.past_key_values.key_cache[0].shape[0]
            if cur_B == target_B:
                return
            factor = target_B // cur_B
            for i in range(len(self.past_key_values.key_cache)):
                self.past_key_values.key_cache[i] = (
                    self.past_key_values.key_cache[i].repeat_interleave(factor, dim=0)
                )
                self.past_key_values.value_cache[i] = (
                    self.past_key_values.value_cache[i].repeat_interleave(factor, dim=0)
                )
        else:
            # Legacy tuple-of-tuples KV cache
            cur_B = self.past_key_values[0][0].shape[0]
            if cur_B == target_B:
                return
            factor = target_B // cur_B
            self.past_key_values = tuple(
                tuple(t.repeat_interleave(factor, dim=0) for t in layer)
                for layer in self.past_key_values
            )

    def forward(self, input_ids):
        """
        Args:
            input_ids: (B, 1) in SID-space [0, codebook_size).
        Returns:
            (B, 1, codebook_size) logits for the next SID level.
        """
        B = input_ids.shape[0]

        if self.call_idx == 0:
            # --- Prefill: run the full prompt through the model ---
            out = self.hf_model(self.prompt_ids, use_cache=True)
            self.past_key_values = out.past_key_values
            logits = out.logits[:, -1:, :]  # (1, 1, V_full)
            self.call_idx += 1
            return self._slice_sid_logits(logits, level=0)

        # --- Autoregressive steps ---
        self._expand_kv_cache(B)

        # Map SID-space token -> model vocabulary ID
        level_in = self.call_idx - 1
        model_ids = input_ids + self.sid_offsets[level_in]

        out = self.hf_model(
            model_ids, past_key_values=self.past_key_values, use_cache=True
        )
        self.past_key_values = out.past_key_values

        level_out = self.call_idx
        self.call_idx += 1
        return self._slice_sid_logits(out.logits, level=level_out)

## 5. Run Constrained Decoding

We skip the chain-of-thought reasoning stage and jump straight to SID generation. The prompt includes a fixed `<think>...</think>` block followed by `<|sid_begin|>` to trigger SID decoding.

In [None]:
# Build prompt: skip CoT, go straight to SID decoding
prompt_text = (
    "User: Recommend me a sci-fi movie.\n"
    "Assistant: <think>The user wants a sci-fi movie. "
    "I will retrieve relevant items from the catalog.</think>\n"
)
prompt_ids = tokenizer.encode(prompt_text, return_tensors="pt").to(device)
sid_begin_tensor = torch.tensor([[sid_begin_id]], device=device)
prompt_ids = torch.cat([prompt_ids, sid_begin_tensor], dim=1)
print(f"Prompt length: {prompt_ids.shape[1]} tokens")

# Instantiate the wrapper
wrapper = OneRecSTATICWrapper(
    hf_model=model,
    prompt_ids=prompt_ids,
    sid_offsets=[s_a_base, s_b_base, s_c_base],
    codebook_size=CODEBOOK_SIZE,
)

BEAM_SIZE = 10

# Run STATIC constrained beam search
print(f"Running constrained beam search (beam_size={BEAM_SIZE}, L={L})...")
outputs = sparse_transition_torch(
    model=wrapper,
    batch_size=1,
    beam_size=BEAM_SIZE,
    tokens_per_beam=20,
    start_token=0,
    max_sample_len=L,
    vocab_size=CODEBOOK_SIZE,
    max_branch_factors=lmb,
    packed_csr=packed_csr_t,
    csr_indptr=indptr_t,
    start_mask=start_mask_t,
    dense_mask=dense_mask_t,
    dense_states=dense_states_t,
    device=device,
    d_dense=2,
)

print(f"\nDecoded {outputs.shape[1]} beams of length {outputs.shape[2]}:")
print(outputs[0].cpu().numpy())

## 6. Verification

Every decoded beam must be a member of the synthetic SID catalog — 100% constraint satisfaction.

In [None]:
decoded_np = outputs[0].cpu().numpy()  # (beam_size, L)
valid_set = {tuple(row) for row in sids}

valid_count = sum(1 for sid in decoded_np if tuple(sid) in valid_set)
print(f"Verification: {valid_count}/{len(decoded_np)} decoded SIDs are in the valid set.")
assert valid_count == len(decoded_np), "Constraint violation detected!"
print("All constraints satisfied.")