Skip to content

Conversation

@bluebread
Copy link

@bluebread bluebread commented Nov 15, 2025

Make sure to read the contributing guidelines before submitting a PR

Implemented DeepSeek3B-MoE-A570M (the LM component of DeepSeek-OCR) but haven't tested it through.

Todo

  • Core LM implementation
  • Verify that the configuration and architecture match the original code.
  • Testing with llama-cli (without vision projector)

@bluebread bluebread changed the base branch from master to sf/deepseek-ocr November 15, 2025 17:46
@bluebread bluebread marked this pull request as ready for review November 16, 2025 08:46
@bluebread
Copy link
Author

@sfallah I've got DeepSeek3B-MoE-A570M running with llama-cli. It generates responses, but sometimes it just outputs nonsense. This doesn't happen with the original model on text-only prompt. Something is still off though I have double-checked the configuration and architecture. Probably because of the tokenizer.

(deepseek-ocr) root@13ca65024005:~/llama.cpp# ./bin/llama-cli -m ~/DeepSeek-OCR/DeepSeek-OCR-64x550M-F16.gguf  -p "Convert the document to markdown. "  > log.txt 2>&1
register_backend: registered backend CPU (1 devices)
register_device: registered device CPU (AMD Ryzen 9 9950X 16-Core Processor)
load_backend: failed to find ggml_backend_init in /root/llama.cpp/build/bin/libggml-cpu.so
build: 7001 (eab28ed3) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu (debug)
main: llama backend init
main: load the model and apply lora adapter, if any
llama_model_loader: loaded meta data with 40 key-value pairs and 155 tensors from /root/DeepSeek-OCR/DeepSeek-OCR-64x550M-F16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = deepseek2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = DeepSeek OCR
llama_model_loader: - kv   3:                       general.organization str              = Deepseek Ai
llama_model_loader: - kv   4:                         general.size_label str              = 64x550M
llama_model_loader: - kv   5:                            general.license str              = mit
llama_model_loader: - kv   6:                               general.tags arr[str,5]       = ["deepseek", "vision-language", "ocr"...
llama_model_loader: - kv   7:                          general.languages arr[str,1]       = ["multilingual"]
llama_model_loader: - kv   8:                      deepseek2.block_count u32              = 12
llama_model_loader: - kv   9:                   deepseek2.context_length u32              = 8192
llama_model_loader: - kv  10:                 deepseek2.embedding_length u32              = 1280
llama_model_loader: - kv  11:              deepseek2.feed_forward_length u32              = 6848
llama_model_loader: - kv  12:             deepseek2.attention.head_count u32              = 10
llama_model_loader: - kv  13:          deepseek2.attention.head_count_kv u32              = 10
llama_model_loader: - kv  14:                   deepseek2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  15: deepseek2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  16:                deepseek2.expert_used_count u32              = 6
llama_model_loader: - kv  17:               deepseek2.expert_group_count u32              = 1
llama_model_loader: - kv  18:          deepseek2.expert_group_used_count u32              = 1
llama_model_loader: - kv  19:                          general.file_type u32              = 1
llama_model_loader: - kv  20:        deepseek2.leading_dense_block_count u32              = 1
llama_model_loader: - kv  21:                       deepseek2.vocab_size u32              = 129280
llama_model_loader: - kv  22:       deepseek2.expert_feed_forward_length u32              = 896
llama_model_loader: - kv  23:                     deepseek2.expert_count u32              = 64
llama_model_loader: - kv  24:              deepseek2.expert_shared_count u32              = 2
llama_model_loader: - kv  25:             deepseek2.expert_weights_scale f32              = 1.000000
llama_model_loader: - kv  26:              deepseek2.expert_weights_norm bool             = false
llama_model_loader: - kv  27:               deepseek2.expert_gating_func u32              = 1
llama_model_loader: - kv  28:               general.quantization_version u32              = 2
llama_model_loader: - kv  29:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  30:                         tokenizer.ggml.pre str              = deepseek-v3
llama_model_loader: - kv  31:                      tokenizer.ggml.tokens arr[str,129280]  = ["<|begin▁of▁sentence|>", "<�...
llama_model_loader: - kv  32:                  tokenizer.ggml.token_type arr[i32,129280]  = [3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  33:                      tokenizer.ggml.merges arr[str,127741]  = ["Ġ t", "Ġ a", "i n", "Ġ Ġ", "h e...
llama_model_loader: - kv  34:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  35:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  36:            tokenizer.ggml.padding_token_id u32              = 2
llama_model_loader: - kv  37:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  38:               tokenizer.ggml.add_sep_token bool             = false
llama_model_loader: - kv  39:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - type  f32:   36 tensors
llama_model_loader: - type  f16:  119 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = F16
print_info: file size   = 5.47 GiB (16.01 BPW) 
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: printing all EOG tokens:
load:   - 1 ('<|end▁of▁sentence|>')
load: special tokens cache size = 830
load: token to piece cache size = 0.8224 MB
print_info: arch             = deepseek2
print_info: vocab_only       = 0
print_info: n_ctx_train      = 8192
print_info: n_embd           = 1280
print_info: n_embd_inp       = 1280
print_info: n_layer          = 12
print_info: n_head           = 10
print_info: n_head_kv        = 10
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 1
print_info: n_embd_k_gqa     = 1280
print_info: n_embd_v_gqa     = 1280
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 6848
print_info: n_expert         = 64
print_info: n_expert_used    = 6
print_info: n_expert_groups  = 1
print_info: n_group_used     = 1
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = linear
print_info: freq_base_train  = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 8192
print_info: rope_finetuned   = unknown
print_info: model type       = 3B
print_info: model params     = 2.93 B
print_info: general.name     = DeepSeek OCR
print_info: n_layer_dense_lead   = 1
print_info: n_lora_q             = 0
print_info: n_lora_kv            = 0
print_info: n_embd_head_k_mla    = 0
print_info: n_embd_head_v_mla    = 0
print_info: n_ff_exp             = 896
print_info: n_expert_shared      = 2
print_info: expert_weights_scale = 1.0
print_info: expert_weights_norm  = 0
print_info: expert_gating_func   = softmax
print_info: rope_yarn_log_mul    = 0.0000
print_info: vocab type       = BPE
print_info: n_vocab          = 129280
print_info: n_merges         = 127741
print_info: BOS token        = 0 '<|begin▁of▁sentence|>'
print_info: EOS token        = 1 '<|end▁of▁sentence|>'
print_info: EOT token        = 1 '<|end▁of▁sentence|>'
print_info: PAD token        = 2 '<|▁pad▁|>'
print_info: LF token         = 201 'Ċ'
print_info: FIM PRE token    = 128801 '<|fim▁begin|>'
print_info: FIM SUF token    = 128800 '<|fim▁hole|>'
print_info: FIM MID token    = 128802 '<|fim▁end|>'
print_info: EOG token        = 1 '<|end▁of▁sentence|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors:   CPU_Mapped model buffer size =  5599.34 MiB
...........................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_seq     = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (4096) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     0.49 MiB
llama_kv_cache:        CPU KV buffer size =   240.00 MiB
llama_kv_cache: size =  240.00 MiB (  4096 cells,  12 layers,  1/1 seqs), K (f16):  120.00 MiB, V (f16):  120.00 MiB
llama_context: Flash Attention was auto, set to enabled
llama_context:        CPU compute buffer size =   257.50 MiB
llama_context: graph nodes  = 654
llama_context: graph splits = 1
common_init_from_params: added <|end▁of▁sentence|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
main: llama threadpool init, n_threads = 16

system_info: n_threads = 16 (n_threads_batch = 16) / 32 | CPU : SSE3 = 1 | SSSE3 = 1 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | F16C = 1 | FMA = 1 | BMI2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | LLAMAFILE = 1 | OPENMP = 1 | REPACK = 1 | 

sampler seed: 3908055410
sampler params: 
	repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
	dry_multiplier = 0.000, dry_base = 1.750, dry_allowed_length = 2, dry_penalty_last_n = 4096
	top_k = 40, top_p = 0.950, min_p = 0.050, xtc_probability = 0.000, xtc_threshold = 0.100, typical_p = 1.000, top_n_sigma = -1.000, temp = 0.800
	mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampler chain: logits -> logit-bias -> penalties -> dry -> top-n-sigma -> top-k -> typical -> top-p -> min-p -> xtc -> temp-ext -> dist 
generate: n_ctx = 4096, n_batch = 2048, n_predict = -1, n_keep = 1

Convert the document to markdown. 2. The following question asks for a conversion to markdown. [end of text]


llama_perf_sampler_print:    sampling time =       4.10 ms /    23 runs   (    0.18 ms per token,  5605.65 tokens per second)
llama_perf_context_print:        load time =     890.72 ms
llama_perf_context_print: prompt eval time =      99.39 ms /     9 tokens (   11.04 ms per token,    90.55 tokens per second)
llama_perf_context_print:        eval time =     294.62 ms /    13 runs   (   22.66 ms per token,    44.12 tokens per second)
llama_perf_context_print:       total time =     407.44 ms /    22 tokens
llama_perf_context_print:    graphs reused =         12
llama_memory_breakdown_print: | memory breakdown [MiB] | total   free    self   model   context   compute    unaccounted |
llama_memory_breakdown_print: |   - Host               |                 6096 =  5599 +     240 +     257                |

@bluebread bluebread marked this pull request as draft November 16, 2025 09:27
@bluebread bluebread marked this pull request as ready for review November 17, 2025 08:44
@bluebread
Copy link
Author

@sfallah Fixed the bug. LM is ready now. Back to working on the vision model. Let me know if you need me to focus on any particular part.

@sfallah
Copy link
Owner

sfallah commented Nov 17, 2025

@bluebread
great!
I thought you could maybe focus on window_parttion and _unpartition, making it backend agnostic.

The original pytorch impl is actually simple:
https://huggingface.co/deepseek-ai/DeepSeek-OCR/blob/main/deepencoder.py#L850

And the same thing with get_rel_pos
https://huggingface.co/deepseek-ai/DeepSeek-OCR/blob/main/deepencoder.py#L899

FYI: I have fixed some issues that I will push before merging.

@sfallah sfallah merged commit b32bb5e into sfallah:sf/deepseek-ocr Nov 17, 2025
@bluebread bluebread changed the title [Draft] Implement DeepSeek3B-MoE-A570M (LM component) Implement DeepSeek3B-MoE-A570M (LM component) Nov 17, 2025
@bluebread
Copy link
Author

@sfallah I've implemented these operations in CUDA backend and opened a PR ggml-org#17383 to the main repository. You can get this feature from the op-dsocr-clean branch.

@sfallah
Copy link
Owner

sfallah commented Nov 19, 2025

@bluebread
cool, very good!
I was/am following a different path myself that I think is still worth pursuing.
window_partition and _unpartition can with some tricks be rewritten so they use only 4D tensors.
I got this idea inspired by something similar that is already implemented for Qwen-2.5-VL
ggml-org#12402
https://github.com/ggml-org/llama.cpp/blob/master/tools/mtmd/clip.cpp#L765

I am still investigating/experimenting with this.
The idea is to calculate the window indices for each input image and set it as input to the graph.
That way we can save the effort implementing the ops for all major backends, and I have a feeling, that the performance will be in same range like the backend ops.

@bluebread
Copy link
Author

@sfallah nice! good idea to work around. where are we at with the vision model? does it runs yet? fyi you can copy-paste the code from examples/eval-callback.cpp and set cb_eval parameter to verify the model runs as expected.

@sfallah
Copy link
Owner

sfallah commented Nov 20, 2025

@bluebread
sorry for belated reply.
The Visual Model doesn't work yet.
It is coded, and the model loading and warmup phase works, but the test fails because of a segfault error that I couldn't debug yet. There are some complications, on my side.
I have replaced the ggml_win_part, and _unpart and will focus on replacing ggml_get_rel_pos and ggml_add_rel_pos before I go on debugging, that will be more time efficient for me.

FYI: I have been using https://github.com/ggml-org/ggml/blob/master/examples/sam/sam.cpp for testing replacement of ggml_win_part . So I can test the SAM changes isolated this way.
Most of the code in my llama.cpp branch for SAM is also from this ggml example.

@sfallah
Copy link
Owner

sfallah commented Nov 21, 2025

@bluebread
I have implemented a first preliminary version of get_rel_pos and add_rel_pos functions that used standard ggml ops, without the cpu-only ones.
Can you please have a look at it.

https://github.com/sfallah/llama.cpp/blob/sf/deepseek-ocr/tools/mtmd/clip.cpp#L2473

The functions are working.
As before I tested them in the ggml sam example. I will push my ggml fork to github too.

But the clip.cpp still has an issue, so the latest commit still doesn't run.
Disclaimer: I am using ChatGPT to generate snippets that I than debug.
I will tidy up the code as soon I get visual model to run, today!

@bluebread
Copy link
Author

@sfallah No problem. I'll take a look.

@sfallah
Copy link
Owner

sfallah commented Nov 21, 2025

@bluebread
last status for today.
Image encoding works, i.e. it runs through and he embedding-output dims look good.
But the end-to-end test fails the decoding fails, i.e. the llama_decode or something before that fails.
I will fix that early tomorrow.
BTW: it runs on Metal (my mac) right now I will also test it on CUDA tomorrow.

@bluebread
Copy link
Author

bluebread commented Nov 22, 2025

@sfallah It looks like the image encoder is still failing? Am I missing anything? The .gguf files are generated by the current converting script. I would appreciate it if you could provide your test settings. Edit: I am testing on sfallah/sf/deepseek-ocr branch.

(deepseek-ocr) root@13ca65024005:~/llama.cpp# ./build/bin/llama-mtmd-cli -m /root/DeepSeek-OCR/DeepSeek-OCR-64x550M-F16.gguf --mmproj /root/DeepSeek-OCR/mmproj-DeepSeek-OCR-F16.gguf --image /root/DeepSeek-OCR-vLLM/treewisdom.png -p "Free OCR" --chat-template deepseek
ggml_cuda_init: failed to initialize CUDA: no CUDA-capable device is detected
load_backend: loaded CUDA backend from /root/llama.cpp/build/bin/libggml-cuda.so
load_backend: loaded CPU backend from /root/llama.cpp/build/bin/libggml-cpu.so
build: 7019 (1268dc3fd) with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
llama_model_loader: loaded meta data with 40 key-value pairs and 155 tensors from /root/DeepSeek-OCR/DeepSeek-OCR-64x550M-F16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = deepseek2
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = DeepSeek OCR
llama_model_loader: - kv   3:                       general.organization str              = Deepseek Ai
llama_model_loader: - kv   4:                         general.size_label str              = 64x550M
llama_model_loader: - kv   5:                            general.license str              = mit
llama_model_loader: - kv   6:                               general.tags arr[str,5]       = ["deepseek", "vision-language", "ocr"...
llama_model_loader: - kv   7:                          general.languages arr[str,1]       = ["multilingual"]
llama_model_loader: - kv   8:                      deepseek2.block_count u32              = 12
llama_model_loader: - kv   9:                   deepseek2.context_length u32              = 8192
llama_model_loader: - kv  10:                 deepseek2.embedding_length u32              = 1280
llama_model_loader: - kv  11:              deepseek2.feed_forward_length u32              = 6848
llama_model_loader: - kv  12:             deepseek2.attention.head_count u32              = 10
llama_model_loader: - kv  13:          deepseek2.attention.head_count_kv u32              = 10
llama_model_loader: - kv  14:                   deepseek2.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  15: deepseek2.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  16:                deepseek2.expert_used_count u32              = 6
llama_model_loader: - kv  17:               deepseek2.expert_group_count u32              = 1
llama_model_loader: - kv  18:          deepseek2.expert_group_used_count u32              = 1
llama_model_loader: - kv  19:                          general.file_type u32              = 1
llama_model_loader: - kv  20:        deepseek2.leading_dense_block_count u32              = 1
llama_model_loader: - kv  21:                       deepseek2.vocab_size u32              = 129280
llama_model_loader: - kv  22:       deepseek2.expert_feed_forward_length u32              = 896
llama_model_loader: - kv  23:                     deepseek2.expert_count u32              = 64
llama_model_loader: - kv  24:              deepseek2.expert_shared_count u32              = 2
llama_model_loader: - kv  25:             deepseek2.expert_weights_scale f32              = 1.000000
llama_model_loader: - kv  26:              deepseek2.expert_weights_norm bool             = false
llama_model_loader: - kv  27:               deepseek2.expert_gating_func u32              = 1
llama_model_loader: - kv  28:               general.quantization_version u32              = 2
llama_model_loader: - kv  29:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  30:                         tokenizer.ggml.pre str              = deepseek-v3
llama_model_loader: - kv  31:                      tokenizer.ggml.tokens arr[str,129280]  = ["<|begin▁of▁sentence|>", "<�...
llama_model_loader: - kv  32:                  tokenizer.ggml.token_type arr[i32,129280]  = [3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  33:                      tokenizer.ggml.merges arr[str,127741]  = ["Ġ t", "Ġ a", "i n", "Ġ Ġ", "h e...
llama_model_loader: - kv  34:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  35:                tokenizer.ggml.eos_token_id u32              = 1
llama_model_loader: - kv  36:            tokenizer.ggml.padding_token_id u32              = 2
llama_model_loader: - kv  37:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  38:               tokenizer.ggml.add_sep_token bool             = false
llama_model_loader: - kv  39:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - type  f32:   36 tensors
llama_model_loader: - type  f16:  119 tensors
print_info: file format = GGUF V3 (latest)
print_info: file type   = F16
print_info: file size   = 5.47 GiB (16.01 BPW) 
load: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect
load: printing all EOG tokens:
load:   - 1 ('<|end▁of▁sentence|>')
load: special tokens cache size = 830
load: token to piece cache size = 0.8224 MB
print_info: arch             = deepseek2
print_info: vocab_only       = 0
print_info: n_ctx_train      = 8192
print_info: n_embd           = 1280
print_info: n_embd_inp       = 1280
print_info: n_layer          = 12
print_info: n_head           = 10
print_info: n_head_kv        = 10
print_info: n_rot            = 128
print_info: n_swa            = 0
print_info: is_swa_any       = 0
print_info: n_embd_head_k    = 128
print_info: n_embd_head_v    = 128
print_info: n_gqa            = 1
print_info: n_embd_k_gqa     = 1280
print_info: n_embd_v_gqa     = 1280
print_info: f_norm_eps       = 0.0e+00
print_info: f_norm_rms_eps   = 1.0e-06
print_info: f_clamp_kqv      = 0.0e+00
print_info: f_max_alibi_bias = 0.0e+00
print_info: f_logit_scale    = 0.0e+00
print_info: f_attn_scale     = 0.0e+00
print_info: n_ff             = 6848
print_info: n_expert         = 64
print_info: n_expert_used    = 6
print_info: n_expert_groups  = 1
print_info: n_group_used     = 1
print_info: causal attn      = 1
print_info: pooling type     = 0
print_info: rope type        = 0
print_info: rope scaling     = linear
print_info: freq_base_train  = 10000.0
print_info: freq_scale_train = 1
print_info: n_ctx_orig_yarn  = 8192
print_info: rope_finetuned   = unknown
print_info: model type       = 3B
print_info: model params     = 2.93 B
print_info: general.name     = DeepSeek OCR
print_info: n_layer_dense_lead   = 1
print_info: n_lora_q             = 0
print_info: n_lora_kv            = 0
print_info: n_embd_head_k_mla    = 0
print_info: n_embd_head_v_mla    = 0
print_info: n_ff_exp             = 896
print_info: n_expert_shared      = 2
print_info: expert_weights_scale = 1.0
print_info: expert_weights_norm  = 0
print_info: expert_gating_func   = softmax
print_info: rope_yarn_log_mul    = 0.0000
print_info: vocab type       = BPE
print_info: n_vocab          = 129280
print_info: n_merges         = 127741
print_info: BOS token        = 0 '<|begin▁of▁sentence|>'
print_info: EOS token        = 1 '<|end▁of▁sentence|>'
print_info: EOT token        = 1 '<|end▁of▁sentence|>'
print_info: PAD token        = 2 '<|▁pad▁|>'
print_info: LF token         = 201 'Ċ'
print_info: FIM PRE token    = 128801 '<|fim▁begin|>'
print_info: FIM SUF token    = 128800 '<|fim▁hole|>'
print_info: FIM MID token    = 128802 '<|fim▁end|>'
print_info: EOG token        = 1 '<|end▁of▁sentence|>'
print_info: max token length = 256
load_tensors: loading model tensors, this can take a while... (mmap = true)
load_tensors:   CPU_Mapped model buffer size =  5599.34 MiB
...........................................
llama_context: constructing llama_context
llama_context: n_seq_max     = 1
llama_context: n_ctx         = 4096
llama_context: n_ctx_seq     = 4096
llama_context: n_batch       = 2048
llama_context: n_ubatch      = 512
llama_context: causal_attn   = 1
llama_context: flash_attn    = auto
llama_context: kv_unified    = false
llama_context: freq_base     = 10000.0
llama_context: freq_scale    = 1
llama_context: n_ctx_seq (4096) < n_ctx_train (8192) -- the full capacity of the model will not be utilized
llama_context:        CPU  output buffer size =     0.49 MiB
llama_kv_cache:        CPU KV buffer size =   240.00 MiB
llama_kv_cache: size =  240.00 MiB (  4096 cells,  12 layers,  1/1 seqs), K (f16):  120.00 MiB, V (f16):  120.00 MiB
llama_context: Flash Attention was auto, set to enabled
llama_context:        CPU compute buffer size =   257.50 MiB
llama_context: graph nodes  = 654
llama_context: graph splits = 1
common_init_from_params: added <|end▁of▁sentence|> logit bias = -inf
common_init_from_params: setting dry_penalty_last_n to ctx_size = 4096
common_init_from_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
Failed to infer a tool call example (possible template bug)
mtmd_cli_context: chat template example:
You are a helpful assistant### Instruction:
Hello
### Response:
Hi there
<|EOT|>
### Instruction:
How are you?
### Response:

clip_model_loader: model name:   DeepSeek OCR
clip_model_loader: description:  
clip_model_loader: GGUF version: 3
clip_model_loader: alignment:    32
clip_model_loader: n_tensors:    572
clip_model_loader: n_kv:         26

clip_model_loader: has vision encoder
clip_ctx: CLIP using CPU backend
load_hparams: projector:          deepseekocr
load_hparams: n_embd:             1024
load_hparams: n_head:             16
load_hparams: n_ff:               64
load_hparams: n_layer:            24
load_hparams: ffn_op:             gelu
load_hparams: projection_dim:     1280

--- vision hparams ---
load_hparams: image_size:         1024
load_hparams: patch_size:         16
load_hparams: has_llava_proj:     0
load_hparams: minicpmv_version:   0
load_hparams: n_merge:            0
load_hparams: n_wa_pattern:       0

load_hparams: model size:         774.24 MiB
load_hparams: metadata size:      0.20 MiB
alloc_compute_meta: warmup with image size = 1024 x 1024
alloc_compute_meta:        CPU compute buffer size =   156.04 MiB
alloc_compute_meta: graph splits = 1, nodes = 836
warmup: flash attention is enabled
main: loading model: /root/DeepSeek-OCR/DeepSeek-OCR-64x550M-F16.gguf
encoding image slice...
clip_image_batch_encode: expected output 4098 tokens, got 4161
/root/llama.cpp/tools/mtmd/clip.cpp:5522: Invalid number of output tokens
[New LWP 125037]
[New LWP 125039]
[New LWP 125040]
[New LWP 125041]
[New LWP 125042]
[New LWP 125043]
[New LWP 125044]
[New LWP 125045]
[New LWP 125046]
[New LWP 125047]
[New LWP 125048]
[New LWP 125049]
[New LWP 125050]
[New LWP 125051]
[New LWP 125052]
[New LWP 125053]
[New LWP 125054]
[Thread debugging using libthread_db enabled]
Using host libthread_db library "/usr/lib/x86_64-linux-gnu/libthread_db.so.1".
0x00007f7902f2542f in __GI___wait4 (pid=125079, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      ../sysdeps/unix/sysv/linux/wait4.c: No such file or directory.
#0  0x00007f7902f2542f in __GI___wait4 (pid=125079, stat_loc=0x0, options=0, usage=0x0) at ../sysdeps/unix/sysv/linux/wait4.c:30
30      in ../sysdeps/unix/sysv/linux/wait4.c
#1  0x00007f79033b186b in ggml_print_backtrace () from /root/llama.cpp/build/bin/libggml-base.so
#2  0x00007f79033b1a02 in ggml_abort () from /root/llama.cpp/build/bin/libggml-base.so
#3  0x00007f7903755156 in clip_image_batch_encode(clip_ctx*, int, clip_image_f32_batch const*, float*) () from /root/llama.cpp/build/bin/libmtmd.so
#4  0x00007f790375576a in clip_image_encode(clip_ctx*, int, clip_image_f32*, float*) () from /root/llama.cpp/build/bin/libmtmd.so
#5  0x00007f79037422b4 in mtmd_encode () from /root/llama.cpp/build/bin/libmtmd.so
#6  0x00007f79037d02cc in mtmd_helper_eval_chunk_single () from /root/llama.cpp/build/bin/libmtmd.so
#7  0x00007f79037d0659 in mtmd_helper_eval_chunks () from /root/llama.cpp/build/bin/libmtmd.so
#8  0x00005598edd2d586 in eval_message(mtmd_cli_context&, common_chat_msg&) ()
#9  0x00005598edd2ad31 in main ()
[Inferior 1 (process 125036) detached]
Aborted (core dumped)
(deepseek-ocr) root@13ca65024005:~/llama.cpp#

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants