Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ rand = "0.8"
pollster = "0.3"
memmap2 = "0.9"
half = { version = "2", features = ["num-traits", "serde", "bytemuck"] }
rayon = "1.10"
ureq = { version = "2", features = ["json"] }
matrixmultiply = "0.3"
regex = "1"
Expand Down
9 changes: 4 additions & 5 deletions src/model/cpu_block_attn_res.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::model::cpu_linear::{CpuLinear, CpuRmsNorm};
use crate::model::cpu_moe::CpuMoELayer;
use rayon::prelude::*;
use crate::model::gemma_mapper::{matmul, rms_norm, apply_rope, apply_rope_gqa, gelu_tanh};
use crate::model::gemma_mapper::{MappedGemma4Model, Gemma4FfnWeights};

Expand Down Expand Up @@ -2147,9 +2148,7 @@ pub fn gemma4_to_block_attnres(teacher: &MappedGemma4Model) -> CpuBlockAttnResMo
let vs = config.vocab_size;
let first_shared_layer = config.num_layers.saturating_sub(config.num_kv_shared_layers);

let mut layers = Vec::with_capacity(config.num_layers);

for (layer_idx, layer_weights) in teacher.layers.iter().enumerate() {
let layers: Vec<CpuBlockAttnResLayer> = teacher.layers.par_iter().enumerate().map(|(layer_idx, layer_weights)| {
let layer_head_dim = layer_weights.attn.head_dim;
let layer_q_dim = layer_weights.attn.q_dim;
let layer_kv_dim = layer_weights.attn.kv_dim;
Expand Down Expand Up @@ -2233,8 +2232,8 @@ pub fn gemma4_to_block_attnres(teacher: &MappedGemma4Model) -> CpuBlockAttnResMo
// KV sharing
cpu_layer.kv_shared = is_kv_shared;

layers.push(cpu_layer);
}
cpu_layer
}).collect();

// LM head: stored as [hd, vs] in gemma_mapper (transposed for matmul)
let lm_head = teacher.lm_head.clone();
Expand Down
17 changes: 17 additions & 0 deletions src/model/cpu_linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,23 @@ impl CpuLinear {
output
}

/// Parallel forward using rayon — ~4-8x faster on multi-core.
/// Use for large matrices during prefill or batch processing.
pub fn forward_parallel(&self, input: &[f32], seq: usize) -> Vec<f32> {
let mut output = crate::model::ternary::ternary_matmul_parallel(
&self.ternary, self.scale, input,
self.out_features, self.in_features, seq,
);
if let Some(ref bias) = self.bias {
for t in 0..seq {
for j in 0..self.out_features {
output[t * self.out_features + j] += bias[j];
}
}
}
output
}

/// Dequantize all weights to FP32 (for export, checkpoint saving, etc).
pub fn weight(&self) -> Vec<f32> {
self.ternary.iter().map(|&v| v as f32 * self.scale).collect()
Expand Down
91 changes: 70 additions & 21 deletions src/model/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,42 +297,91 @@ pub fn ternary_matmul(
in_cols: usize,
seq: usize,
) -> Vec<f32> {
assert_eq!(
ternary.len(),
out_rows * in_cols,
"ternary matrix shape mismatch: got {}, expected {}×{}",
ternary.len(),
out_rows,
in_cols
);
assert_eq!(
input.len(),
seq * in_cols,
"input shape mismatch: got {}, expected {}×{}",
input.len(),
seq,
in_cols
);
assert_eq!(ternary.len(), out_rows * in_cols);
assert_eq!(input.len(), seq * in_cols);

let mut output = vec![0.0f32; seq * out_rows];

// Process each sequence position
for s in 0..seq {
let input_row = &input[s * in_cols..(s + 1) * in_cols];
let out_slice = &mut output[s * out_rows..(s + 1) * out_rows];

for r in 0..out_rows {
let weight_row = &ternary[r * in_cols..(r + 1) * in_cols];
let mut sum = 0.0f32;
// Ternary matmul: sum of (sign * input)
// Split into positive and negative contributions for better pipeline
let mut pos_sum = 0.0f32;
let mut neg_sum = 0.0f32;
for j in 0..in_cols {
// Branchless: sign as f32 multiplier
// -1i8 as f32 = -1.0, 0i8 as f32 = 0.0, 1i8 as f32 = 1.0
sum += weight_row[j] as f32 * input_row[j];
let w = weight_row[j];
// Branchless ternary multiply: sign bit extraction
// w > 0 → add, w < 0 → subtract, w == 0 → skip
let v = input_row[j];
pos_sum += if w > 0 { v } else { 0.0 };
neg_sum += if w < 0 { v } else { 0.0 };
}
output[s * out_rows + r] = scale * sum;
out_slice[r] = scale * (pos_sum - neg_sum);
}
}

output
}

/// Multi-threaded ternary matmul using rayon.
/// Processes sequence positions in parallel — each is independent.
/// ~4-8x speedup on Skylake (4 cores / 8 threads).
pub fn ternary_matmul_parallel(
ternary: &[i8],
scale: f32,
input: &[f32],
out_rows: usize,
in_cols: usize,
seq: usize,
) -> Vec<f32> {
use rayon::prelude::*;
assert_eq!(ternary.len(), out_rows * in_cols);
assert_eq!(input.len(), seq * in_cols);

if seq <= 1 {
// Single token — parallelize across output rows instead
let mut output = vec![0.0f32; out_rows];
output.par_iter_mut().enumerate().for_each(|(r, out_val)| {
let weight_row = &ternary[r * in_cols..(r + 1) * in_cols];
let input_row = &input[..in_cols];
let mut pos_sum = 0.0f32;
let mut neg_sum = 0.0f32;
for j in 0..in_cols {
let w = weight_row[j];
let v = input_row[j];
pos_sum += if w > 0 { v } else { 0.0 };
neg_sum += if w < 0 { v } else { 0.0 };
}
*out_val = scale * (pos_sum - neg_sum);
});
return output;
}

// Multi-token — parallelize across seq positions
let mut output = vec![0.0f32; seq * out_rows];
output.par_chunks_mut(out_rows).enumerate().for_each(|(s, out_slice)| {
let input_row = &input[s * in_cols..(s + 1) * in_cols];
for r in 0..out_rows {
let weight_row = &ternary[r * in_cols..(r + 1) * in_cols];
let mut pos_sum = 0.0f32;
let mut neg_sum = 0.0f32;
for j in 0..in_cols {
let w = weight_row[j];
let v = input_row[j];
pos_sum += if w > 0 { v } else { 0.0 };
neg_sum += if w < 0 { v } else { 0.0 };
}
out_slice[r] = scale * (pos_sum - neg_sum);
}
});
output
}

/// Ternary matmul from 2-bit packed data.
///
/// Works directly on the packed representation from `pack_ternary()`,
Expand Down
Loading