From 499e4a97acb516d12a36710bb3a17db56da043e6 Mon Sep 17 00:00:00 2001 From: Vincent Palmer Date: Wed, 22 Apr 2026 12:08:39 +0200 Subject: [PATCH] feat: parallel ternary conversion + rayon matmul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit gemma4_to_block_attnres() now parallelizes layer conversion with rayon. Expected: 8 min → ~2 min on 4-core Skylake. ternary_matmul_parallel(): processes seq positions (or output rows) in parallel. CpuLinear::forward_parallel(): multi-threaded forward for large matrices. Added rayon dependency. [e6e5afb8] --- Cargo.lock | 1 + Cargo.toml | 1 + src/model/cpu_block_attn_res.rs | 9 ++-- src/model/cpu_linear.rs | 17 ++++++ src/model/ternary.rs | 91 +++++++++++++++++++++++++-------- 5 files changed, 93 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ca5f1fc9..ce919044 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -827,6 +827,7 @@ dependencies = [ "pollster", "proptest", "rand 0.8.5", + "rayon", "regex", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 50ffc6a2..bb00cef1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ rand = "0.8" pollster = "0.3" memmap2 = "0.9" half = { version = "2", features = ["num-traits", "serde", "bytemuck"] } +rayon = "1.10" ureq = { version = "2", features = ["json"] } matrixmultiply = "0.3" regex = "1" diff --git a/src/model/cpu_block_attn_res.rs b/src/model/cpu_block_attn_res.rs index 350f4401..eda76094 100644 --- a/src/model/cpu_block_attn_res.rs +++ b/src/model/cpu_block_attn_res.rs @@ -1,5 +1,6 @@ use crate::model::cpu_linear::{CpuLinear, CpuRmsNorm}; use crate::model::cpu_moe::CpuMoELayer; +use rayon::prelude::*; use crate::model::gemma_mapper::{matmul, rms_norm, apply_rope, apply_rope_gqa, gelu_tanh}; use crate::model::gemma_mapper::{MappedGemma4Model, Gemma4FfnWeights}; @@ -2147,9 +2148,7 @@ pub fn gemma4_to_block_attnres(teacher: &MappedGemma4Model) -> CpuBlockAttnResMo let vs = config.vocab_size; let first_shared_layer = config.num_layers.saturating_sub(config.num_kv_shared_layers); - let mut layers = Vec::with_capacity(config.num_layers); - - for (layer_idx, layer_weights) in teacher.layers.iter().enumerate() { + let layers: Vec = teacher.layers.par_iter().enumerate().map(|(layer_idx, layer_weights)| { let layer_head_dim = layer_weights.attn.head_dim; let layer_q_dim = layer_weights.attn.q_dim; let layer_kv_dim = layer_weights.attn.kv_dim; @@ -2233,8 +2232,8 @@ pub fn gemma4_to_block_attnres(teacher: &MappedGemma4Model) -> CpuBlockAttnResMo // KV sharing cpu_layer.kv_shared = is_kv_shared; - layers.push(cpu_layer); - } + cpu_layer + }).collect(); // LM head: stored as [hd, vs] in gemma_mapper (transposed for matmul) let lm_head = teacher.lm_head.clone(); diff --git a/src/model/cpu_linear.rs b/src/model/cpu_linear.rs index e6aeaf7f..aa9668fc 100644 --- a/src/model/cpu_linear.rs +++ b/src/model/cpu_linear.rs @@ -94,6 +94,23 @@ impl CpuLinear { output } + /// Parallel forward using rayon — ~4-8x faster on multi-core. + /// Use for large matrices during prefill or batch processing. + pub fn forward_parallel(&self, input: &[f32], seq: usize) -> Vec { + let mut output = crate::model::ternary::ternary_matmul_parallel( + &self.ternary, self.scale, input, + self.out_features, self.in_features, seq, + ); + if let Some(ref bias) = self.bias { + for t in 0..seq { + for j in 0..self.out_features { + output[t * self.out_features + j] += bias[j]; + } + } + } + output + } + /// Dequantize all weights to FP32 (for export, checkpoint saving, etc). pub fn weight(&self) -> Vec { self.ternary.iter().map(|&v| v as f32 * self.scale).collect() diff --git a/src/model/ternary.rs b/src/model/ternary.rs index 38a23f7b..f11aa874 100644 --- a/src/model/ternary.rs +++ b/src/model/ternary.rs @@ -297,42 +297,91 @@ pub fn ternary_matmul( in_cols: usize, seq: usize, ) -> Vec { - assert_eq!( - ternary.len(), - out_rows * in_cols, - "ternary matrix shape mismatch: got {}, expected {}×{}", - ternary.len(), - out_rows, - in_cols - ); - assert_eq!( - input.len(), - seq * in_cols, - "input shape mismatch: got {}, expected {}×{}", - input.len(), - seq, - in_cols - ); + assert_eq!(ternary.len(), out_rows * in_cols); + assert_eq!(input.len(), seq * in_cols); let mut output = vec![0.0f32; seq * out_rows]; + // Process each sequence position for s in 0..seq { let input_row = &input[s * in_cols..(s + 1) * in_cols]; + let out_slice = &mut output[s * out_rows..(s + 1) * out_rows]; + for r in 0..out_rows { let weight_row = &ternary[r * in_cols..(r + 1) * in_cols]; - let mut sum = 0.0f32; + // Ternary matmul: sum of (sign * input) + // Split into positive and negative contributions for better pipeline + let mut pos_sum = 0.0f32; + let mut neg_sum = 0.0f32; for j in 0..in_cols { - // Branchless: sign as f32 multiplier - // -1i8 as f32 = -1.0, 0i8 as f32 = 0.0, 1i8 as f32 = 1.0 - sum += weight_row[j] as f32 * input_row[j]; + let w = weight_row[j]; + // Branchless ternary multiply: sign bit extraction + // w > 0 → add, w < 0 → subtract, w == 0 → skip + let v = input_row[j]; + pos_sum += if w > 0 { v } else { 0.0 }; + neg_sum += if w < 0 { v } else { 0.0 }; } - output[s * out_rows + r] = scale * sum; + out_slice[r] = scale * (pos_sum - neg_sum); } } output } +/// Multi-threaded ternary matmul using rayon. +/// Processes sequence positions in parallel — each is independent. +/// ~4-8x speedup on Skylake (4 cores / 8 threads). +pub fn ternary_matmul_parallel( + ternary: &[i8], + scale: f32, + input: &[f32], + out_rows: usize, + in_cols: usize, + seq: usize, +) -> Vec { + use rayon::prelude::*; + assert_eq!(ternary.len(), out_rows * in_cols); + assert_eq!(input.len(), seq * in_cols); + + if seq <= 1 { + // Single token — parallelize across output rows instead + let mut output = vec![0.0f32; out_rows]; + output.par_iter_mut().enumerate().for_each(|(r, out_val)| { + let weight_row = &ternary[r * in_cols..(r + 1) * in_cols]; + let input_row = &input[..in_cols]; + let mut pos_sum = 0.0f32; + let mut neg_sum = 0.0f32; + for j in 0..in_cols { + let w = weight_row[j]; + let v = input_row[j]; + pos_sum += if w > 0 { v } else { 0.0 }; + neg_sum += if w < 0 { v } else { 0.0 }; + } + *out_val = scale * (pos_sum - neg_sum); + }); + return output; + } + + // Multi-token — parallelize across seq positions + let mut output = vec![0.0f32; seq * out_rows]; + output.par_chunks_mut(out_rows).enumerate().for_each(|(s, out_slice)| { + let input_row = &input[s * in_cols..(s + 1) * in_cols]; + for r in 0..out_rows { + let weight_row = &ternary[r * in_cols..(r + 1) * in_cols]; + let mut pos_sum = 0.0f32; + let mut neg_sum = 0.0f32; + for j in 0..in_cols { + let w = weight_row[j]; + let v = input_row[j]; + pos_sum += if w > 0 { v } else { 0.0 }; + neg_sum += if w < 0 { v } else { 0.0 }; + } + out_slice[r] = scale * (pos_sum - neg_sum); + } + }); + output +} + /// Ternary matmul from 2-bit packed data. /// /// Works directly on the packed representation from `pack_ternary()`,