diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e04f0fc4f9a..bb12ddac8cf 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -248,7 +248,10 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const size_t max_nodes = this->graph_max_nodes(); + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); @@ -300,9 +303,6 @@ llama_context::llama_context( cross.v_embd.clear(); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - // avoid reserving graphs with zero outputs - assume one output per sequence n_outputs = n_seqs; @@ -1386,9 +1386,9 @@ void llama_context::output_reorder() { // graph // -uint32_t llama_context::graph_max_nodes() const { +uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) { if (model.arch == LLM_ARCH_QWEN3NEXT) { - return std::max(8192u, 32u*model.n_tensors()); + return std::max(n_tokens * 40, 32u * model.n_tensors()); } return std::max(1024u, 8u*model.n_tensors()); } diff --git a/src/llama-context.h b/src/llama-context.h index 20cbd789554..97410913b14 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -197,7 +197,7 @@ struct llama_context { // public: - uint32_t graph_max_nodes() const; + uint32_t graph_max_nodes(uint32_t n_tokens); // can reuse the llm_graph_result instance of the context (for example to update a memory module) llm_graph_result * get_gf_res_reserve() const;