Skip to content

quantrpeter/run-raw-model-using-raw-method-java

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

run-raw-model-using-raw-method-java

A from-scratch CPU implementation of the DeepSeek-R1-Distill-Qwen-1.5B (Qwen2 architecture) transformer in plain Java, with no third-party inference library in the hot path. We don't pull in llama.cpp, ggml, DJL, ONNX Runtime, ND4J, TornadoVM, Vector API helpers, or anything similar — just the JDK (java.nio for the file mmap, java.util.concurrent.ForkJoinPool for the matmul outer loop).

The companion projects are the same idea in different languages:

The forward pass, tokenizer and GGUF parser here all match the C implementation in behavior; running both with the same prompt and --temp 0 produces token-for-token identical output (verified on DeepSeek-R1-Distill-Qwen-1.5B, F16).

prompt
   │
   ▼
[Tokenizer.java]  byte-level BPE  (regex pre-tokenize → merge greedily)
   │   ids
   ▼
[Qwen2.java]   for each token:
     embed →┐
            ▼
   ┌──────────────────────────────────────────────┐
   │  for L in 0..27:                             │
   │    rmsnorm                                    │
   │    Q,K,V proj  (matmul + bias)                │
   │    rope (rotate-halves, theta=10 000)         │
   │    GQA attention (12 q-heads, 2 kv-heads)     │
   │    softmax(Q·K^T / √d)                        │
   │    output proj + residual                     │
   │    rmsnorm                                    │
   │    SwiGLU FFN (gate × up → silu → down)       │
   │    residual                                   │
   └──────────────────────────────────────────────┘
            │
   final rmsnorm + LM head (matmul over n_vocab)
            │  logits
            ▼
[Qwen2.java]   sample (greedy / temp + top-k + top-p)
            │  next id
            ▼
[Tokenizer.java]  detokenize → bytes → stdout

Files

.
├── Makefile          # mvn package / java -jar wrappers
├── pom.xml           # Maven build (no runtime dependencies)
├── README.md
└── src/main/java/io/github/runrawqwen2/
    ├── Gguf.java        # mmap a .gguf, parse header / KVs / tensor table
    ├── Tokenizer.java   # GPT-2/Qwen2 byte-level BPE encode + decode
    ├── Qwen2.java       # Qwen2 model loader + forward pass + sampler
    └── Main.java        # CLI driver (parse args → load → prefill → generate)

Prerequisites

  • JDK 17 or newer. Tested on OpenJDK 21 and 25 on macOS aarch64.
  • Maven 3.6+. brew install maven on macOS.
  • A GGUF file produced from the original Hugging Face checkpoint with --outtype f16 or f32. We don't implement Q*-style quantization (Q4_K, Q6_K, …) so a quantized GGUF will be rejected with a clear error message.
  • ~4 GB of free RAM for the F16 weights + KV cache.

The simplest source of a usable GGUF is the sibling project, which has a make convert target:

cd ../run-raw-model
make convert OUTTYPE=f16    # produces ../run-raw-model/R1-Distill-1.5B.gguf

The Makefile here defaults to that exact path.

Quick start

make build                                     # mvn package
make run                                       # java -jar ... defaults
make run PROMPT="Why is the sky blue?" TEMP=0  # greedy sampling

Or call the jar directly:

java -Xmx4g -jar target/run-raw-qwen2-0.1.0.jar \
    -m ../run-raw-model/R1-Distill-1.5B.gguf \
    -p "The capital of France is" \
    -n 128 \
    --temp 0

The first run will print model metadata to stderr followed by the prompt and generated text to stdout, then a tok/s line at the end.

CLI reference

Usage: java -jar run-raw-qwen2.jar -m <model.gguf> [options]
  -m   <path>     path to GGUF model (required)
  -p   <text>     prompt (default: "Hello, my name is")
  -n   <int>      number of tokens to generate (default: 128)
  -c   <int>      context size (default: 2048)
  --temp <float>  temperature, <=0 = greedy (default: 0.8)
  --top-k <int>   top-k (default: 40)
  --top-p <float> top-p (default: 0.95)
  --seed <long>   RNG seed (default: time-based)
  --chat          wrap prompt with the Qwen2 chat template
                  (<|im_start|>user\n…<|im_end|>\n<|im_start|>assistant\n)

How it works (slightly more detail)

1. GGUF parsing (Gguf.java)

GGUF is a self-describing single-file format produced by convert_hf_to_gguf.py. We FileChannel.map() the header section once to walk the descriptor table, then map each tensor's data range as its own MappedByteBuffer. Tensor data is never copied; we hold weak views (ShortBuffer, FloatBuffer) into the mappings for the lifetime of the run.

The parser walks three sections:

  • Header: 4-byte magic "GGUF", version (must be 3), tensor count, KV count.
  • KV metadata: an array of (string key, type, value) triples. Strings, ints, floats, bools, plus arrays of any of those. Used both for hyperparameters (qwen2.embedding_length, qwen2.block_count, …) and for the tokenizer (tokenizer.ggml.tokens, tokenizer.ggml.merges, tokenizer.ggml.bos_token_id, …).
  • Tensor descriptors: name, ndims, dims, dtype, byte offset. Tensor data starts at the next general.alignment-aligned offset past the descriptor table (default alignment = 32 bytes).

We support F32, F16 and BF16 as tensor dtypes — anything else is rejected at load time with an error.

