diff --git a/common/arg.cpp b/common/arg.cpp index d2b81c331ca..458ca407952 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1824,6 +1824,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.image_max_tokens = value; } ).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS")); + add_opt(common_arg( + {"--dsocr-mode"}, "MODE", + "DeepSeek-OCR resolution mode, one of:\n" + "- auto (default): automatically select resolution\n" + "- tiny, small, base, large: native resolution\n" + "- gundam, gundam-master: dynamic resolution", + [](common_params & params, const std::string & value) { + if (value == "auto" || value == "tiny" || value == "small" || value == "base" || + value == "large" || value == "gundam" || value == "gundam-master") { + params.dsocr_mode = value; + } else { + throw std::invalid_argument("invalid value"); + } + } + ).set_examples(mmproj_examples).set_env("LLAMA_ARG_DSOCR_MODE")); if (llama_supports_rpc()) { add_opt(common_arg( {"--rpc"}, "SERVERS", diff --git a/common/common.h b/common/common.h index 2f23d0baa83..82d3989f10a 100644 --- a/common/common.h +++ b/common/common.h @@ -433,6 +433,7 @@ struct common_params { std::vector image; // path to image file(s) int image_min_tokens = -1; int image_max_tokens = -1; + std::string dsocr_mode = "auto"; // DeepSeek-OCR resolution mode: auto, tiny, small, base, large, gundam, gundam-master // finetune struct lr_opt lr; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 24a0465e4dc..1d8bfa0d4a6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6013,12 +6013,14 @@ def get_vision_config(self) -> dict[str, Any]: def tensor_force_quant(self, name, new_name, bid, n_dims): + # TODO: increase numercial stability. maybe delete later. + return gguf.GGMLQuantizationType.F32 # related to https://github.com/ggml-org/llama.cpp/issues/13025 - if "input_projection" in name: - return gguf.GGMLQuantizationType.F16 - if ".embeddings." in name: - return gguf.GGMLQuantizationType.F32 - return super().tensor_force_quant(name, new_name, bid, n_dims) + # if "input_projection" in name: + # return gguf.GGMLQuantizationType.F16 + # if ".embeddings." in name: + # return gguf.GGMLQuantizationType.F32 + # return super().tensor_force_quant(name, new_name, bid, n_dims) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # Only process vision-related tensors, skip language model tensors diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 687c669304d..944d00a2adc 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -214,5 +214,7 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, pixel_offset, stream); + } else { + GGML_ABORT("fatal error"); } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b99345a2e93..583195953b5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5204,6 +5204,7 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(q->ne[3] == v->ne[3]); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 20f68430784..a486ee13840 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -442,6 +443,33 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) { // debugging // + +static std::string to_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static void print_tensor_info(ggml_tensor * t) { + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + char src1_str[128] = {0}; + if (src1) { + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, to_ne_string(src1).c_str()); + } + + printf("%s: %s = %s(%s{%s}, %s)\n", + t->name, ggml_type_name(t->type), ggml_op_desc(t), + src0->name, to_ne_string(src0).c_str(), + src1 ? src1_str : ""); +} + static void print_tensor_shape(ggml_tensor * t) { printf("%s.shape = [", t->name); for (int i = 0; i < ggml_n_dims(t); ++i) { @@ -453,12 +481,50 @@ static void print_tensor_shape(ggml_tensor * t) { printf("]\n"); } +static void print_tensor_sum(ggml_tensor * t, uint8_t * data, int64_t n) { + (void) n; // unused parameter + ggml_type type = t->type; + int64_t * ne = t->ne; + size_t * nb = t->nb; + double sum = 0.0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + sum += v; + } + } + } + } + printf("%s.sum = %.6f\n", t->name, sum); +} + static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { ggml_type type = t->type; int64_t * ne = t->ne; size_t * nb = t->nb; + printf("%s.data: [\n", t->name); for (int64_t i3 = 0; i3 < ne[3]; i3++) { - printf("%s.data: [\n", t->name); + if (i3 == n && ne[3] > 2*n) { + printf(" ..., \n"); + i3 = ne[3] - n; + } + printf(" [\n"); for (int64_t i2 = 0; i2 < ne[2]; i2++) { if (i2 == n && ne[2] > 2*n) { printf(" ..., \n"); @@ -500,6 +566,120 @@ static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) { } printf(" ]\n"); } + printf(" ]\n"); +} + +static void save_tensor_to_file(const struct ggml_tensor * tensor, const uint8_t * data_ptr) { + char filename[512]; + snprintf(filename, sizeof(filename), "%s_cpp.txt", tensor->name); + + FILE * f = fopen(filename, "w"); + if (!f) { + fprintf(stderr, "Failed to open %s\n", filename); + return; + } + + // Check tensor size and warn if too large + int64_t total_elements = ggml_nelements(tensor); + fprintf(stderr, "Saving tensor %s (%lld elements) to %s\n", + tensor->name, (long long)total_elements, filename); + + if (total_elements > 10000000) { // 10M elements + fprintf(stderr, "Warning: tensor is very large (%lld elements), this may take time\n", + (long long)total_elements); + } + + const uint8_t * data = (data_ptr) ? data_ptr : (uint8_t *) tensor->data; + ggml_type type = tensor->type; + const int64_t * ne = tensor->ne; + const size_t * nb = tensor->nb; + + // Use a buffer to reduce I/O calls + const size_t BUF_SIZE = 8192; + char * buf = (char *) malloc(BUF_SIZE); + if (!buf) { + fprintf(stderr, "Failed to allocate buffer\n"); + fclose(f); + return; + } + size_t buf_pos = 0; + + // Helper lambda to flush buffer + auto flush_buf = [&]() { + if (buf_pos > 0) { + fwrite(buf, 1, buf_pos, f); + buf_pos = 0; + } + }; + + // Helper to append to buffer + auto append = [&](const char * str, size_t len) { + if (buf_pos + len >= BUF_SIZE) { + flush_buf(); + } + if (len >= BUF_SIZE) { + // String too large for buffer, write directly + fwrite(str, 1, len, f); + } else { + memcpy(buf + buf_pos, str, len); + buf_pos += len; + } + }; + + auto append_str = [&](const char * str) { + append(str, strlen(str)); + }; + + char num_buf[32]; + + // Write header once for all batches + append_str(tensor->name); + append_str(".data: [\n"); + + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + append_str(" [\n"); // Start of batch + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + append_str(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + append_str(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + float v; + if (type == GGML_TYPE_F16) { + v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]); + } else if (type == GGML_TYPE_F32) { + v = *(float *) &data[i]; + } else if (type == GGML_TYPE_I32) { + v = (float) *(int32_t *) &data[i]; + } else if (type == GGML_TYPE_I16) { + v = (float) *(int16_t *) &data[i]; + } else if (type == GGML_TYPE_I8) { + v = (float) *(int8_t *) &data[i]; + } else { + GGML_ABORT("fatal error"); + } + int len = snprintf(num_buf, sizeof(num_buf), "%8.4f", v); + append(num_buf, len); + if (i0 < ne[0] - 1) append_str(", "); + } + append_str("],\n"); + } + append_str(" ],\n"); + } + append_str(" ]"); // End of batch + if (i3 < ne[3] - 1) { + append_str(",\n"); // Comma between batches + } else { + append_str("\n"); + } + } + + append_str("]\n"); // Close the top-level array + + flush_buf(); + free(buf); + fclose(f); + fprintf(stderr, "Tensor saved successfully\n"); } // diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 3553ca3ad5c..a590c067269 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -193,8 +193,6 @@ struct clip_hparams { int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; - bool crop_mode = false; - // audio int32_t n_mel_bins = 0; // whisper preprocessor int32_t proj_stack_factor = 0; // ultravox @@ -208,6 +206,9 @@ struct clip_hparams { int32_t custom_image_min_tokens = -1; int32_t custom_image_max_tokens = -1; + // DeepSeek-OCR resolution mode + enum clip_dsocr_mode dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO; + void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) { const int cur_merge = n_merge == 0 ? 1 : n_merge; const int patch_area = patch_size * patch_size * cur_merge * cur_merge; @@ -512,6 +513,7 @@ struct clip_ctx { if (ctx_params.image_max_tokens > 0) { model.hparams.custom_image_max_tokens = ctx_params.image_max_tokens; } + model.hparams.dsocr_mode = ctx_params.dsocr_mode; backend_ptrs.push_back(backend_cpu); backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu)); @@ -680,16 +682,19 @@ struct clip_graph { const auto tgt_size = inpL->ne[1]; const auto str_size = model.pos_embed->ne[1]; if (str_size != tgt_size) { + ggml_tensor * old_pos_embed = nullptr; + old_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, model.pos_embed, 2, 0, 1, 3)); + // TODO: ggml_interpolate doesn't support bicubic model for CUDA backend ggml_tensor * new_pos_embed = ggml_interpolate( ctx0, - model.pos_embed, + old_pos_embed, tgt_size, tgt_size, enc_n_embd, 1, ggml_scale_mode::GGML_SCALE_MODE_BICUBIC ); - new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 2,1,0,3)); + new_pos_embed = ggml_cont(ctx0, ggml_permute(ctx0, new_pos_embed, 1, 2, 0, 3)); cur = ggml_add(ctx0, inpL, new_pos_embed); } else { cur = ggml_add(ctx0, inpL, model.pos_embed); @@ -698,10 +703,10 @@ struct clip_graph { // loop over layers for (int il = 0; il < _depth; il++) { auto & layer = model.sam_layers[il]; + ggml_tensor * shortcut = cur; // layernorm1 cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "enc_layer_inp_normed", il); const int64_t w0 = cur->ne[1]; const int64_t h0 = cur->ne[2]; @@ -710,7 +715,7 @@ struct clip_graph { // local attention layer - apply window partition // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L169-L172 //cur = ggml_win_part(ctx0, cur, 14); - cur = window_partition(ctx0, cur, 14); + cur = window_partition(ctx0, cur, 14); // TODO: make this configurable } const int64_t W = cur->ne[1]; @@ -718,110 +723,93 @@ struct clip_graph { // self-attention { + const int B = cur->ne[3]; + cur = ggml_mul_mat(ctx0, layer.qkv_w, cur); cur = ggml_add(ctx0, cur, layer.qkv_b); - const int B = cur->ne[3]; - - cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W * H, B); - cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 3, 1, 2)); - - ggml_tensor * Qcur = - ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 0); - Qcur = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, enc_n_heads, W * H, B); - Qcur = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); - Qcur = ggml_reshape_3d(ctx0, Qcur, enc_d_heads, W * H, B * enc_n_heads); - - ggml_tensor * Kcur = - ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], cur->nb[3]); - Kcur = ggml_reshape_4d(ctx0, Kcur, enc_d_heads, enc_n_heads, W * H, B); - Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - Kcur = ggml_reshape_3d(ctx0, Kcur, enc_d_heads, W * H, B * enc_n_heads); - - ggml_tensor * Vcur = - ggml_view_3d(ctx0, cur, enc_n_embd, W * H, B, cur->nb[1], cur->nb[2], 2 * cur->nb[3]); - Vcur = ggml_reshape_4d(ctx0, Vcur, enc_d_heads, enc_n_heads, W * H, B); - Vcur = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); // transposed - Vcur = ggml_reshape_3d(ctx0, Vcur, W * H, enc_d_heads, B * enc_n_heads); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - - - struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcur, Qcur); - - struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, 1.0f / sqrtf(enc_d_heads)); - - struct ggml_tensor * rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); - struct ggml_tensor * rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); - - struct ggml_tensor * q_r = ggml_reshape_4d(ctx0, Qcur, enc_d_heads, W, H, B * enc_n_heads); - - struct ggml_tensor * rel_w = ggml_cont(ctx0,ggml_permute(ctx0, - ggml_mul_mat(ctx0, - rw, - ggml_cont(ctx0, ggml_permute(ctx0, q_r, 0, 2, 1, 3))), - 0, 2, 1, 3)); - struct ggml_tensor * rel_h = ggml_mul_mat(ctx0, rh, q_r); - - struct ggml_tensor * attn = add_rel_pos_inplace(ctx0, KQ_scaled, rel_w, rel_h); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, attn); - - struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcur, KQ_soft_max); - - cur = ggml_reshape_4d( - ctx0, - ggml_cont(ctx0, ggml_permute(ctx0, ggml_reshape_4d(ctx0, KQV, enc_d_heads, W * H, enc_n_heads, B), - 0, 2, 1, 3)), - enc_n_embd, W, H, B); - + cur = ggml_cont(ctx0, cur); // Ensure tensor is contiguous before reshape + cur = ggml_reshape_4d(ctx0, cur, enc_n_embd, 3, W*H, B); + + ggml_tensor * Q; + ggml_tensor * K; + ggml_tensor * V; + + Q = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 0*cur->nb[1]); + Q = ggml_reshape_4d(ctx0, ggml_cont(ctx0, Q), enc_d_heads, enc_n_heads, W*H, B); + Q = ggml_cont (ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] + + K = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 1*cur->nb[1]); + K = ggml_reshape_4d(ctx0, ggml_cont(ctx0, K), enc_d_heads, enc_n_heads, W*H, B); + K = ggml_cont (ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] + + V = ggml_view_3d (ctx0, cur, enc_n_embd, W*H, B, cur->nb[2], cur->nb[3], 2*cur->nb[1]); + V = ggml_reshape_4d(ctx0, ggml_cont(ctx0, V), enc_d_heads, enc_n_heads, W*H, B); + V = ggml_cont (ctx0, ggml_permute(ctx0, V, 0, 2, 1, 3)); // [B, enc_n_heads, H*W, enc_d_heads] + + ggml_tensor * mask; + ggml_tensor * rw; + ggml_tensor * rh; + ggml_tensor * qr; + + rw = get_rel_pos(ctx0, layer.rel_pos_w, W, W); // [W, W, C] + rh = get_rel_pos(ctx0, layer.rel_pos_h, H, H); // [H, H, C] + qr = ggml_reshape_4d(ctx0, Q, enc_d_heads, W, H, B*enc_n_heads); + + const int WH_pad = GGML_PAD(W*H, GGML_KQ_MASK_PAD) - W*H; + + rw = ggml_mul_mat (ctx0, rw, ggml_cont(ctx0, ggml_permute(ctx0, qr, 0, 2, 1, 3))); // [B*enc_n_heads, W, H, W] + rw = ggml_cont (ctx0, ggml_permute(ctx0, rw, 0, 2, 1, 3)); // [B*enc_n_heads, H, W, W] + rw = ggml_reshape_4d(ctx0, rw, W, 1, W*H, enc_n_heads*B); + rw = ggml_repeat_4d (ctx0, rw, W, H, W*H, enc_n_heads*B); + rh = ggml_mul_mat (ctx0, rh, qr); // [B*enc_n_heads, H, W, H] + rh = ggml_reshape_4d(ctx0, rh, 1, H, W*H, enc_n_heads*B); + mask = ggml_add (ctx0, rw, rh); // [B*enc_n_heads, H*W, H, W] + mask = ggml_reshape_4d(ctx0, mask, W*H, W*H, enc_n_heads, B); + mask = ggml_pad (ctx0, mask, 0, WH_pad, 0, 0); + mask = ggml_cast (ctx0, mask, GGML_TYPE_F16); + + float scale = 1.0f / sqrtf((float)enc_d_heads); + cur = ggml_flash_attn_ext(ctx0, Q, K, V, mask, scale, 0.0f, 0.0f); // [B, H*W, enc_n_heads, enc_d_heads] + + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), enc_n_embd, W, H, B); cur = ggml_mul_mat(ctx0, layer.o_w, cur); cur = ggml_add_inplace(ctx0, cur, layer.o_b); } if (hparams.is_global_attn(il) == false) { // local attention layer - reverse window partition - cur = window_unpartition(ctx0, cur, w0, h0, 14); + cur = window_unpartition(ctx0, cur, w0, h0, 14); // TODO: make window size configurable } // re-add the layer input, e.g., residual - cur = ggml_add(ctx0, cur, inpL); + cur = ggml_add(ctx0, cur, shortcut); ggml_tensor * inpFF = cur; - - cb(inpFF, "ffn_inp", il); - // layernorm2 cur = build_norm(inpFF, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il); - cb(cur, "ffn_inp_normed", il); // ffn cur = build_ffn(cur, layer.ff_up_w, layer.ff_up_b, nullptr, nullptr, layer.ff_down_w, layer.ff_down_b, hparams.ffn_op, il); - cb(cur, "ffn_out", il); - - // residual 2 cur = ggml_add(ctx0, cur, inpFF); - cb(cur, "layer_out", il); + cb(cur, "sam_layer_out", il); } - cur = ggml_cont(ctx0, ggml_permute(ctx0, inpL, 2, 0, 1, 3)); - - cur = ggml_conv_2d_sk_p0(ctx0, model.neck_0_w, cur); - - cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_1_w, model.neck_1_b, hparams.eps); + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 2, 0, 1, 3)); - cur = ggml_conv_2d_s1_ph(ctx0, model.neck_2_w, cur); + const int out_chans = model.neck_0_w->ne[3]; - cur = sam_layer_norm_2d(ctx0, cur, 256, model.neck_3_w, model.neck_3_b, hparams.eps); + cur = ggml_conv_2d(ctx0, model.neck_0_w, cur, 1, 1, 0, 0, 1, 1); + cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_1_w, model.neck_1_b, hparams.eps); + cur = ggml_conv_2d(ctx0, model.neck_2_w, cur, 1, 1, 1, 1, 1, 1); + cur = sam_layer_norm_2d(ctx0, cur, out_chans, model.neck_3_w, model.neck_3_b, hparams.eps); - cur = ggml_conv_2d(ctx0, model.net_2, cur, 2,2,1,1, 1,1); - cur = ggml_conv_2d(ctx0, model.net_3, cur, 2,2,1,1, 1,1); + cur = ggml_conv_2d(ctx0, model.net_2, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_conv_2d(ctx0, model.net_3, cur, 2, 2, 1, 1, 1, 1); + cb(cur, "sam_output", -1); ggml_build_forward_expand(gf, cur); return cur; @@ -850,34 +838,32 @@ struct clip_graph { ggml_cgraph * build_deepseek_ocr() { //patch embedding ggml_tensor * inp_raw = build_inp_raw(); - - ggml_tensor * global_features_1 = build_sam_enc(inp_raw, std::max(img.nx, img.ny)); - ggml_tensor * global_features_2 = build_dp_ocr_clip(global_features_1); - + // FIXME remove n_patches is hardcoded - + // torch global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) - global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1,2,1,0,3)); + global_features_1 = ggml_cont(ctx0,ggml_permute(ctx0, global_features_1, 1, 2, 0, 3)); int clip_n_patches = global_features_1->ne[1] * global_features_1->ne[2]; - + // flatten 2nd and 3rd dims global_features_1 = ggml_reshape_2d(ctx0, global_features_1, global_features_1->ne[0], clip_n_patches); - + // remove CLS token - global_features_2 = ggml_view_2d(ctx0, global_features_2, - n_embd, clip_n_patches, - ggml_row_size(global_features_2->type, n_embd), 0); - - ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 1); + global_features_2 = ggml_view_2d(ctx0, global_features_2, n_embd, clip_n_patches, + global_features_2->nb[1], global_features_2->nb[1]); + + ggml_tensor * global_features = ggml_concat(ctx0, global_features_2, global_features_1, 0); global_features = ggml_reshape_2d(ctx0, global_features, 2* n_embd,clip_n_patches); global_features = ggml_cont(ctx0, global_features); global_features = ggml_mul_mat(ctx0, model.fc_w, global_features); global_features = ggml_add(ctx0, global_features, model.fc_b); global_features = build_global_local_features(ctx0,global_features); - global_features = ggml_cont(ctx0, ggml_permute(ctx0, global_features, 1, 0, 2, 3)); + + cb(global_features, "dsocr_output", -1); + ggml_build_forward_expand(gf, global_features); return gf; } @@ -891,30 +877,23 @@ struct clip_graph { GGML_ASSERT(model.image_newline != nullptr); GGML_ASSERT(model.view_seperator != nullptr); - // 1) global_features: [n_dim, h*w] -> [n_dim, w, h] -> [h, w, n_dim] const auto h = static_cast(std::sqrt(static_cast(global_features->ne[1]))); const auto w = h; const auto n_dim = global_features->ne[0]; - ggml_tensor * t = ggml_reshape_4d(ctx0, global_features, n_dim, h, w, 1); // (n_dim, w, h) - t = ggml_cont(ctx0, ggml_permute(ctx0, t, 2, 1, 0, 3)); // (h, w, n_dim) - ggml_tensor * nl = ggml_cont(ctx0,ggml_permute(ctx0, model.image_newline, 2, 1, 0, 3)); - nl = ggml_repeat_4d(ctx0, nl, h, 1, n_dim, 1); // n_pos rows - - - // 2) image_newline: [n_dim] -> [1, 1, n_dim] -> repeat to [h, 1, n_dim] - t = ggml_concat(ctx0, t, nl, 1); // (h, w+1, n_dim) - - t = ggml_reshape_2d(ctx0, t, n_dim, h* (h + 1)); // (n_dim, h*(w+1)) - - // 5) append view_separator as an extra "token": - // view_separator: [n_dim] -> [n_dim, 1] - ggml_tensor * vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + ggml_tensor * cur; + ggml_tensor * imgnl; + ggml_tensor * vs; - // concat along token dimension (dim=1): - t = ggml_concat(ctx0, t, vs, 1); // (n_dim, h*(w+1) + 1) + cur = ggml_reshape_3d(ctx0, global_features, n_dim, w, h); + imgnl = ggml_repeat_4d(ctx0, model.image_newline, n_dim, 1, h, 1); + cur = ggml_reshape_2d(ctx0, ggml_concat(ctx0, cur, imgnl, 1), n_dim, (w+1)*h); + cb(cur, "insert_imgnl", -1); + vs = ggml_reshape_2d(ctx0, model.view_seperator, n_dim, 1); // (n_dim, 1) + cur = ggml_concat(ctx0, cur, vs, 1); // (n_dim, h*(w+1) + 1) + cb(cur, "insert_vs", -1); - return t; + return cur; } @@ -1569,8 +1548,8 @@ struct clip_graph { ggml_tensor * inp = ggml_cpy(ctx0, patch_embeds, ggml_dup_tensor(ctx0, patch_embeds)); - inp = ggml_cont(ctx0,ggml_permute(ctx0, inp,2,1,0,3)); - inp = ggml_reshape_2d(ctx0, inp, n_embd, inp->ne[1]*inp->ne[2]*inp->ne[3]); + inp = ggml_reshape_2d(ctx0, inp, inp->ne[0]*inp->ne[1], inp->ne[2]); + inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3)); ggml_tensor * new_pos_embd = ggml_cpy(ctx0, model.position_embeddings, ggml_dup_tensor(ctx0, model.position_embeddings)); @@ -1601,7 +1580,7 @@ struct clip_graph { // add CLS token - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); //TODO : check norm type for dp-ocr-clip norm_type norm_t = NORM_TYPE_NORMAL; @@ -1610,9 +1589,8 @@ struct clip_graph { ggml_tensor * positions = ggml_cast(ctx0, ggml_arange(ctx0, 0, n_pos, 1), GGML_TYPE_I32); ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, new_pos_embd, positions); - - ggml_tensor * cur = build_vit(inp, n_pos, norm_t, hparams.ffn_op, learned_pos_embd, - nullptr); // shape [1024, 16, 16] + ggml_tensor * cur = build_vit(inp, n_pos, norm_t, ffn_op_type::FFN_GELU_QUICK, + learned_pos_embd, nullptr); // shape [1024, 16, 16] ggml_build_forward_expand(gf, cur); @@ -2576,19 +2554,27 @@ struct clip_graph { if (max_rel_dist != L) { // Linear interpolation - const auto scale = L / static_cast(max_rel_dist); - ggml_tensor * indices = ggml_arange(ctx, 0.0f, static_cast(max_rel_dist), 1.0f); - indices = ggml_scale_inplace(ctx, indices, scale); - ggml_tensor * indices_floor= ggml_cast(ctx, ggml_floor(ctx, indices), GGML_TYPE_I32); - ggml_tensor * indices_ceil = ggml_cast(ctx, ggml_ceil(ctx, indices), GGML_TYPE_I32); - ggml_tensor * weights = ggml_sub(ctx, indices, indices_floor); - ggml_tensor * ws1 = ggml_scale_bias(ctx, weights, -1.0f, 1.0f); - rel_pos_resized = ggml_cont(ctx , ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); // [C, L] for ggml_get_rows - ggml_tensor * rs1 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_floor), 1, 0, 2, 3)); // lower rows - rs1 = ggml_mul(ctx, rs1, ws1); // lower rows - ggml_tensor * rs2 = ggml_cont(ctx, ggml_permute(ctx, ggml_get_rows(ctx, rel_pos_resized, indices_ceil), 1, 0, 2, 3)); // upper rows - rs2 = ggml_mul(ctx, rs2, weights); // upper rows - rel_pos_resized = ggml_add(ctx,rs1, rs2); + int64_t ne0 = rel_pos_resized->ne[0]; + int64_t ne1 = rel_pos_resized->ne[1]; + int64_t ne2 = rel_pos_resized->ne[2]; + int64_t ne3 = rel_pos_resized->ne[3]; + + rel_pos_resized = ggml_reshape_3d( + ctx, + ggml_cont(ctx, ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)), + ne1, 1, ne0*ne2*ne3 + ); + rel_pos_resized = ggml_reshape_4d( + ctx, + ggml_interpolate( + ctx, + rel_pos_resized, + max_rel_dist, 1, ne0*ne2*ne3, 1, + ggml_scale_mode::GGML_SCALE_MODE_BILINEAR + ), + max_rel_dist, ne0, ne2, ne3 + ); + rel_pos_resized = ggml_cont(ctx, ggml_permute(ctx, rel_pos_resized, 1, 0, 2, 3)); } // ------------------------------------------------- @@ -2627,7 +2613,7 @@ struct clip_graph { rel = ggml_sub(ctx, q_coord, k_coord); // [q_size, k_size] rel = ggml_scale_bias(ctx, rel, 1.0f, (k_size - 1.0f)*k_scale); // [q_size, k_size] // Clamp to [0, L-1] range for valid indexing - rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos->ne[1] - 1)); + rel = ggml_clamp(ctx, rel, 0.0f, static_cast(rel_pos_resized->ne[1] - 1)); // ------------------------------------------------- // clamp to [0, L-1] and cast to int32 (for ggml_get_rows) @@ -2641,7 +2627,7 @@ struct clip_graph { // flatten to 1D for ggml_get_rows int qk = q_size * k_size; ggml_tensor * idx_flat = ggml_reshape_1d(ctx, idx_2d, qk); // [qk] - ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos, idx_flat); // [qk, C] + ggml_tensor * gathered = ggml_get_rows(ctx, rel_pos_resized, idx_flat); // [qk, C] // ------------------------------------------------- // Gather from rel_pos → [qk, C] @@ -2671,7 +2657,7 @@ struct clip_graph { } x = ggml_reshape_4d(ctx, x, c * window, npw, window, nph * b); x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); - x = ggml_reshape_4d(ctx, x, c, window ,window, npw * nph * b); + x = ggml_reshape_4d(ctx, x, c, window, window, npw * nph * b); return x; } @@ -3419,7 +3405,6 @@ struct clip_model_loader { hparams.patch_size = 16; hparams.image_size = 1024; hparams.warmup_image_size = 1024; - hparams.crop_mode = false; } break; default: break; @@ -5070,9 +5055,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } break; case PROJECTOR_TYPE_DEEPSEEKOCR: - if (!params.crop_mode) { - /* Native Resolution (Tiny/Small/Base/Large) */ - + { const int native_resolutions[] = { 512 /* tiny */, 640 /* small */, 1024 /* base */, 1280 /* large */ }; @@ -5080,29 +5063,49 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const int orig_w = original_size.width; const int orig_h = original_size.height; const int orig_area = orig_h * orig_w; - - // mode selection logic (find most suitable resolution) + std::array color; + + for (int i = 0; i < 3; i++) { + color[i] = (int)(255 * params.image_mean[i]); + } + int mode_i = 0; - int min_diff = orig_area; - - for (int i = 0; i < 4; i++) { - int r = native_resolutions[i]; - if (std::abs(orig_area - r*r) < min_diff) { - mode_i = i; - min_diff = std::abs(orig_area - r*r); + + if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_TINY) { + mode_i = 0; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_SMALL) { + mode_i = 1; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_BASE) { + mode_i = 2; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_LARGE) { + mode_i = 3; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM) { + mode_i = 4; + } else if (params.dsocr_mode == clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM_MASTER) { + mode_i = 5; + } else { + if (params.dsocr_mode != clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO) { + LOG_WRN("%s: unknown dsocr_mode, using auto mode\n", __func__); + } + int min_diff = orig_area; + for (int i = 0; i < 4; i++) { + int r = native_resolutions[i]; + if (std::abs(orig_area - r*r) < min_diff) { + mode_i = i; + min_diff = std::abs(orig_area - r*r); + } } } - const int image_size = native_resolutions[mode_i]; - if (mode_i < 2) { - // TINY/SMALL MODE: Direct resize (no slicing) + /* Native Resolution (Tiny/Small) */ + const int image_size = native_resolutions[mode_i]; + // Just resize the image to image_size × image_size - clip_image_u8_ptr resized_img(clip_image_u8_init()); img_tool::resize(*img, *resized_img, clip_image_size{image_size, image_size}, - img_tool::RESIZE_ALGO_BICUBIC); // Match PIL default + img_tool::RESIZE_ALGO_BICUBIC, true, color); // Match PIL default clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(*resized_img, *res, params.image_mean, params.image_std); @@ -5111,10 +5114,11 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_x = 1; res_imgs->grid_y = 1; } - else { - // BASE/LARGE MODE: Resize with aspect ratio + padding + else if (mode_i < 4) { + /* Native Resolution (Base/Large) */ + const int image_size = native_resolutions[mode_i]; + // Resize maintaining aspect ratio, then pad to square - float scale = std::min( static_cast(image_size) / orig_w, static_cast(image_size) / orig_h @@ -5124,14 +5128,14 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str clip_image_u8_ptr scaled_img(clip_image_u8_init()); img_tool::resize(*img, *scaled_img, clip_image_size{new_w, new_h}, - img_tool::RESIZE_ALGO_BICUBIC); + img_tool::RESIZE_ALGO_BICUBIC, true, color); // Use mean color for padding unsigned char pad_r = static_cast(params.image_mean[0] * 255.0f); unsigned char pad_g = static_cast(params.image_mean[1] * 255.0f); unsigned char pad_b = static_cast(params.image_mean[2] * 255.0f); - // Step 2: Pad to image_size × image_size (center padding) + // Pad to image_size × image_size (center padding) clip_image_u8_ptr padded_img(clip_image_u8_init()); padded_img->nx = image_size; padded_img->ny = image_size; @@ -5159,7 +5163,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } } - // Step 3: Normalize and output + // Normalize and output clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(*padded_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); @@ -5167,68 +5171,69 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_x = 1; res_imgs->grid_y = 1; } - } - else { - /* Dynamic Resolution (Gundam/Gundam-M) */ - - // configurable, or read from params - const int min_num = 2; - const int max_num = 9; - const int image_size = params.image_size; // typically 640 - // const bool use_thumbnail = true; // mimic python's use_thumbnail - - // original image size - const int orig_w = original_size.width; - const int orig_h = original_size.height; - - // 1) build candidate grids (cols, rows) - auto target_ratios = ds_build_target_ratios(min_num, max_num); - - // 2) pick the grid that best matches the original aspect ratio - const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); - auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); - const int grid_cols = best.first; // how many tiles horizontally - const int grid_rows = best.second; // how many tiles vertically - - // 3) compute the target (forced) size — python did: - // target_width = image_size * cols - // target_height = image_size * rows - const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; - - // 4) prepare slice instructions, same style as the idefics3 branch - llava_uhd::slice_instructions instructions; - instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global - instructions.refined_size = refined_size; - instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; - - // in deepseek python they always produce *full* 640x640 blocks, - // so we can do a simple double loop over rows/cols: - for (int r = 0; r < grid_rows; ++r) { - for (int c = 0; c < grid_cols; ++c) { - const int x = c * image_size; - const int y = r * image_size; - - instructions.slices.push_back(llava_uhd::slice_coordinates{ - /* x */ x, - /* y */ y, - /* size */ clip_image_size{ image_size, image_size } - }); + else { + GGML_ABORT("DeepSeek-OCR: Gundam/Gundam-Master haven't been tested yet.\n"); + /* Dynamic Resolution (Gundam/Gundam-Master) */ + + // configurable, or read from params + const int min_num = 2; + const int max_num = 9; + const int image_size = params.image_size; // typically 640 + // const bool use_thumbnail = true; // mimic python's use_thumbnail + + // original image size + const int orig_w = original_size.width; + const int orig_h = original_size.height; + + // 1) build candidate grids (cols, rows) + auto target_ratios = ds_build_target_ratios(min_num, max_num); + + // 2) pick the grid that best matches the original aspect ratio + const float aspect_ratio = static_cast(orig_w) / static_cast(orig_h); + auto best = ds_find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_w, orig_h, image_size); + const int grid_cols = best.first; // how many tiles horizontally + const int grid_rows = best.second; // how many tiles vertically + + // 3) compute the target (forced) size — python did: + // target_width = image_size * cols + // target_height = image_size * rows + const clip_image_size refined_size{ image_size * grid_cols, image_size * grid_rows }; + + // 4) prepare slice instructions, same style as the idefics3 branch + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{ image_size, image_size }; // for thumbnail/global + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ grid_cols, grid_rows }; + + // in deepseek python they always produce *full* 640x640 blocks, + // so we can do a simple double loop over rows/cols: + for (int r = 0; r < grid_rows; ++r) { + for (int c = 0; c < grid_cols; ++c) { + const int x = c * image_size; + const int y = r * image_size; + + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */ x, + /* y */ y, + /* size */ clip_image_size{ image_size, image_size } + }); + } } + + // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) + auto imgs = llava_uhd::slice_image(img, instructions); + + // 7) cast & normalize like the idefics3 branch + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + // keep the grid info — the model may need to know how to reassemble / attend + res_imgs->grid_x = grid_cols; + res_imgs->grid_y = grid_rows; } - - // 5) run the actual slicing (this should: resize to refined_size, then crop every slice) - auto imgs = llava_uhd::slice_image(img, instructions); - - // 7) cast & normalize like the idefics3 branch - for (size_t i = 0; i < imgs.size(); ++i) { - clip_image_f32_ptr res(clip_image_f32_init()); - normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); - res_imgs->entries.push_back(std::move(res)); - } - - // keep the grid info — the model may need to know how to reassemble / attend - res_imgs->grid_x = grid_cols; - res_imgs->grid_y = grid_rows; } break; @@ -5415,12 +5420,15 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_DEEPSEEKOCR: { - int x_patch = img->nx / (params.patch_size); - - n_patches += x_patch + 1; - n_patches = 1280; - - + // SAM encoder applies two stride-2 convolutions (net_2 and net_3) + // which reduces spatial dimensions by 4x in each direction (16x total) + // E.g., 64x64 -> 16x16 patches + n_patches /= 16; + + // build_global_local_features adds image newlines and view separator + // Formula: h*(w+1) + 1 where h = w = sqrt(n_patches) + int h = static_cast(std::sqrt(static_cast(n_patches))); + n_patches = h * (h + 1) + 1; } break; default: GGML_ABORT("unsupported projector type"); @@ -5803,8 +5811,27 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima for (ggml_tensor * t : ctx->debug_print_tensors) { std::vector data(ggml_nbytes(t)); ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + print_tensor_info(t); print_tensor_shape(t); - print_tensor_data(t, data.data(), 3); + print_tensor_sum(t, data.data(), 3); + std::string tname_s = std::string(t->name); + + bool is_stored = false; + std::vector patterns = { + /* Add tensor names here to dump (e.g. "sam_output") */ + }; + + for (auto & p : patterns) { + if (tname_s == p) { + save_tensor_to_file(t, data.data()); + is_stored = true; + break; + } + } + + if (!is_stored) { + print_tensor_data(t, data.data(), 3); + } } } diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index eb96b389cfb..c0b191dcf30 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -29,11 +29,22 @@ enum clip_flash_attn_type { CLIP_FLASH_ATTN_TYPE_ENABLED = 1, }; +enum clip_dsocr_mode { + CLIP_DSOCR_MODE_AUTO, + CLIP_DSOCR_MODE_TINY, + CLIP_DSOCR_MODE_SMALL, + CLIP_DSOCR_MODE_BASE, + CLIP_DSOCR_MODE_LARGE, + CLIP_DSOCR_MODE_GUNDAM, + CLIP_DSOCR_MODE_GUNDAM_MASTER, +}; + struct clip_context_params { bool use_gpu; enum clip_flash_attn_type flash_attn_type; int image_min_tokens; int image_max_tokens; + enum clip_dsocr_mode dsocr_mode; }; struct clip_init_result { diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index bd52341e357..f30ec1bcbf4 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -138,6 +138,7 @@ struct mtmd_cli_context { mparams.flash_attn_type = params.flash_attn_type; mparams.image_min_tokens = params.image_min_tokens; mparams.image_max_tokens = params.image_max_tokens; + mparams.dsocr_mode = params.dsocr_mode.c_str(); ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams)); if (!ctx_vision.get()) { LOG_ERR("Failed to load vision model from %s\n", clip_path); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index f6ac40ba911..0c360f13741 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -110,6 +110,7 @@ mtmd_context_params mtmd_context_params_default() { /* flash_attn_type */ LLAMA_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ -1, /* image_max_tokens */ -1, + /* dsocr_mode */ "auto", }; return params; } @@ -172,11 +173,32 @@ struct mtmd_context { throw std::runtime_error("media_marker must not be empty"); } + enum clip_dsocr_mode dsocr_mode; + + if (std::string(ctx_params.dsocr_mode) == "auto") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_AUTO; + } else if (std::string(ctx_params.dsocr_mode) == "tiny") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_TINY; + } else if (std::string(ctx_params.dsocr_mode) == "small") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_SMALL; + } else if (std::string(ctx_params.dsocr_mode) == "base") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_BASE; + } else if (std::string(ctx_params.dsocr_mode) == "large") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_LARGE; + } else if (std::string(ctx_params.dsocr_mode) == "gundam") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM; + } else if (std::string(ctx_params.dsocr_mode) == "gundam-master") { + dsocr_mode = clip_dsocr_mode::CLIP_DSOCR_MODE_GUNDAM_MASTER; + } else { + throw std::invalid_argument("invalid value"); + } + clip_context_params ctx_clip_params { /* use_gpu */ ctx_params.use_gpu, /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, + /* dsocr_mode */ dsocr_mode, }; auto res = clip_init(mmproj_fname, ctx_clip_params); diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index bc4c9a57bda..3dc34ae3b77 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -86,6 +86,9 @@ struct mtmd_context_params { // limit number of image tokens, only for vision models with dynamic resolution int image_min_tokens; // minimum number of tokens for image input (default: read from metadata) int image_max_tokens; // maximum number of tokens for image input (default: read from metadata) + + // DeepSeek-OCR resolution mode + const char * dsocr_mode; // one of: auto, tiny, small, base, large, gundam, gundam-master }; MTMD_API const char * mtmd_default_marker(void);