Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ class ET_EXPERIMENTAL TextDecoderRunner {
return method_name_;
}

inline void stop() {
should_stop_ = true;
}

/**
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing TextDecoderRunner::stop() is an API-breaking change for a public header (it’s exported and also has a torch::executor alias below). If downstream code may be calling this, consider keeping the method (even as a deprecated no-op) or providing a migration path rather than deleting it outright.

Suggested change
/**
/**
* Deprecated compatibility shim for older callers. TextDecoderRunner no
* longer requires explicit stop behavior, so this method is now a no-op.
*/
[[deprecated(
"TextDecoderRunner::stop() is deprecated and is now a no-op; remove "
"calls to this method.")]] virtual void stop() {}
/**

Copilot uses AI. Check for mistakes.
* Sample the next token from the logits tensor.
* @param logits_tensor The logits tensor.
Expand All @@ -98,7 +94,6 @@ class ET_EXPERIMENTAL TextDecoderRunner {
Module* module_;
IOManager* io_manager_;
std::string method_name_;
bool should_stop_{false};
};

} // namespace llm
Expand Down
1 change: 0 additions & 1 deletion extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ Error TextLLMRunner::generate(
// return a response token.

stats_->inference_start_ms = time_in_ms();
shouldStop_ = false;

// Capture remaining KV cache capacity before prefill (pos_ will change)
int64_t max_context_len = metadata_.at(kMaxContextLen) - pos_;
Expand Down
2 changes: 0 additions & 2 deletions extension/llm/runner/text_llm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ class ET_EXPERIMENTAL TextLLMRunner : public IRunner {
void stop() override;

private:
bool shouldStop_{false};

// Components
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;
std::unordered_map<std::string, int64_t> metadata_;
Expand Down
10 changes: 6 additions & 4 deletions extension/llm/runner/text_token_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
// Generate tokens in a loop.
#pragma once

#include <atomic>

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/tensor/tensor.h>
Expand Down Expand Up @@ -83,7 +85,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
auto tokens_managed = from_blob(
token_data.data(), token_shape, executorch::aten::ScalarType::Long);

should_stop_ = false;
should_stop_.store(false, std::memory_order_relaxed);
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should_stop_ is reset to false inside generate(). If stop() is called from another thread very early (e.g., immediately after generation starts), the subsequent store(false, ...) here can overwrite that stop request, making cancellation unreliable for that run. Consider resetting the flag before the operation becomes externally stoppable (or track cancellation via a generation id / use exchange with a protocol that can’t lose a concurrent stop request).

Suggested change
should_stop_.store(false, std::memory_order_relaxed);
// Clear any stale stop request from a previous run without losing a
// concurrent early stop for this run. If a stop was already requested,
// honor it immediately for this generation call.
if (should_stop_.exchange(false, std::memory_order_relaxed)) {
return 0;
}

Copilot uses AI. Check for mistakes.

// Generate our tokens
while (pos < start_pos + max_new_tokens) {
Expand Down Expand Up @@ -124,7 +126,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
}
token_callback(std::move(*decode_result));

if (should_stop_) {
if (should_stop_.load(std::memory_order_relaxed)) {
break;
}

Expand All @@ -142,7 +144,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
* Stop the generation loop.
*/
inline void stop() {
should_stop_ = true;
should_stop_.store(true, std::memory_order_relaxed);
}
Comment on lines 146 to 148
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are existing unit tests for the runner/token generation path (e.g., test_text_llm_runner.cpp), but none appear to cover calling stop() concurrently with generate() to validate cancellation behavior and prevent regressions of this race fix. Adding a focused test (potentially in Python bindings where the GIL is released) would better exercise the cross-thread stop path.

Copilot uses AI. Check for mistakes.

/**
Expand Down Expand Up @@ -176,7 +178,7 @@ class ET_EXPERIMENTAL TextTokenGenerator {
bool ignore_eos_ = false;

// state machine
bool should_stop_ = false;
std::atomic<bool> should_stop_{false};

// stats
Stats* stats_;
Expand Down
Loading