Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2897,13 +2897,15 @@ async fn cmd_evaluate(
/// dL/dB[o][r] = scaling * Σ_t d_y[t][o] · (Σ_d A[r][d] · x[t][d])
///
/// `input`: [seq * in_f] — input to the projection
/// `d_output`: [seq * hd] — gradient signal (d_hidden or truncated)
/// Compute LoRA gradients for a single adapter.
/// `input`: [seq * in_f] — activations from forward pass
/// `d_output`: [seq * out_f] — gradient signal (may be truncated from d_hidden)
fn compute_lora_grad(
lora_layer: &ferrisres::training::lora::LoraLayer,
input: &[f32],
d_output: &[f32],
seq: usize,
hd: usize,
_hd: usize,
) -> (Vec<f32>, Vec<f32>) {
let rank = lora_layer.rank();
let in_f = lora_layer.in_features();
Expand All @@ -2915,33 +2917,32 @@ fn compute_lora_grad(
let mut ga = vec![0.0f32; rank * in_f];
let mut gb = vec![0.0f32; out_f * rank];

let actual_out = out_f.min(hd);
let actual_in = in_f.min(hd);
let actual_seq = seq.min(input.len() / in_f.max(1)).min(d_output.len() / actual_out.max(1));
// d_output has stride out_f (possibly truncated from hd)
let actual_seq = seq.min(input.len() / in_f.max(1)).min(d_output.len() / out_f.max(1));

for r in 0..rank {
for d in 0..actual_in {
for d in 0..in_f {
let mut grad = 0.0f32;
for t in 0..actual_seq {
let mut btdy = 0.0f32;
for o in 0..actual_out {
btdy += b[o * rank + r] * d_output[t * hd + o];
for o in 0..out_f {
btdy += b[o * rank + r] * d_output[t * out_f + o];
}
grad += btdy * input[t * in_f + d];
}
ga[r * in_f + d] = scaling * grad;
}
}

for o in 0..actual_out {
for o in 0..out_f {
for r in 0..rank {
let mut grad = 0.0f32;
for t in 0..actual_seq {
let mut ax_r = 0.0f32;
for d in 0..actual_in {
for d in 0..in_f {
ax_r += a[r * in_f + d] * input[t * in_f + d];
}
grad += d_output[t * hd + o] * ax_r;
grad += d_output[t * out_f + o] * ax_r;
}
gb[o * rank + r] = scaling * grad;
}
Expand Down
161 changes: 94 additions & 67 deletions src/model/cpu_block_attn_res.rs
Original file line number Diff line number Diff line change
Expand Up @@ -990,43 +990,74 @@ impl CpuBlockAttnResModel {
}
/// Embedding, norms, attention scores, RoPE stay on CPU (cheap ops).
/// Q/K/V/O projections, FFN gate/up/down, MoE experts, LM head go to GPU.
/// GPU ternary matmul helper: uses WGSL kernel with packed ternary weights.
/// Falls back to CPU if GPU fails or buffer too large.
fn gpu_ternary_mm(
gpu: &crate::model::gpu_forward::GpuMatmulAccelerator,
packed: &[u32],
scales: &[f32],
input: &[f32],
seq: usize,
in_cols: usize,
out_rows: usize,
) -> Vec<f32> {
match gpu.gpu_ternary_matmul(packed, scales, input, seq, in_cols, out_rows) {
Ok(result) => result,
Err(_) => {
// CPU fallback: unpack ternary and matmul
let mut result = vec![0.0f32; seq * out_rows];
let cols_packed = (in_cols + 15) / 16;
for t in 0..seq {
for row in 0..out_rows {
let mut sum = 0.0f32;
for c in 0..cols_packed {
let packed = packed[row * cols_packed + c];
for bit in 0..16u32 {
let col = c * 16 + bit as usize;
if col >= in_cols { break; }
let code = (packed >> (bit * 2)) & 3;
let val = input[t * in_cols + col];
match code {
0 => sum -= val,
2 => sum += val,
_ => {}
}
}
}
result[t * out_rows + row] = sum * scales[row];
}
}
result
}
}
}

/// GPU forward using ternary matmul kernels.
///
/// Key difference from old forward_gpu: uses packed ternary weights directly
/// on GPU instead of dequantizing to FP32 every call.
pub fn forward_gpu(
&self,
token_ids: &[u32],
gpu: &crate::model::gpu_forward::GpuMatmulAccelerator,
dispatch: &crate::device::DispatchPlan,
_dispatch: &crate::device::DispatchPlan,
) -> Vec<f32> {
let seq = token_ids.len();
let hd = self.hidden_dim;
let vs = self.vocab_size;

// Helper: GPU matmul with CPU fallback
let gpu_mm = |a: &[f32], b: &[f32], m: usize, k: usize, n: usize| -> Vec<f32> {
if matches!(dispatch.should_gpu_matmul(m as u64, k as u64, n as u64), crate::device::OpTarget::Gpu) {
match gpu.gpu_matmul_cpu_b(a, b, m, k, n) {
Ok(r) => return r,
Err(e) => {
tracing::debug!(event = "gpu_matmul_fallback", error = ?e, m, k, n, "GPU matmul failed, falling back to CPU");
}
}
}
matmul(a, b, m, k, n)
};

// 1. Embedding (CPU — just a lookup)
let mut hidden = vec![0.0f32; seq * hd];
for (t, &tid) in token_ids.iter().enumerate() {
let id = tid as usize;
if id * hd + hd <= self.embed_tokens.len() {
for d in 0..hd {
hidden[t * hd + d] = self.embed_tokens[id * hd + d];
}
for d in 0..hd { hidden[t * hd + d] = self.embed_tokens[id * hd + d]; }
}
}
let scale = (hd as f32).sqrt();
for h in hidden.iter_mut() { *h *= scale; }

// 2. Pre-compute PLE inputs
// 2. PLE
let ple_dim = self.hidden_size_per_layer_input;
let ple_precomputed = self.precompute_ple(&hidden, token_ids, seq, hd, ple_dim);

Expand All @@ -1040,7 +1071,6 @@ impl CpuBlockAttnResModel {
block_reps.push(partial_sum.clone());

for (layer_idx, layer) in self.layers.iter().enumerate() {
// PLE slice
let ple_slice = ple_precomputed.as_ref().map(|pre| {
let mut slice = vec![0.0f32; seq * ple_dim];
for t in 0..seq {
Expand All @@ -1051,24 +1081,21 @@ impl CpuBlockAttnResModel {
slice
});

// KV sharing
let (kv, should_store) = if layer_idx >= first_shared_layer && first_shared_layer > 0 && layer.kv_shared {
let kv_source = Self::kv_shared_source_layer(layer_idx, first_shared_layer, &self.layers);
if let Some((sk, sv)) = shared_kv.get(&kv_source) {
(Some((sk.as_slice(), sv.as_slice())), false)
} else { (None, false) }
} else { (None, true) };

// === Attention (matmuls on GPU, rest CPU) ===
// === Attention ===
let residual = hidden.clone();
let normed = layer.attn_norm.forward(&hidden);
let lora_m = self.lora_manager.as_ref();

// Q projection
let q_dim = layer.q_proj.out_features();
let kv_dim_out = layer.v_proj.out_features();
let q_w = layer.q_proj.weight();
let mut q = gpu_mm(&normed, &q_w, seq, hd, q_dim);
// Q projection via GPU ternary matmul
let (q_packed, q_scales) = layer.q_proj.gpu_packed();
let mut q = Self::gpu_ternary_mm(gpu, &q_packed, &q_scales, &normed, seq, hd, layer.q_proj.out_features());
if let Some(ref lora_m) = lora_m {
if let Some(lora_out) = lora_m.forward(layer_idx, "q_proj", &normed, seq) {
for (i, v) in lora_out.iter().enumerate() { q[i] += v; }
Expand All @@ -1077,17 +1104,17 @@ impl CpuBlockAttnResModel {
q = crate::model::gemma_mapper::per_head_rms_norm(&q, layer.q_norm.weight(), seq, layer.num_heads, layer.head_dim);
apply_rope(&mut q, seq, layer.num_heads, layer.head_dim, 0, layer.rope_theta, layer.partial_rotary_factor);

// K/V
// K/V via GPU ternary matmul
let (k, v) = match kv {
Some((sk, sv)) => (sk.to_vec(), sv.to_vec()),
None => {
let k_w = layer.k_proj.weight();
let v_w = layer.v_proj.weight();
let mut k = gpu_mm(&normed, &k_w, seq, hd, layer.k_proj.out_features());
let mut v_raw = gpu_mm(&normed, &v_w, seq, hd, layer.v_proj.out_features());
let (k_packed, k_scales) = layer.k_proj.gpu_packed();
let (v_packed, v_scales) = layer.v_proj.gpu_packed();
let mut k = Self::gpu_ternary_mm(gpu, &k_packed, &k_scales, &normed, seq, hd, layer.k_proj.out_features());
let mut v_raw = Self::gpu_ternary_mm(gpu, &v_packed, &v_scales, &normed, seq, hd, layer.v_proj.out_features());
if let Some(ref lora_m) = lora_m {
if let Some(lora_out) = lora_m.forward(layer_idx, "v_proj", &normed, seq) {
for (i, l) in lora_out.iter().enumerate() { v_raw[i] += l; }
for (i, l) in v_raw.iter_mut().enumerate() { *l += lora_out[i]; }
}
}
k = crate::model::gemma_mapper::per_head_rms_norm(&k, layer.k_norm.weight(), seq, layer.num_kv_heads, layer.head_dim);
Expand All @@ -1097,12 +1124,19 @@ impl CpuBlockAttnResModel {
}
};

// Attention scores (CPU — O(seq²) but small dimensions)
let attn_out = layer.cpu_attention(&q, &k, &v, seq, layer.num_heads, layer.num_kv_heads, layer.head_dim, q_dim, kv_dim_out);
// Attention (CPU — small matrices)
let attn_raw = layer.cpu_attention_raw(&q, &k, &v, seq, layer.num_heads, layer.num_kv_heads, layer.head_dim, layer.q_proj.out_features(), layer.v_proj.out_features());
// O projection via GPU ternary matmul
let (o_packed, o_scales) = layer.out_proj.gpu_packed();
let mut attn_out = Self::gpu_ternary_mm(gpu, &o_packed, &o_scales, &attn_raw, seq, layer.q_proj.out_features(), layer.out_proj.out_features());
if let Some(ref lora_m) = lora_m {
if let Some(lora_out) = lora_m.forward(layer_idx, "o_proj", &attn_raw, seq) {
for (i, l) in attn_out.iter_mut().enumerate() { *l += lora_out[i]; }
}
}
let attn_out = layer.post_attn_norm.forward(&attn_out);
for i in 0..hidden.len() { hidden[i] = residual[i] + attn_out[i]; }

// Store K/V
if should_store && first_shared_layer > 0 && !k.is_empty() {
shared_kv.insert(layer_idx, (k, v));
}
Expand All @@ -1112,7 +1146,7 @@ impl CpuBlockAttnResModel {
let normed2 = layer.pre_ffn_norm.forward(&hidden);

if let Some(ref moe) = layer.moe {
// MoE routing (CPU — small matmul + softmax)
// Router (CPU — small)
let mut router_out = vec![0.0f32; seq * moe.num_experts];
for t in 0..seq {
for e in 0..moe.num_experts {
Expand All @@ -1121,69 +1155,65 @@ impl CpuBlockAttnResModel {
router_out[t * moe.num_experts + e] = dot;
}
}
// Softmax + top-k
let mut expert_outputs = vec![0.0f32; seq * hd];
for t in 0..seq {
let r_off = t * moe.num_experts;
let max_r = (0..moe.num_experts).map(|e| router_out[r_off + e]).fold(f32::NEG_INFINITY, f32::max);
let mut sum_e = 0.0f32;
for e in 0..moe.num_experts { router_out[r_off + e] = (router_out[r_off + e] - max_r).exp(); sum_e += router_out[r_off + e]; }
for e in 0..moe.num_experts { router_out[r_off + e] /= sum_e; }
// Find top-k
let mut indices: Vec<usize> = (0..moe.num_experts).collect();
indices.sort_by(|&a, &b| router_out[r_off + b].partial_cmp(&router_out[r_off + a]).unwrap());
let top_k_sum: f32 = indices[..moe.top_k].iter().map(|&e| router_out[r_off + e]).sum();
for &ei in &indices[..moe.top_k] {
let w = router_out[r_off + ei] / top_k_sum;
let input_t = &normed2[t * hd..(t + 1) * hd];
let gate_w = moe.expert_gate[ei].to_fp32();
let up_w = moe.expert_up[ei].to_fp32();
let down_w = moe.expert_down[ei].to_fp32();
let gated = gpu_mm(input_t, &gate_w, 1, hd, moe.intermediate_dim);
let upped = gpu_mm(input_t, &up_w, 1, hd, moe.intermediate_dim);
let gated: Vec<f32> = if moe.use_gelu { gated.iter().map(|&x| gelu_tanh(x)).collect() } else { gated.iter().map(|&x| x * (1.0 / (1.0 + (-x).exp()))).collect() };
// Expert via GPU ternary matmul
let (g_packed, g_scales) = moe.expert_gate[ei].gpu_packed();
let (u_packed, u_scales) = moe.expert_up[ei].gpu_packed();
let (d_packed, d_scales) = moe.expert_down[ei].gpu_packed();
let gated = Self::gpu_ternary_mm(gpu, &g_packed, &g_scales, input_t, 1, hd, moe.intermediate_dim);
let upped = Self::gpu_ternary_mm(gpu, &u_packed, &u_scales, input_t, 1, hd, moe.intermediate_dim);
let gated: Vec<f32> = if moe.use_gelu { gated.iter().map(|&x| crate::model::gemma_mapper::gelu_tanh(x)).collect() } else { gated.iter().map(|&x| x / (1.0 + (-x).exp())).collect() };
let combined: Vec<f32> = gated.iter().zip(upped.iter()).map(|(&g, &u)| g * u).collect();
let down = gpu_mm(&combined, &down_w, 1, moe.intermediate_dim, hd);
let down = Self::gpu_ternary_mm(gpu, &d_packed, &d_scales, &combined, 1, moe.intermediate_dim, hd);
for d in 0..hd { expert_outputs[t * hd + d] += w * down[d]; }
}
}
let normed_ffn = layer.post_ffn_norm.forward(&expert_outputs);
for i in 0..hidden.len() { hidden[i] = ffn_residual[i] + normed_ffn[i] * layer.layer_scalar; }
} else {
// Dense FFN
} else if let (Some(gate), Some(up), Some(down)) = (&layer.ffn_gate, &layer.ffn_up, &layer.ffn_down) {
let id = layer.intermediate_dim;
let (gate_w, up_w, down_w) = match (&layer.ffn_gate, &layer.ffn_up, &layer.ffn_down) {
(Some(g), Some(u), Some(d)) => (g.weight(), u.weight(), d.weight()),
_ => { continue; }
};
let gated = gpu_mm(&normed2, &gate_w, seq, hd, id);
let upped = gpu_mm(&normed2, &up_w, seq, hd, id);
let gated: Vec<f32> = if layer.use_gelu { gated.iter().map(|&x| gelu_tanh(x)).collect() } else { gated.iter().map(|&x| x * (1.0 / (1.0 + (-x).exp()))).collect() };
let (g_packed, g_scales) = gate.gpu_packed();
let (u_packed, u_scales) = up.gpu_packed();
let (d_packed, d_scales) = down.gpu_packed();
let gated = Self::gpu_ternary_mm(gpu, &g_packed, &g_scales, &normed2, seq, hd, id);
let upped = Self::gpu_ternary_mm(gpu, &u_packed, &u_scales, &normed2, seq, hd, id);
let gated: Vec<f32> = if layer.use_gelu { gated.iter().map(|&x| crate::model::gemma_mapper::gelu_tanh(x)).collect() } else { gated.iter().map(|&x| x / (1.0 + (-x).exp())).collect() };
let combined: Vec<f32> = gated.iter().zip(upped.iter()).map(|(&g, &u)| g * u).collect();
let ffn_out = gpu_mm(&combined, &down_w, seq, id, hd);
let ffn_out = Self::gpu_ternary_mm(gpu, &d_packed, &d_scales, &combined, seq, id, hd);
let normed_ffn = layer.post_ffn_norm.forward(&ffn_out);
for i in 0..hidden.len() { hidden[i] = ffn_residual[i] + normed_ffn[i] * layer.layer_scalar; }
}

// PLE residual
if let Some(ref ple_s) = ple_slice {
// PLE injection (gate → GELU → element-wise multiply → project → norm → residual)
if let (Some(ref gate), Some(ref proj), Some(ref norm)) =
(&layer.ple_input_gate, &layer.ple_projection, &layer.ple_post_norm)
{
let ple_dim = ple_s.len() / seq;
let gate_out = gate.forward(&hidden, seq);
let gate_gelu: Vec<f32> = gate_out.iter().map(|&x| gelu_tanh(x)).collect();
let (pg_packed, pg_scales) = gate.gpu_packed();
let (pp_packed, pp_scales) = proj.gpu_packed();
let gate_out = Self::gpu_ternary_mm(gpu, &pg_packed, &pg_scales, &hidden, seq, hd, ple_dim);
let gate_gelu: Vec<f32> = gate_out.iter().map(|&x| crate::model::gemma_mapper::gelu_tanh(x)).collect();
let mut gated = vec![0.0f32; seq * ple_dim];
for i in 0..seq * ple_dim { gated[i] = gate_gelu[i] * ple_s[i]; }
let proj_w = proj.weight();
let proj_out = gpu_mm(&gated, &proj_w, seq, ple_dim, hd);
let proj_out = Self::gpu_ternary_mm(gpu, &pp_packed, &pp_scales, &gated, seq, ple_dim, hd);
let ple_final = norm.forward(&proj_out);
for i in 0..hidden.len() { hidden[i] += ple_final[i]; }
}
}

// Block boundary
for t in 0..seq { for d in 0..hd { partial_sum[d] += hidden[t * hd + d]; } }
if self.is_block_boundary(layer_idx) {
for d in 0..hd { partial_sum[d] /= ((seq) * (self.block_config.layers_per_block)) as f32; }
Expand All @@ -1194,12 +1224,9 @@ impl CpuBlockAttnResModel {
}
}

// 4. Final norm + LM head
hidden = rms_norm(&hidden, &self.final_norm, hd, 1e-6);
let logits = gpu_mm(&hidden, &self.lm_head, seq, hd, vs);

// 5. Logit softcapping
let mut logits = logits;
// 4. Final norm + LM head (CPU — too large for GPU buffer)
hidden = crate::model::gemma_mapper::rms_norm(&hidden, &self.final_norm, hd, 1e-6);
let mut logits = matmul(&hidden, &self.lm_head, seq, hd, vs);
if let Some(cap) = self.final_logit_softcapping {
for l in logits.iter_mut() { *l = (*l / cap).tanh() * cap; }
}
Expand Down
9 changes: 9 additions & 0 deletions src/model/cpu_linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ impl CpuLinear {
self.packed = new_packed;
self.scale = new_scale;
}

/// GPU-ready packed weights: u32 format (16 values/u32) and per-row scales.
/// Since all rows share one absmean scale, scales = [scale; out_features].
pub fn gpu_packed(&self) -> (Vec<u32>, Vec<f32>) {
let total = self.in_features * self.out_features;
let packed_u32 = crate::model::ternary::pack_ternary_u32(&self.packed, total);
let scales = vec![self.scale; self.out_features];
(packed_u32, scales)
}
}

/// CPU-only RMS normalization. Stores weights as `Vec<f32>`.
Expand Down
Loading
Loading