From eab28ed318bc16cd8f967422758fd8ed7c7d50ae Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 15 Nov 2025 17:28:18 +0000 Subject: [PATCH 1/3] mtmd: add DeepSeek-OCR LM support with standard attention --- convert_hf_to_gguf.py | 14 ++++++++------ gguf-py/gguf/gguf_writer.py | 2 +- src/llama-arch.cpp | 2 ++ src/llama-model.cpp | 30 ++++++++++++++++++++++++++++++ src/models/deepseek2.cpp | 29 ++++++++++++++++++++++++++++- 5 files changed, 69 insertions(+), 8 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 77fc77e823400..c8a48c01bfb8e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1494,12 +1494,9 @@ def __init__(self, *args, **kwargs): # FIXME: DeepseekOCRVisionModel specific hack if self.block_count is None: if isinstance(self, DeepseekOCRVisionModel): - print(self.hparams) clip_block_count = self.hparams['layers'] if clip_block_count is not None: self.block_count = clip_block_count - if sam_block_count is not None: - self.block_count = sam_block_count if self.block_count is None else self.block_count + sam_block_count if self.block_count is None: raise KeyError(f"could not find block count using any of: {self.n_block_keys}") self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) @@ -7095,10 +7092,15 @@ def set_vocab(self): raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!") def set_gguf_parameters(self): + is_ocr = (self.hparams["num_hidden_layers"] == 12) - # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) - self.hparams["num_key_value_heads"] = 1 - + if is_ocr: + self.hparams['rope_theta'] = self.hparams.get('rope_theta', 10000.0) + self.hparams['rms_norm_eps'] = self.hparams.get('rms_norm_eps', 1e-6) + else: + # note: deepseek2 using MLA converts into MQA (ie: GQA with 1 group) + self.hparams["num_key_value_heads"] = 1 + super().set_gguf_parameters() hparams = self.hparams kv_lora_rank = hparams["q_lora_rank"] if hparams["q_lora_rank"] is not None else 512 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index fca498a859dd7..34ecb5e396fec 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -813,7 +813,7 @@ def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) def add_layer_norm_rms_eps(self, value: float) -> None: - self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + self.add_float64(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) def add_group_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index b7642b568dffb..ac3ab5cfa7779 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1446,6 +1446,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, { LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" }, { LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" }, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 829f1e3c14f82..a21a3ce619ad5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4550,6 +4550,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { const bool is_lite = (hparams.n_layer == 27); + const bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -4575,6 +4576,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + if (is_ocr) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + + continue; + } + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); if (!is_lite) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 68f72f72bb643..e649286cecfaa 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -5,6 +5,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (model.name.find("ocr") != std::string::npos || model.name.find("OCR") != std::string::npos); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -44,7 +45,33 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head, n_tokens); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + } + else { ggml_tensor * q = NULL; if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); From 76305878d52cb20de142369d348212cee5eb5436 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sun, 16 Nov 2025 08:45:08 +0000 Subject: [PATCH 2/3] mtmd: successfully runs DeepSeek-OCR LM in llama-cli --- convert_hf_to_gguf.py | 15 ++++++++------- gguf-py/gguf/gguf_writer.py | 2 +- src/llama-model.cpp | 17 +++++++++++------ src/models/deepseek2.cpp | 15 +++++++++------ 4 files changed, 29 insertions(+), 20 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index c8a48c01bfb8e..6d07b9acdb75b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -7112,13 +7112,16 @@ def set_gguf_parameters(self): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) - self.gguf_writer.add_kv_lora_rank(kv_lora_rank) + if "kv_lora_rank" in hparams and hparams["kv_lora_rank"] is not None: + self.gguf_writer.add_kv_lora_rank(kv_lora_rank) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length(kv_lora_rank) - self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) - self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + if not is_ocr: + self.gguf_writer.add_key_length(kv_lora_rank + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length(kv_lora_rank) + self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) + self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) @@ -7133,8 +7136,6 @@ def set_gguf_parameters(self): else: raise ValueError(f"Unsupported scoring_func value: {scoring_func}") - self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) - rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 34ecb5e396fec..fca498a859dd7 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -813,7 +813,7 @@ def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) def add_layer_norm_rms_eps(self, value: float) -> None: - self.add_float64(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) + self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value) def add_group_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a21a3ce619ad5..79639c515ebe4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1562,12 +1562,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); + bool is_ocr = (name.find("ocr") != std::string::npos || name.find("OCR") != std::string::npos); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - if (!is_lite) { + if (!is_lite && !is_ocr) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + if (!is_ocr) { + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + } ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -1583,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; case 27: type = LLM_TYPE_16B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; @@ -4578,10 +4583,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (is_ocr) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); // TODO - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); if (i < (int) hparams.n_layer_dense_lead) { diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index e649286cecfaa..375f3594541bb 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -46,6 +46,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // self_attention if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + ggml_tensor * Qcur = NULL; ggml_tensor * Kcur = NULL; ggml_tensor * Vcur = NULL; @@ -57,13 +60,13 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(Kcur, "k", il); cb(Vcur, "v", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head, n_tokens); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); cb(Qcur, "q_pe", il); cb(Kcur, "k_pe", il); From 2de3436705a853c815daf5c2bab5dcae18ee47c1 Mon Sep 17 00:00:00 2001 From: bluebread Date: Mon, 17 Nov 2025 08:44:29 +0000 Subject: [PATCH 3/3] mtmd: Fix RoPE type for DeepSeek-OCR LM. --- examples/eval-callback/eval-callback.cpp | 18 +++++++++--------- src/models/deepseek2.cpp | 5 +++-- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index cefa39a57c886..ed181a1ab4500 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -74,19 +74,19 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } } for (int64_t i3 = 0; i3 < ne[3]; i3++) { - LOG(" [\n"); + LOG(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i2 = ne[2] - n; } - LOG(" [\n"); + LOG(" [\n"); for (int64_t i1 = 0; i1 < ne[1]; i1++) { if (i1 == n && ne[1] > 2*n) { - LOG(" ..., \n"); + LOG(" ..., \n"); i1 = ne[1] - n; } - LOG(" ["); + LOG(" ["); for (int64_t i0 = 0; i0 < ne[0]; i0++) { if (i0 == n && ne[0] > 2*n) { LOG("..., "); @@ -98,10 +98,10 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne } LOG("],\n"); } - LOG(" ],\n"); + LOG(" ],\n"); } - LOG(" ]\n"); - LOG(" sum = %f\n", sum); + LOG(" ]\n"); + LOG(" sum = %f\n", sum); } // TODO: make this abort configurable/optional? @@ -136,7 +136,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); } - LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + LOG("%s: %16s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type), ggml_op_desc(t), src0->name, ggml_ne_string(src0).c_str(), src1 ? src1_str : "", diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 375f3594541bb..bc1b2127acd96 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -47,6 +47,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // self_attention if (is_ocr) { const int n_embed_head = hparams.n_embd / hparams.n_head(); + const int ocr_rope_type = GGML_ROPE_TYPE_NEOX; GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); ggml_tensor * Qcur = NULL; @@ -65,8 +66,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); cb(Qcur, "q_pe", il); cb(Kcur, "k_pe", il);