Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,30 @@ def export_and_lower(model, config, args):
_export_cuda(model, config, args)


def _strip_sampler_from_forward(model):
"""Bind ``model.forward`` to a minimal ``(tokens, input_pos) -> logits``
variant for non-CUDA export.

The default ``Qwen35MoE.forward`` carries an optional temperature input and
a sampling branch used only by the on-device CUDA sampler; non-CUDA
backends sample on the host so that branch is dead code at trace time.
Even when statically eliminated, the extra parameter and branch perturb
the program ``torch.export`` produces enough to shift kernel selection in
the lowered MLX/Metal graph and slow execution by 10-30%. Eager callers
and the CUDA export path are unaffected.
"""
import types

def _clean_forward(self, tokens, input_pos):
x = self.embed_tokens(tokens)
for layer in self.layers:
x = layer(x, input_pos)
x = self.norm(x)
return self.lm_head(x)

model.forward = types.MethodType(_clean_forward, model)


def _export_mlx(model, config, args):
"""Export model to .pte via torch.export + MLX backend."""
import gc
Expand All @@ -568,6 +592,8 @@ def _export_mlx(model, config, args):
from executorch.exir.passes import MemoryPlanningPass
from torch.export import Dim, export

_strip_sampler_from_forward(model)

example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
Expand Down Expand Up @@ -650,6 +676,7 @@ def _export_metal(model, config, args):

inductor_config.coordinate_descent_tuning = False
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
_strip_sampler_from_forward(model)

# --- Decode method (T=1, static shape) ---
print("Exporting decode method...")
Expand Down
48 changes: 42 additions & 6 deletions examples/models/qwen3_5_moe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

#ifdef EXECUTORCH_BUILD_CUDA
#include <cuda_runtime.h>
#else
#include <executorch/extension/llm/sampler/util.h>
#endif

DEFINE_string(model_path, "", "Model .pte file path.");
Expand All @@ -37,7 +39,10 @@ DEFINE_string(
"Path to file containing prompt text (overrides --prompt).");
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");
DEFINE_bool(
cuda_graph,
false,
"Enable CUDA graph for decode method. CUDA only.");

namespace llm = ::executorch::extension::llm;
using ::executorch::extension::from_blob;
Expand All @@ -48,10 +53,18 @@ using ::executorch::runtime::EValue;

using SizesType = executorch::aten::SizesType;

// Read a sampled token from the model output tensor [B, 1].
// The model performs Gumbel-max sampling on-device and returns a single
// float token ID. This function copies it from GPU and casts to uint64.
// Convert a model output tensor to the next sampled token id.
//
// On the CUDA build, the model fuses the sampler in (see sampler.py /
// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1]
// float tensor; we just copy that scalar back from device.
//
// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits
// of shape [B, T, V] in the model dtype (typically bf16). We sample on
// CPU via the shared `llm::logits_to_token` helper, which accepts a
// temperature (0 = greedy / argmax).
static uint64_t read_token(const executorch::aten::Tensor& output) {
#ifdef EXECUTORCH_BUILD_CUDA
const void* ptr = output.const_data_ptr();

cudaPointerAttributes attrs;
Expand All @@ -73,6 +86,13 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
memcpy(&val, ptr, sizeof(float));
}
return static_cast<uint64_t>(val);
#else
// logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 /
// UInt16 dtypes. Negative temperatures are clamped to 0 (greedy).
const float temp =
FLAGS_temperature <= 0.0 ? 0.0f : static_cast<float>(FLAGS_temperature);
return static_cast<uint64_t>(llm::logits_to_token(output, temp));
#endif
}

int main(int argc, char** argv) {
Expand Down Expand Up @@ -133,16 +153,23 @@ int main(int argc, char** argv) {
}
auto metadata = metadata_result.get();

#ifdef EXECUTORCH_BUILD_CUDA
// Set CUDA graph option if requested (must be before load_method)
if (FLAGS_cuda_graph) {
executorch::runtime::BackendOptions<2> cuda_opts;
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");
executorch::runtime::set_option("CudaBackend", cuda_opts.view());
printf("CUDA graph enabled for decode method\n");
}
#else
if (FLAGS_cuda_graph) {
ET_LOG(Info, "--cuda_graph ignored on non-CUDA build");
}
#endif

printf("Loading methods...\n");

#ifdef EXECUTORCH_BUILD_CUDA
// Enable cross-method per-FQN weight sharing in the CUDA backend so that
// prefill and decode (which share KV cache and other mutable buffers /
// weights) avoid duplicate GPU allocations. This is critical for fitting
Expand Down Expand Up @@ -170,6 +197,7 @@ int main(int argc, char** argv) {
return 1;
}
}
#endif

auto err = module->load_method("prefill");
if (err != Error::Ok) {
Expand Down Expand Up @@ -224,12 +252,16 @@ int main(int argc, char** argv) {
// ---------------------------------------------------------------
auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };

// Use a very small temperature for greedy to avoid division by zero
// while keeping the Gumbel noise negligible relative to logit differences.
#ifdef EXECUTORCH_BUILD_CUDA
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more generic? Detect if pte was exported with sampler built in and route appropriately?

Then make fuse-sampler an export arg that is on for cuda and off for mlx/metal for now?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would like to keep current status; i don't think fusing sampler into model's forward method is a good practice. This is a temporary solution before device support and once we get it in the near future all modules should return a logit and use a sampler tool to do sampling.

// CUDA build: model fuses the sampler in. Pass a temperature tensor as
// a third input. Use a very small temperature for greedy to avoid
// division by zero while keeping the Gumbel noise negligible relative
// to logit differences.
float temp_val =
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);
auto temp_tensor =
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
#endif

// ---------------------------------------------------------------
// Prefill
Expand Down Expand Up @@ -260,7 +292,9 @@ int main(int argc, char** argv) {
std::vector<EValue> prefill_inputs;
prefill_inputs.push_back(tokens_tensor);
prefill_inputs.push_back(pos_tensor);
#ifdef EXECUTORCH_BUILD_CUDA
prefill_inputs.push_back(temp_tensor);
#endif

auto prefill_result = module->execute(run_method, prefill_inputs);
if (prefill_result.error() != Error::Ok) {
Expand Down Expand Up @@ -308,7 +342,9 @@ int main(int argc, char** argv) {
std::vector<EValue> decode_inputs;
decode_inputs.push_back(EValue(decode_tokens));
decode_inputs.push_back(EValue(decode_pos));
#ifdef EXECUTORCH_BUILD_CUDA
decode_inputs.push_back(EValue(temp_tensor));
#endif

auto decode_result = module->execute("decode", decode_inputs);
if (decode_result.error() != Error::Ok) {
Expand Down
13 changes: 1 addition & 12 deletions examples/models/qwen3_5_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import torch
import torch.nn as nn

from executorch.examples.models.qwen3_5_moe.sampler import sample
from torch.nn import functional as F

Expand Down Expand Up @@ -186,7 +185,6 @@ def _apply_rotary(x, cos, sin):


class KVCache(nn.Module):

def __init__(self, n_kv_heads, head_dim, max_seq_len):
super().__init__()
self.register_buffer(
Expand All @@ -207,7 +205,6 @@ def update(self, input_pos, k_val, v_val):


class FullAttention(nn.Module):

def __init__(self, config):
super().__init__()
self.n_heads = config.num_attention_heads
Expand Down Expand Up @@ -318,7 +315,6 @@ def forward(self, x, input_pos):


class GatedDeltaNet(nn.Module):

def __init__(self, config):
super().__init__()
self.num_k_heads = config.linear_num_key_heads
Expand Down Expand Up @@ -540,7 +536,6 @@ def forward(self, x):


class SparseMoE(nn.Module):

def __init__(self, config):
super().__init__()
self.top_k = config.num_experts_per_tok
Expand Down Expand Up @@ -574,7 +569,6 @@ def forward(self, x):


class Block(nn.Module):

def __init__(self, config, layer_idx):
super().__init__()
self.layer_type = config.layer_types[layer_idx]
Expand All @@ -599,7 +593,6 @@ def forward(self, x, input_pos):


class Qwen35MoE(nn.Module):

def __init__(self, config):
super().__init__()
self.config = config
Expand All @@ -620,12 +613,8 @@ def forward(
for layer in self.layers:
x = layer(x, input_pos)
x = self.norm(x)
# When no sampling is requested, return the full ``[B, T, V]``
# logits so callers (eval, custom samplers) can inspect every
# position. Otherwise apply the prefill optimization and only
# materialize ``[B, V]`` for the last token.
if temperature is None:
return self.lm_head(x).float() # [B, T, V] float32
return self.lm_head(x) # [B, T, V] in model dtype
logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32
# GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is
# equivalent to drawing from softmax(logits/T) but stays entirely
Expand Down
Loading