In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../maxtext")))
os.environ["SKIP_JAX_PRECOMPILE"] = "1"

import functools
from etils import epath


import transformers
import numpy as np

import jax
from flax import nnx
from flax import linen as nn

import MaxText as mt
from MaxText import pyconfig
from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama

from tunix.rl.rollout.vllm_rollout import VllmRollout
from tunix.rl.rollout.base_rollout import RolloutConfig

from tunix.rl.rollout import base_rollout
from tunix.models.llama3 import model as llama3_lib

from vllm import LLM
import orbax.checkpoint as ocp

2025-08-15 23:25:23.155136: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755300323.169961  774458 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755300323.173777  774458 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1755300323.185425  774458 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755300323.185438  774458 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1755300323.185440  774458 computation_placer.cc:177] computation placer alr

INFO 08-15 23:25:27 [__init__.py:241] Automatically detected platform tpu.
INFO 08-15 23:25:27 [__init__.py:16] TPU info: node_name=wenxindong-v6e-8 | tpu_type=v6e-8 | worker_id=0 | num_chips=8 | num_cores_per_chip=1
INFO 08-15 23:25:27 [__init__.py:29] Running vLLM without Pathways. Module pathwaysutils is not imported.


In [3]:
MODEL = "meta-llama/Llama-3.1-8B"
TOTAL_TPU_TO_USE = 8
MESH = [(1, TOTAL_TPU_TO_USE), ("fsdp", "tp")]  # YY


model_tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL)
mesh = model_mesh = jax.make_mesh(
    *MESH, devices=jax.devices()[:TOTAL_TPU_TO_USE]
)

In [4]:
def get_ref_maxtext_model(config):

  def create_model(config):
    return mt.from_pretrained(config, rngs=nnx.Rngs(params=0, dropout=1))

  abstract_model = nnx.eval_shape(create_model, config=config)
  graphdef, abstract_state = nnx.split(abstract_model)
  print("The abstract NNX state (all leaves are abstract arrays):")
  nnx.display(abstract_state)
  specs = nnx.get_partition_spec(abstract_state)
  mesh = abstract_model.mesh

  # JIT a function that creates the model state with proper sharding from the start.
  # By providing out_shardings, we instruct JAX to produce sharded output directly,
  # avoiding a large intermediate allocation on a single device.
  with nn.logical_axis_rules(config.logical_axis_rules):
    out_shardings = nn.logical_to_mesh_sharding(specs, mesh)

  @functools.partial(jax.jit, out_shardings=out_shardings)
  def create_sharded_state():
    # This will be JIT-compiled. JAX knows the output sharding and can
    # initialize the parameters directly on the target devices in a sharded way.
    model = create_model(config)
    return nnx.state(model)

  with mesh:
    # Create the model with sharded parameters.
    sharded_state = create_sharded_state()
    model = nnx.merge(graphdef, sharded_state)

    if config.load_parameters_path:
      target_for_restore = jax.tree.map(
          lambda v: v.value,
          sharded_state,
          is_leaf=lambda n: isinstance(n, nnx.Variable),
      )

      try:
        ckptr = ocp.Checkpointer(
            ocp.PyTreeCheckpointHandler(
                restore_concurrent_gb=None,
                save_concurrent_gb=None,
                use_ocdbt=True,
                use_zarr3=True,
            )
        )
        # This is a memory optimization. We don't want to restore the entire checkpoint - only the params.
        # Rather than pass the entire abstract state, which could unnecessarily restore opt_state and such and waste
        # memory, we instead specify here that we are just restoring the params field of the checkpoint
        # (which itself may be a dictionary containing a key named 'params').
        restore_args = ocp.checkpoint_utils.construct_restore_args(
            target_for_restore
        )
        restored = ckptr.restore(
            epath.Path(config.load_parameters_path),
            item={"params": {"params": target_for_restore}},
            transforms={},
            restore_args={"params": {"params": restore_args}},
        )
        checkpoint = restored["params"]["params"]

        if checkpoint:
          nnx.update(model, checkpoint)

      except Exception as e:
        raise ValueError(f"Checkpointing failed: {e}")

    tunix_model = TunixMaxTextLlama(
        base_model=model,
        use_attention_mask=False,  # trust Tunix loss masking
    )

    model_config = llama3_lib.ModelConfig.llama3_1_8b()
    tunix_model.config = model_config

  return tunix_model, mesh, model_config


from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama

config_ref = pyconfig.initialize(
    [
        "",
        "../../maxtext/MaxText/configs/base.yml",
    ],  # TODO: @mazumdera: why decode.py?
    base_output_directory="gs://dummy_output_dir",  # This is not used in Tunix.
    run_name="test-tunix-maxtext-llama3.1-8b",
    tokenizer_type="tiktoken",
    tokenizer_path="assets/tokenizer_llama3.tiktoken",
    load_parameters_path="gs://maxtext-model-checkpoints/llama3.1-8b/2025-01-23-19-04/scanned/0/items",
    per_device_batch_size=1,
    max_prefill_predict_length=4,
    max_target_length=16,
    steps=10,
    async_checkpointing="false",
    model_name="llama3.1-8b",
    checkpoint_period=5,
    skip_jax_distributed_system="true",
    weight_dtype="bfloat16",
    attention="dot_product",
    remat_policy="custom",
    decoder_layer_input="offload",
    query_proj="offload",
    key_proj="offload",
    value_proj="offload",
    opt_type="sgd",
)

maxtext_model, _, model_config = get_ref_maxtext_model(config_ref)
nnx.display(maxtext_model)

Updating keys from env and command line: ['run_name', 'model_name', 'load_parameters_path', 'async_checkpointing', 'checkpoint_period', 'weight_dtype', 'remat_policy', 'decoder_layer_input', 'query_proj', 'key_proj', 'value_proj', 'attention', 'base_output_directory', 'tokenizer_path', 'tokenizer_type', 'per_device_batch_size', 'steps', 'skip_jax_distributed_system', 'max_target_length', 'max_prefill_predict_length', 'opt_type']
Running Model: llama3.1-8b
Updating following parameters in config

base_emb_dim: 4096
base_num_query_heads: 32
base_num_kv_heads: 8
base_num_decoder_layers: 32
base_mlp_dim: 14336
head_dim: 128
mlp_activations: ['silu', 'linear']
vocab_size: 128256
enable_dropout: False
logits_via_embedding: False
normalization_layer_epsilon: 1e-05
rope_max_timescale: 500000
decoder_block: llama2
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_num_decoder_layers', 'base_mlp_dim', 'head_dim', 'mlp_activations', 'vocab_size', 'enable

Num_devices: 8, shape (1, 1, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1)


I0815 23:26:16.059053  775862 google_auth_provider.cc:149] Using credentials at /home/wenxindong_google_com/.config/gcloud/application_default_credentials.json
I0815 23:26:16.059135  775862 google_auth_provider.cc:156] Using OAuth2 AuthProvider


In [5]:
def create_maxtext_to_vllm_mappings():
  """Create mappings for transferring MaxText scanned state to vLLM unscanned state."""
  return {
      # Token embeddings - shard vocab dimension for TP
      'base.token_embedder.embedding': (
          'embed.embedding',
          ('model', None),
      ),  # checked
      # Final layer norm - no sharding needed
      'base.decoder.decoder_norm.scale': (
          'model.norm.scale',
          (None,),
      ),  # checked
      # LM head (logits projection) - shard vocab dimension for TP
      'base.decoder.logits_dense.kernel': (
          'lm_head',
          (None, 'model'),
      ),  # checked
      # Layer-specific mappings (scanned -> unscanned)
      # MLP components - shard hidden dimensions for TP
      'base.decoder.layers.mlp.wi_0.kernel': (  # checked
          'model.layers.*.mlp.gate_proj.kernel',
          (None, 'layer', 'model'),
      ),  # gate_proj: (4096, 14336) - shard output
      'base.decoder.layers.mlp.wi_1.kernel': (  # checked
          'model.layers.*.mlp.up_proj.kernel',
          (None, 'layer', 'model'),
      ),  # up_proj: (4096, 14336) - shard output
      'base.decoder.layers.mlp.wo.kernel': (  # checked
          'model.layers.*.mlp.down_proj.kernel',
          ('model', 'layer', None),
      ),  # down_proj: (14336, 4096) - shard input
      # Layer norms - no sharding needed
      'base.decoder.layers.pre_self_attention_layer_norm.scale': (
          'model.layers.*.input_layernorm.scale',
          (None, 'layer'),  # checked
      ),
      'base.decoder.layers.post_self_attention_layer_norm.scale': (
          'model.layers.*.post_attention_layernorm.scale',
          (None, 'layer'),  # checked
      ),
      # Attention components - shard head dimensions for TP
      'base.decoder.layers.self_attention.query.kernel': (
          'model.layers.*.self_attn.q_proj.kernel',
          (None, 'layer', 'model', None),
      ),  # q_proj: shard num_heads # NOT MATCH
      'base.decoder.layers.self_attention.key.kernel': (
          'model.layers.*.self_attn.k_proj.kernel',
          (None, 'layer', 'model', None),
      ),  # k_proj: shard num_kv_heads
      'base.decoder.layers.self_attention.value.kernel': (
          'model.layers.*.self_attn.v_proj.kernel',
          (None, 'layer', 'model', None),  # match
      ),  # v_proj: shard num_kv_heads
      'base.decoder.layers.self_attention.out.kernel': (
          'model.layers.*.self_attn.o_proj.kernel',
          ('model', 'layer', None, None),
      ),  # o_proj: shard input heads #match
  }

In [6]:
transpose_keys = {}


def reorder_rope(arr):
  evens = arr[..., ::2]
  odds = arr[..., 1::2]
  return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)


def transform_query_kernel(arr):
  head_dim = arr.shape[-1]
  assert head_dim == 128  # hard coded for now
  depth_scale = np.dtype('float32').type(np.sqrt(head_dim))
  arr = arr * depth_scale
  return reorder_rope(arr)


def transform_key_kernel(arr):
  return reorder_rope(arr)


hook_fns = {
    'base.decoder.layers.self_attention.query.kernel': transform_query_kernel,
    'base.decoder.layers.self_attention.key.kernel': transform_key_kernel,
}

maxtext_model.to_hf_mappings = create_maxtext_to_vllm_mappings
maxtext_model.to_hf_transpose_keys = lambda *args: transpose_keys
maxtext_model.lora_to_hf_mappings = lambda *args: None  # No LoRA
maxtext_model.to_hf_hook_fns = lambda *args: hook_fns

In [7]:
TOTAL_GENERATION_STEPS = 64
MAX_PROMPT_LENGTH = 64
TEMPERATURE = 0.9
TOP_P = 1.0
TOP_K = None
cache_config = base_rollout.RolloutConfig(
    max_tokens_to_generate=TOTAL_GENERATION_STEPS,
    max_prompt_length=MAX_PROMPT_LENGTH,
    kv_cache_size=MAX_PROMPT_LENGTH + TOTAL_GENERATION_STEPS + 256,
    temperature=TEMPERATURE,
    top_p=TOP_P,
    top_k=TOP_K,
)

In [None]:
vllm_rollout = VllmRollout(
    model=maxtext_model,
    tokenizer=model_tokenizer,
    cache_config_or_size=1024,
    mesh=mesh,
    lora_config=None,
    model_version=MODEL,
)

In [24]:
output = vllm_rollout.generate(
    ["What is the capital of France?"],
    rollout_config=RolloutConfig(
        n=1, max_tokens_to_generate=64, temperature=0.1
    ),
)

Adding requests:   0%|          | 0/1 [00:00<?, ?it/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

In [25]:
output

RolloutOutput(text=['<|begin_of_text|>://www.google.com/url?q=https://www.1stwebdesigner.com/what-is-the-capital-of-france/&sa=U&ved=2ahUKEwjYjv7YjY7jAhVJjYQKHZQfBqoQFjAAeg'], logits=None, tokens=Array([[128000,   1129,   2185,   5831,    916,  58354,  44882,  53099,
          1129,   2185,     13,     16,    267,   2984,  25894,    261,
           916,     14,  12840,  31530,  10826,  98231,   8838,  51478,
           685,  84472,   9258,     28,     52,      5,   2111,     28,
            17,   1494,     52,   3472,  68054,     56,     73,     85,
            22,     56,     73,     56,     22,     73,  25797,     53,
            41,     73,     56,     48,     42,  62859,     48,     69,
            33,     80,     78,     48,     37,     73,   6157,    797]],      dtype=int32), left_padded_prompt_tokens=Array([[128001, 128001, 128001, 128001, 128001, 128001, 128001, 128000,
          3923,    374,    279,   6864,    315,   9822,     30, 128001]],      dtype=int32), logprobs=[[0.0, 

In [None]:
## code below not useful. dont run

In [None]:
# # vLLM model
# golden_llm = LLM(
#     MODEL,
#     max_model_len=1024,
#     tensor_parallel_size=8,
#     gpu_memory_utilization=0.3,
# )

# Compare weights

# _golden_state_flatten = (
#     golden_llm.llm_engine.model_executor.driver_worker.model_runner.state.flat_state()
# )
# _maxtext_state_flatten = nnx.state(llama3_1_8b).flat_state()

# golden_state_flatten = {
#     '.'.join(str(key) for key in keys): v for keys, v in _golden_state_flatten
# }
# maxtext_state_flatten = {
#     '.'.join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten
# }

# from tunix.generate.vllm_sampler import VllmSampler, MappingConfig

# sampler = VllmSampler(
#     mesh=mesh,
#     tokenizer=model_tokenizer,
#     max_model_len=1024,
#     model_version="meta-llama/Llama-3.1-8b",
#     mapping_config=MappingConfig({}, {}, {}, None),
#     hbm_utilization=0.3,
# )