From 6c0715befcab53eab7fb03cd82437c715729297e Mon Sep 17 00:00:00 2001 From: bluebread Date: Tue, 18 Nov 2025 06:19:38 +0000 Subject: [PATCH 1/6] fix: update callback for ffn_moe_weighted and add callback for attn_out in deepseek2 model --- src/llama-graph.cpp | 2 +- src/models/deepseek2.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b199e94628fff..4daf3f230b575 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1106,7 +1106,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + cb(experts, "ffn_moe_weighted", il); } ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index bc1b2127acd96..f4a40d7d6e805 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -74,6 +74,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); } else { ggml_tensor * q = NULL; From 8bce66d5f2a76e4e638f07b40769fdc8a248ad7d Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 15:28:37 +0000 Subject: [PATCH 2/6] clip: fixed warnings --- tools/mtmd/clip.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 40b60cbfd5da6..eb3d461dac15a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -667,9 +667,9 @@ struct clip_graph { constexpr int _depth = 12; constexpr int enc_n_heads = 12; constexpr int enc_d_heads = enc_n_embd / enc_n_heads; - constexpr int _prompt_n_embd = 256; + // constexpr int _prompt_n_embd = 256; constexpr int enc_patch_size = 16; - constexpr int _window_size = 14; + // constexpr int _window_size = 14; const int enc_n_patches = enc_image_size / enc_patch_size; // 64 @@ -834,7 +834,7 @@ struct clip_graph { ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); - ggml_tensor * global_features_2 = build_dp_ocr_clip(inp_raw, global_features_1); + ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) global_features_1 = ggml_permute(ctx0, global_features_1,2,1,0,3); @@ -1532,7 +1532,7 @@ struct clip_graph { return gf; } - ggml_tensor * build_dp_ocr_clip(ggml_tensor * inpL, ggml_tensor * patch_embeds) { + ggml_tensor * build_dp_ocr_clip(ggml_tensor * patch_embeds) { GGML_ASSERT(model.class_embedding != nullptr); GGML_ASSERT(model.position_embeddings != nullptr); @@ -2466,6 +2466,8 @@ struct clip_graph { return inpL; } + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { auto [c, w, h, b] = x->ne; // same as @@ -2486,6 +2488,8 @@ struct clip_graph { return x; } + // Implementation based on approach suggested by Acly + // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 static ggml_tensor* window_unpartition(ggml_context* m, ggml_tensor* x, int w, int h, int window) { int64_t c = x->ne[0]; // same as @@ -4881,7 +4885,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const int min_num = 2; const int max_num = 9; const int image_size = params.image_size; // typically 640 - const bool use_thumbnail = true; // mimic python's use_thumbnail + // const bool use_thumbnail = true; // mimic python's use_thumbnail // original image size const int orig_w = original_size.width; From 5e6cf3c6a838af32c3debce73425d199477b7669 Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 15:36:45 +0000 Subject: [PATCH 3/6] Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr --- tools/mtmd/clip.cpp | 109 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 103 insertions(+), 6 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index eb3d461dac15a..45cc2328c8d55 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -734,8 +734,8 @@ struct clip_graph { struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads)); - struct ggml_tensor * rw = ggml_get_rel_pos(ctx0, layer.rel_pos_w, W, W); - struct ggml_tensor * rh = ggml_get_rel_pos(ctx0, layer.rel_pos_h, H, H); + struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); + struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); @@ -745,7 +745,7 @@ struct clip_graph { 2, 1, 3)); struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - struct ggml_tensor * attn = ggml_add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); + struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); @@ -837,9 +837,9 @@ struct clip_graph { ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) - global_features_1 = ggml_permute(ctx0, global_features_1,2,1,0,3); - global_features_1 = ggml_cont(ctx0, global_features_1); + global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); global_features_1 = ggml_reshape_2d(ctx0, global_features_1, n_embd, n_patches); + // remove CLS token global_features_2 = ggml_view_2d(ctx0, global_features_2, n_embd, n_patches, @@ -850,6 +850,7 @@ struct clip_graph { global_features = ggml_cont(ctx0, global_features); global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); global_features = ggml_add(ctx0, global_features, model.fc_b); + global_features = build_global_local_features(ctx0,global_features); ggml_build_forward_expand(gf, global_features); return gf; @@ -869,7 +870,6 @@ struct clip_graph { t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); nl = ggml_repeat_4d(ctx0, nl, 64, 1, 1280, 1); // n_pos rows - nl = ggml_cont(ctx0, nl); // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] @@ -2466,6 +2466,103 @@ struct clip_graph { return inpL; } + // attn: [k_h*k_w, q_h*q_w] +// rel_h: [q_h, q_w, k_h] +// rel_w: [q_h, q_w, k_w] + +static ggml_tensor * add_rel_pos_inplace( + ggml_context * ctx, + ggml_tensor * attn, + ggml_tensor * rel_w, + ggml_tensor * rel_h, + int q_size +) { + + ggml_tensor *attn_4d = + ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_4d = + ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor *rel_w_4d = + ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); + + ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d + + ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); + result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); + + + return result; +} + + +static ggml_tensor * get_rel_pos( + ggml_context * ctx, + ggml_tensor * rel_pos, // [L, C] + int q_size, + int k_size +) { + + const auto dtype = rel_pos->type; + + const int64_t L = rel_pos->ne[0]; // length + const int64_t C = rel_pos->ne[1]; // channels + + // ------------------------------------------------- + // 1) q_idx ← arange(0..q_size-1) [q_size] + // 2) k_idx ← arange(0..k_size-1) [k_size] + // ------------------------------------------------- + + + ggml_tensor * q_coord = ggml_cast(ctx, + ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f), + GGML_TYPE_F32); // [q_size] + ggml_tensor * k_coord = ggml_cast(ctx, + ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f), + GGML_TYPE_F32); // [k_size] + + ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size); + q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size] + + // broadcast reshape: + k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size] + k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + + // ------------------------------------------------- + // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling + // ------------------------------------------------- + rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size] + + rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + + // ------------------------------------------------- + // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) + // ------------------------------------------------- + + ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast(L - 1)); + + ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size] + + // flatten to 1D for ggml_get_rows + const int64_t qk = static_cast(q_size) * static_cast(k_size); + ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + + // ------------------------------------------------- + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + + // reshape to final output → [q_size, k_size, C] + ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0], + q_size, + k_size); + + return out; // [q_size, k_size, C] +} + // Implementation based on approach suggested by Acly // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 static ggml_tensor* window_partition(ggml_context* ctx, ggml_tensor* x, int window) { From 7e9fbeccc5c28a8464ace5e4e22dfef213cb3c66 Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 17:12:12 +0000 Subject: [PATCH 4/6] mtmd: fix get_rel_pos --- tools/mtmd/clip.cpp | 174 +++++++++++++++++++++++--------------------- 1 file changed, 90 insertions(+), 84 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 45cc2328c8d55..a4bf717d0b083 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2467,101 +2467,107 @@ struct clip_graph { } // attn: [k_h*k_w, q_h*q_w] -// rel_h: [q_h, q_w, k_h] -// rel_w: [q_h, q_w, k_w] - -static ggml_tensor * add_rel_pos_inplace( - ggml_context * ctx, - ggml_tensor * attn, - ggml_tensor * rel_w, - ggml_tensor * rel_h, - int q_size -) { - - ggml_tensor *attn_4d = - ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_h_4d = - ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d - - ggml_tensor *rel_w_4d = - ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); - - ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d - - ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); - result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); - - - return result; -} - - -static ggml_tensor * get_rel_pos( - ggml_context * ctx, - ggml_tensor * rel_pos, // [L, C] - int q_size, - int k_size -) { - - const auto dtype = rel_pos->type; - - const int64_t L = rel_pos->ne[0]; // length - const int64_t C = rel_pos->ne[1]; // channels - - // ------------------------------------------------- - // 1) q_idx ← arange(0..q_size-1) [q_size] - // 2) k_idx ← arange(0..k_size-1) [k_size] - // ------------------------------------------------- - - - ggml_tensor * q_coord = ggml_cast(ctx, - ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f), - GGML_TYPE_F32); // [q_size] - ggml_tensor * k_coord = ggml_cast(ctx, - ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f), - GGML_TYPE_F32); // [k_size] + // rel_h: [q_h, q_w, k_h] + // rel_w: [q_h, q_w, k_w] + + static ggml_tensor * add_rel_pos_inplace( + ggml_context * ctx, + ggml_tensor * attn, + ggml_tensor * rel_w, + ggml_tensor * rel_h, + int q_size + ) { - ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, q_size, k_size); - q_coord = ggml_cont(ctx,ggml_repeat(ctx, q_coord, rel)); // [q_size, k_size] + ggml_tensor *attn_4d = + ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); - // broadcast reshape: - k_coord = ggml_reshape_2d(ctx, k_coord, 1, k_size); // [1, k_size] - k_coord = ggml_cont(ctx,ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + ggml_tensor *rel_h_4d = + ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); - // ------------------------------------------------- - // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling - // ------------------------------------------------- - rel = ggml_sub(ctx, k_coord, q_coord); // [q_size, k_size] + ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d - rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + ggml_tensor *rel_w_4d = + ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); - // ------------------------------------------------- - // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) - // ------------------------------------------------- + ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d - ggml_tensor * rel_clamped = ggml_clamp(ctx, rel, 0, static_cast(L - 1)); + ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); + result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]); - ggml_tensor * idx_2d = ggml_cast(ctx, rel_clamped, GGML_TYPE_I32); // [q_size, k_size] - // flatten to 1D for ggml_get_rows - const int64_t qk = static_cast(q_size) * static_cast(k_size); - ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + return result; + } - // ------------------------------------------------- - // Gather from rel_pos → [qk, C] - // ------------------------------------------------- - ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] - // reshape to final output → [q_size, k_size, C] - ggml_tensor * out = ggml_reshape_3d(ctx, gathered,rel_pos->ne[0], - q_size, - k_size); + static ggml_tensor * get_rel_pos( + ggml_context * ctx, + ggml_tensor * rel_pos, // [L, C] + int q_size, + int k_size + ) { + const int64_t C = rel_pos->ne[0]; // channels + const int64_t L = rel_pos->ne[1]; // length + + GGML_ASSERT(2*std::max(q_size, k_size) - 1 == L); + + // ------------------------------------------------- + // 1) q_idx ← arange(0..q_size-1) [q_size] + // 2) k_idx ← arange(0..k_size-1) [k_size] + // ------------------------------------------------- + + // ggml_arange always returns FP32 tensor + ggml_tensor * q_coord = ggml_arange(ctx, 0.0f, static_cast(q_size), 1.0f); // [q_size] + ggml_tensor * k_coord = ggml_arange(ctx, 0.0f, static_cast(k_size), 1.0f); // [k_size] + ggml_tensor * rel = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, k_size, q_size); + + // broadcast reshape: + q_coord = ggml_cont(ctx, + ggml_repeat(ctx, + ggml_reshape_2d(ctx, q_coord, 1, q_size), // [q_size, 1] + rel + ) + ); // [q_size, k_size] + k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + + // This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with + // the original implementation. + if (q_size != k_size) { + q_coord = ggml_scale_inplace(ctx, q_coord, std::max((float)k_size/q_size, 1.0f)); + k_coord = ggml_scale_inplace(ctx, k_coord, std::max((float)q_size/k_size, 1.0f)); + } - return out; // [q_size, k_size, C] -} + // ------------------------------------------------- + // relative_coords = q - k + (k_size - 1) // SAME as PyTorch when no scaling + // ------------------------------------------------- + + rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] + rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + // Clamp to [0, L-1] range for valid indexing + rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos->ne[1] - 1)); + + // ------------------------------------------------- + // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) + // ------------------------------------------------- + + ggml_tensor * idx_2d = ggml_cast(ctx, rel, GGML_TYPE_I32); // [q_size, k_size] + + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + + // flatten to 1D for ggml_get_rows + int qk = q_size * k_size; + ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + + // ------------------------------------------------- + // Gather from rel_pos → [qk, C] + // ------------------------------------------------- + + ggml_tensor * out = ggml_reshape_3d(ctx, gathered, C, k_size, q_size); // [qk, C] + + + return out; // [q_size, k_size, C] + } // Implementation based on approach suggested by Acly // See: https://github.com/ggml-org/llama.cpp/pull/17383#issuecomment-3554227091 From 7b8d735c901666d91f211f380ca2edc625fd72c1 Mon Sep 17 00:00:00 2001 From: bluebread Date: Fri, 21 Nov 2025 18:04:01 +0000 Subject: [PATCH 5/6] mtmd: fixed the wrong scaler for get_rel_pos --- tools/mtmd/clip.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index a4bf717d0b083..f291894b6eab7 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2529,11 +2529,14 @@ struct clip_graph { ); // [q_size, k_size] k_coord = ggml_cont(ctx, ggml_repeat(ctx, k_coord, rel)); // [q_size, k_size] + float q_scale = std::max((float)k_size/q_size, 1.0f); + float k_scale = std::max((float)q_size/k_size, 1.0f); + // This wouldn't be triggered in DeepSeek-OCR. Just for compatibility with // the original implementation. if (q_size != k_size) { - q_coord = ggml_scale_inplace(ctx, q_coord, std::max((float)k_size/q_size, 1.0f)); - k_coord = ggml_scale_inplace(ctx, k_coord, std::max((float)q_size/k_size, 1.0f)); + q_coord = ggml_scale_inplace(ctx, q_coord, q_scale); + k_coord = ggml_scale_inplace(ctx, k_coord, k_scale); } // ------------------------------------------------- @@ -2541,7 +2544,7 @@ struct clip_graph { // ------------------------------------------------- rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] - rel = ggml_scale_bias(ctx, rel, 1.0f, static_cast(k_size) - 1.0f); // [q_size, k_size] + rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size] // Clamp to [0, L-1] range for valid indexing rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos->ne[1] - 1)); From effe66958e25d860ffb12715e00ff313d821b248 Mon Sep 17 00:00:00 2001 From: bluebread Date: Sat, 22 Nov 2025 02:09:37 +0000 Subject: [PATCH 6/6] mtmd: minor changed --- tools/mtmd/clip.cpp | 36 +++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index f291894b6eab7..23d86f9575176 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -739,13 +739,14 @@ struct clip_graph { struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); - struct ggml_tensor * rel_w = ggml_cont( - ctx0, - ggml_permute(ctx0, ggml_mul_mat(ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), 0, - 2, 1, 3)); + struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0, + ggml_mul_mat(ctx0, + rw, + ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), + 0, 2, 1, 3)); struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h, W); + struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); @@ -2466,7 +2467,7 @@ struct clip_graph { return inpL; } - // attn: [k_h*k_w, q_h*q_w] + // attn: [q_h*q_w, k_h*k_w] // rel_h: [q_h, q_w, k_h] // rel_w: [q_h, q_w, k_w] @@ -2474,24 +2475,29 @@ struct clip_graph { ggml_context * ctx, ggml_tensor * attn, ggml_tensor * rel_w, - ggml_tensor * rel_h, - int q_size + ggml_tensor * rel_h ) { + const int k_w = rel_w->ne[0]; + const int k_h = rel_h->ne[0]; + const int q_w = rel_h->ne[1]; + const int q_h = rel_h->ne[2]; + + GGML_ASSERT(q_w == rel_w->ne[1]); + GGML_ASSERT(q_h == rel_w->ne[2]); + GGML_ASSERT(attn->ne[0] == k_h*k_w); + GGML_ASSERT(attn->ne[1] == q_h*q_w); - ggml_tensor *attn_4d = - ggml_reshape_4d(ctx, attn, q_size,q_size, attn->ne[1], attn->ne[2]); + ggml_tensor *attn_4d = ggml_reshape_4d(ctx, attn, k_w, k_h, attn->ne[1], attn->ne[2]); - ggml_tensor *rel_h_4d = - ggml_reshape_4d(ctx, rel_h, 1, q_size, attn->ne[1], attn->ne[2]); + ggml_tensor *rel_h_4d = ggml_reshape_4d(ctx, rel_h, 1, k_h, attn->ne[1], attn->ne[2]); ggml_tensor *rel_h_rep = ggml_repeat(ctx, rel_h_4d, attn_4d); // now same shape as attn_5d - ggml_tensor *rel_w_4d = - ggml_reshape_4d(ctx, rel_w, q_size, 1, attn->ne[1], attn->ne[2]); + ggml_tensor *rel_w_4d = ggml_reshape_4d(ctx, rel_w, k_w, 1, attn->ne[1], attn->ne[2]); ggml_tensor *rel_w_rep = ggml_repeat(ctx, rel_w_4d, attn_4d); // now same shape as attn_5d - ggml_tensor * result = ggml_add(ctx, attn_4d, ggml_add(ctx, rel_h_rep, rel_w_rep)); + ggml_tensor * result = ggml_add_inplace(ctx, attn_4d, ggml_add_inplace(ctx, rel_h_rep, rel_w_rep)); result = ggml_reshape_3d(ctx, result, attn->ne[0], attn->ne[1], attn->ne[2]);