From 5381b9cf635c88961845b0776041bcbca561e2f4 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:13:29 +0100 Subject: [PATCH 1/3] using common build_attn in sam --- tools/mtmd/clip.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index d1bed23d030..b9bcfafa1c1 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2789,15 +2789,12 @@ struct clip_graph { Q = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), d_heads, n_heads, W*H, B); - Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] K = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), d_heads, n_heads, W*H, B); - K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] V = ggml_view_3d (ctx0, cur, n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), d_heads, n_heads, W*H, B); - V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, n_heads, H*W, d_heads] ggml_tensor * mask; ggml_tensor * rw; @@ -2806,7 +2803,7 @@ struct clip_graph { rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] - qr = ggml_reshape_4d(ctx0, Q, d_heads, W, H, B*n_heads); + qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)), d_heads, W, H, B*n_heads); const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; @@ -2822,11 +2819,18 @@ struct clip_graph { mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); float scale = 1.0f / sqrtf((float)d_heads); - cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, n_heads, d_heads] + cur = build_attn( + layer.o_w, + layer.o_b, + Q, + K, + V, + mask, + scale, + il + ); // [B, H*W, n_embd] cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B); - cur = ggml_mul_mat(ctx0, layer.o_w, cur); - cur = ggml_add_inplace(ctx0, cur, layer.o_b); } if (hparams.is_global_attn(il) == false) { From 076138a428512a977935539d570b0e90ae4d990e Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Thu, 4 Dec 2025 23:45:59 +0100 Subject: [PATCH 2/3] corrected code-branch when flash-attn disabled enabling usage of --flash-attn option --- tools/mtmd/clip.cpp | 10 ++++------ tools/mtmd/mtmd.cpp | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index b9bcfafa1c1..2cd72b8872b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2590,10 +2590,7 @@ struct clip_graph { } else { ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); v = ggml_cont(ctx0, v); - - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); // F32 may not needed for vision encoders? // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -2601,8 +2598,9 @@ struct clip_graph { kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + } cb(cur, "kqv_out", il); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 2c20af099b9..791ac771668 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -175,7 +175,7 @@ struct mtmd_context { clip_context_params ctx_clip_params { /* use_gpu */ ctx_params.use_gpu, - /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, + /* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type), /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, /* warmup */ ctx_params.warmup, From f5bd310a5ee228ce7cd9487d13e316e0d437b591 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Fri, 5 Dec 2025 09:30:58 +0100 Subject: [PATCH 3/3] minor formatting and style --- tools/mtmd/clip.cpp | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 2cd72b8872b..af03a8fe2e2 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2598,8 +2598,8 @@ struct clip_graph { kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f); ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - cur = ggml_cont(ctx0, ggml_permute(ctx0, kqv, 0, 2, 1, 3)); - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); + cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]); } @@ -2801,7 +2801,8 @@ struct clip_graph { rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] - qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)), d_heads, W, H, B*n_heads); + qr = ggml_permute(ctx0, Q, 0, 2, 1, 3); + qr = ggml_reshape_4d(ctx0, ggml_cont(ctx0, qr), d_heads, W, H, B * n_heads); const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; @@ -2818,16 +2819,8 @@ struct clip_graph { float scale = 1.0f / sqrtf((float)d_heads); - cur = build_attn( - layer.o_w, - layer.o_b, - Q, - K, - V, - mask, - scale, - il - ); // [B, H*W, n_embd] + cur = build_attn(layer.o_w, layer.o_b, Q, K, V, mask, scale, + il); // [B, H*W, n_embd] cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), n_embd, W, H, B); }