A single MappedByteBuffer in Java tops out at 2 GiB (the position / limit / capacity API is int-typed), so we map each tensor individually rather than the whole 3+ GiB file at once. For Qwen2-1.5B the largest single tensor (token_embd.weight) is ~470 MiB, well within the limit.

2. Tokenizer (Tokenizer.java)

Qwen2 uses GPT-2-family byte-level BPE. Encoding has three stages:

  1. Byte-to-Unicode remap. Each input byte 0..255 becomes a single printable Unicode codepoint. Bytes that print fine in ASCII map to themselves; everything else is shifted into U+0100..U+02FF. This keeps BPE merges working purely in printable space.

  2. Pre-tokenization. Split the remapped text into chunks using the GPT-4 regex (the same one Qwen-Tokenizer ships with):

    (?i:'s|'t|'re|'ve|'m|'ll|'d) | [^\r\n\p{L}\p{N}]?\p{L}+
    | \p{N}{1,3} | ?[^\s\p{L}\p{N}]+[\r\n]* | \s*[\r\n]+
    | \s+(?!\S) | \s+
    

    We re-implement this by hand (no java.util.regex) by classifying each codepoint as letter / digit / space / newline / other and greedily matching alternatives top-to-bottom.

  3. BPE merges. Inside each chunk we keep a doubly-linked list of "symbols" (initially one per codepoint), repeatedly find the adjacent pair with the lowest merge rank, and combine it. Final symbols are looked up in the vocab to get token ids.

Special / control tokens (<|im_start|>, <|endoftext|>, …) are detected by a longest-prefix scan over the special-token byte sequences before falling back into BPE, and are emitted verbatim during decoding instead of being byte-decoded.

3. Forward pass (Qwen2.java)

For each generated token we:

  1. Embed by indexing row id of token_embd.weight (shape [n_embd × n_vocab]).
  2. For each of the 28 transformer blocks:
    • RMSNorm: y = x * w / sqrt(mean(x²) + ε), ε = 1e-6.
    • Q/K/V projections (matmul + bias). Qwen2 has biases on Q,K,V and not on O. Q has 12 heads × 128 dim = 1536. K and V have 2 heads × 128 = 256 (grouped-query attention, group size 6). K and V are written straight into the per-layer KV cache slot for the current position.
    • RoPE on Q and the freshly-computed K column. NeoX-style rotated halves with θ = 10 000.
    • Attention: for each query head h, find the matching KV head h / 6, compute softmax(Q·Kᵀ / √128) over all positions 0..pos, then weight the V cache by those scores.
    • Output projection + residual into x.
    • RMSNorm again.
    • SwiGLU FFN: down( silu(gate(x)) * up(x) ), intermediate size 8 960. Residual back into x.
  3. Final RMSNorm.
  4. LM head: logits = x · output.weight (or tok_embd.weight if embeddings are tied; for this checkpoint they're not).

All matmuls dequantize F16/BF16 weights to F32 on the fly (a 65 536-entry lookup table for F16, a single shift for BF16). The matmul outer loop runs as a ForkJoinPool.submit(IntStream.range(...). parallel().forEach(...)) so on a multi-core CPU the runtime scales close to linearly with cores until you hit memory bandwidth.

The KV cache is [n_layer × n_ctx × n_kv_heads × head_dim] floats — about 117 MB at n_ctx = 2048.

4. Sampling (Qwen2.java)

temperature ≤ 0 is greedy argmax. Otherwise we apply temperature, softmax, sort by descending probability, truncate to top-k, then truncate further to the smallest prefix whose cumulative mass ≥ top-p, renormalize, and draw with a self-contained xorshift64 RNG.

The descending sort is done with a primitive Arrays.sort(long[]): each (probability, id) pair packs into one long (high 32 bits = -floatToRawIntBits(prob), low 32 bits = id), so an ordinary ascending long sort yields descending probability order with no autoboxing.

Limitations vs llama.cpp

This project is intentionally minimalist:

  • No quantized weights. Only F32, F16, BF16 tensors are read. Re-convert with --outtype f16 if your GGUF has Q4_K/Q6_K etc.
  • CPU only. No Metal, CUDA, JNI BLAS or Vector API kernels. The parallelism comes from a single shared ForkJoinPool over the matmul output dimension. Expect roughly the same throughput as the C reference per core; real llama.cpp will still be much faster.
  • Single token at a time. Prefill walks the prompt one token per forward call. A batched matmul would be much faster but obscures the algorithm.
  • Qwen2 only. The forward pass is hard-coded for the Qwen2 architecture (RoPE, SwiGLU, RMSNorm, GQA, biased Q/K/V). Other architectures (LLaMA, Mistral, Phi, Gemma, …) are similar but each needs its own analogue.
  • Approximate pre-tokenizer. We treat every non-ASCII codepoint as \p{L}. For prompts that mix alphanumeric writing systems (CJK, Arabic, Cyrillic …) tokenization should still match the reference; it diverges only on edge cases involving Unicode punctuation / digits, which are rare in practice.
  • 2 GiB per-tensor cap. Java's MappedByteBuffer indexes with an int, so any single tensor must fit in 2 GiB. Fine for Qwen2-1.5B, but a 70B-class checkpoint with a 4-billion-element output projection would need to be split.

License

Public domain / 0-BSD — do whatever you want with the code. The model weights remain under their original Hugging Face license.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors