Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,41 @@ async fn cmd_infer(
info!(event = "lora_merged", "LoRA merged into ternary base weights, adapters removed");
}

// Try GPU acceleration
let gpu_accel: Option<ferrisres::model::gpu_forward::GpuMatmulAccelerator> =
match ferrisres::model::gpu_forward::GpuMatmulAccelerator::new() {
Ok(accel) => {
info!(event = "gpu_available", "GPU available for student inference");
Some(accel)
}
Err(e) => {
info!(event = "gpu_unavailable", error = ?e, "No GPU, using CPU-only student inference");
None
}
};

let gpu_dispatch = gpu_accel.as_ref().map(|gpu| {
let profile = gpu.profile();
let vram = gpu.estimated_vram_bytes();
let max_buf = gpu.max_buffer_bytes();
ferrisres::device::DispatchPlan::new(
profile, 0, vram, max_buf,
student_model.hidden_dim as u64,
student_model.layers.first().map_or(8, |l| l.num_heads as u64),
student_model.layers.first().map_or(256, |l| l.head_dim as u64),
student_model.layers.first().map_or(6144, |l| l.intermediate_dim as u64),
student_model.vocab_size as u64,
1, // decode: seq=1
student_model.layers.first().map_or(1, |l| l.num_kv_heads as u64),
)
});

if let Some(ref plan) = gpu_dispatch {
info!(event = "student_dispatch", "\n{}", plan.summary());
}

// Autoregressive generation on student model
let gen_tokens = generate_student(&student_model, &tokens, max_tokens, temperature as f32, None);
let gen_tokens = generate_student(&student_model, &tokens, max_tokens, temperature as f32, None, gpu_accel.as_ref(), gpu_dispatch.as_ref());

let output_text = decoder(&gen_tokens);
let display_output = if let Some(ref mut layer) = armor_layer {
Expand Down Expand Up @@ -2707,6 +2740,8 @@ fn generate_student(
max_new_tokens: usize,
temperature: f32,
eos_token_id: Option<u32>,
gpu: Option<&ferrisres::model::gpu_forward::GpuMatmulAccelerator>,
_dispatch: Option<&ferrisres::device::DispatchPlan>,
) -> Vec<u32> {
let mut all_tokens: Vec<u32> = prompt_tokens.to_vec();
let vs = model.vocab_size;
Expand All @@ -2724,11 +2759,11 @@ fn generate_student(

info!(event = "student_gen_start", prompt_len = prompt_tokens.len(), max_new_tokens, "Starting student model generation (KV-cached)");

// === Prefill phase: process all prompt tokens at once ===
// === Prefill phase: process all prompt tokens at once (CPU, populates KV cache) ===
let t0 = std::time::Instant::now();
let logits = model.forward_prefill(prompt_tokens, &mut cache);
let prefill_ms = t0.elapsed().as_millis();
info!(event = "student_prefill", prompt_len = prompt_tokens.len(), prefill_ms, cache_mb = cache.total_memory_bytes() / (1024*1024), "Prefill complete");
info!(event = "student_prefill", prompt_len = prompt_tokens.len(), prefill_ms, cache_mb = cache.total_memory_bytes() / (1024*1024), gpu = gpu.is_some(), "Prefill complete");

// Take logits for last position
let last_logits = &logits[(prompt_tokens.len() - 1) * vs..prompt_tokens.len() * vs];
Expand All @@ -2741,7 +2776,11 @@ fn generate_student(
for step in 1..max_new_tokens {
let t0 = std::time::Instant::now();

let logits = model.forward_decode(all_tokens[all_tokens.len() - 1], &mut cache);
let logits = if gpu.is_some() {
model.forward_decode_gpu(all_tokens[all_tokens.len() - 1], &mut cache, gpu.unwrap())
} else {
model.forward_decode(all_tokens[all_tokens.len() - 1], &mut cache)
};
let forward_ms = t0.elapsed().as_millis();

let last_logits = &logits; // decode returns only last token's logits
Expand Down
166 changes: 166 additions & 0 deletions src/model/cpu_block_attn_res.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1884,6 +1884,172 @@ impl CpuBlockAttnResModel {

model
}

/// GPU-accelerated decode: single token with KV cache.
/// Matmuls go through gpu_ternary_matmul, KV cache stays on CPU.
pub fn forward_decode_gpu(
&self,
new_token_id: u32,
cache: &mut crate::inference::student_kv_cache::ModelKVCache,
gpu: &crate::model::gpu_forward::GpuMatmulAccelerator,
) -> Vec<f32> {
let hd = self.hidden_dim;
let vs = self.vocab_size;
let pos = cache.seq_len();

cache.cached_token_ids.push(new_token_id);

// 1. Embedding
let mut hidden = vec![0.0f32; hd];
let id = new_token_id as usize;
if id * hd + hd <= self.embed_tokens.len() {
for d in 0..hd { hidden[d] = self.embed_tokens[id * hd + d]; }
}
let scale = (hd as f32).sqrt();
for h in hidden.iter_mut() { *h *= scale; }

let ple_dim = self.hidden_size_per_layer_input;

for (layer_idx, layer) in self.layers.iter().enumerate() {
let ple_input = cache.ple_prefix.as_ref().and_then(|pre| {
let all_tokens = &cache.cached_token_ids;
let t = all_tokens.len() - 1;
let idx = t * self.num_layers * ple_dim + layer_idx * ple_dim;
if idx + ple_dim <= pre.len() {
Some(pre[idx..idx + ple_dim].to_vec())
} else {
None
}
});

let kv_dim = layer.num_kv_heads * layer.head_dim;
let q_dim = layer.num_heads * layer.head_dim;

let residual = hidden.clone();
let normed = layer.attn_norm.forward_single(&hidden);

// Q via GPU ternary matmul
let (q_packed, q_scales) = layer.q_proj.gpu_packed();
let mut q = Self::gpu_ternary_mm(gpu, &q_packed, &q_scales, &normed, 1, hd, layer.q_proj.out_features());
q = crate::model::gemma_mapper::per_head_rms_norm(&q, layer.q_norm.weight(), 1, layer.num_heads, layer.head_dim);
apply_rope(&mut q, 1, layer.num_heads, layer.head_dim, pos, layer.rope_theta, layer.partial_rotary_factor);

// K/V via GPU ternary matmul
let (new_k, new_v) = {
let (k_packed, k_scales) = layer.k_proj.gpu_packed();
let (v_packed, v_scales) = layer.v_proj.gpu_packed();
let mut k = Self::gpu_ternary_mm(gpu, &k_packed, &k_scales, &normed, 1, hd, layer.k_proj.out_features());
let v_raw = Self::gpu_ternary_mm(gpu, &v_packed, &v_scales, &normed, 1, hd, layer.v_proj.out_features());
k = crate::model::gemma_mapper::per_head_rms_norm(&k, layer.k_norm.weight(), 1, layer.num_kv_heads, layer.head_dim);
let v = crate::model::gemma_mapper::per_head_rms_norm_no_scale(&v_raw, 1, layer.num_kv_heads, layer.head_dim);
apply_rope_gqa(&mut k, 1, layer.num_kv_heads, layer.head_dim, pos, layer.rope_theta, layer.partial_rotary_factor);
(k, v)
};

cache.layers[layer_idx].append(&new_k, &new_v);

let (full_k, full_v) = {
cache.layers[layer_idx].get()
};

// Attention (CPU — small matrices)
let attn_out = layer.cpu_attention_packed_parallel(&q, full_k, full_v, 1, layer.num_heads, layer.num_kv_heads, layer.head_dim, q_dim, kv_dim);

// O projection via GPU ternary matmul
let (o_packed, o_scales) = layer.out_proj.gpu_packed();
let attn_out = Self::gpu_ternary_mm(gpu, &o_packed, &o_scales, &attn_out, 1, layer.out_proj.in_features(), layer.out_proj.out_features());
let attn_out = layer.post_attn_norm.forward_single(&attn_out);
for i in 0..hd { hidden[i] = residual[i] + attn_out[i]; }

// FFN via GPU
let ffn_residual = hidden.clone();
let normed2 = layer.pre_ffn_norm.forward_single(&hidden);

if let Some(ref moe) = layer.moe {
// Router (CPU — tiny)
let mut router_out = vec![0.0f32; moe.num_experts];
for e in 0..moe.num_experts {
let mut dot = 0.0f32;
for d in 0..hd { dot += normed2[d] * moe.gate_weights[e * hd + d]; }
router_out[e] = dot;
}
let max_r = router_out.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum_e = 0.0f32;
for e in 0..moe.num_experts { router_out[e] = (router_out[e] - max_r).exp(); sum_e += router_out[e]; }
for e in 0..moe.num_experts { router_out[e] /= sum_e; }
let mut indices: Vec<usize> = (0..moe.num_experts).collect();
indices.sort_by(|&a, &b| router_out[b].partial_cmp(&router_out[a]).unwrap());
let top_k_sum: f32 = indices[..moe.top_k].iter().map(|&e| router_out[e]).sum();

let mut expert_out = vec![0.0f32; hd];
for &ei in &indices[..moe.top_k] {
let w = router_out[ei] / top_k_sum;
let (g_packed, g_scales) = moe.expert_gate[ei].gpu_packed();
let (u_packed, u_scales) = moe.expert_up[ei].gpu_packed();
let (d_packed, d_scales) = moe.expert_down[ei].gpu_packed();
let gated = Self::gpu_ternary_mm(gpu, &g_packed, &g_scales, &normed2, 1, hd, moe.intermediate_dim);
let upped = Self::gpu_ternary_mm(gpu, &u_packed, &u_scales, &normed2, 1, hd, moe.intermediate_dim);
let gated: Vec<f32> = if moe.use_gelu { gated.iter().map(|&x| crate::model::gemma_mapper::gelu_tanh(x)).collect() } else { gated.iter().map(|&x| x / (1.0 + (-x).exp())).collect() };
let combined: Vec<f32> = gated.iter().zip(upped.iter()).map(|(&g, &u)| g * u).collect();
let down = Self::gpu_ternary_mm(gpu, &d_packed, &d_scales, &combined, 1, moe.intermediate_dim, hd);
for d in 0..hd { expert_out[d] += w * down[d]; }
}
let normed_ffn = layer.post_ffn_norm.forward_single(&expert_out);
for i in 0..hd { hidden[i] = ffn_residual[i] + normed_ffn[i] * layer.layer_scalar; }
} else if let (Some(gate), Some(up), Some(down)) = (&layer.ffn_gate, &layer.ffn_up, &layer.ffn_down) {
let id = layer.intermediate_dim;
let (g_packed, g_scales) = gate.gpu_packed();
let (u_packed, u_scales) = up.gpu_packed();
let (d_packed, d_scales) = down.gpu_packed();
let gated = Self::gpu_ternary_mm(gpu, &g_packed, &g_scales, &normed2, 1, hd, id);
let upped = Self::gpu_ternary_mm(gpu, &u_packed, &u_scales, &normed2, 1, hd, id);
let gated: Vec<f32> = if layer.use_gelu { gated.iter().map(|&x| crate::model::gemma_mapper::gelu_tanh(x)).collect() } else { gated.iter().map(|&x| x / (1.0 + (-x).exp())).collect() };
let combined: Vec<f32> = gated.iter().zip(upped.iter()).map(|(&g, &u)| g * u).collect();
let ffn_out = Self::gpu_ternary_mm(gpu, &d_packed, &d_scales, &combined, 1, id, hd);
let normed_ffn = layer.post_ffn_norm.forward_single(&ffn_out);
for i in 0..hd { hidden[i] = ffn_residual[i] + normed_ffn[i] * layer.layer_scalar; }
}

// PLE (CPU — small)
if let Some(ref ple_s) = ple_input {
if let (Some(ref gate), Some(ref proj), Some(ref norm)) =
(&layer.ple_input_gate, &layer.ple_projection, &layer.ple_post_norm)
{
let ple_d = ple_s.len();
let gate_w = gate.weight();
let proj_w = proj.weight();
let gate_out = matmul(&hidden, &gate_w, 1, hd, ple_d);
let gate_gelu: Vec<f32> = gate_out.iter().map(|&x| crate::model::gemma_mapper::gelu_tanh(x)).collect();
let mut gated = vec![0.0f32; ple_d];
for i in 0..ple_d { gated[i] = gate_gelu[i] * ple_s[i]; }
let proj_out = matmul(&gated, &proj_w, 1, ple_d, hd);
let ple_final = norm.forward_single(&proj_out);
for i in 0..hd { hidden[i] += ple_final[i]; }
}
}

// Block tracking
for d in 0..hd { cache.partial_sum[d] += hidden[d]; }
if self.is_block_boundary(layer_idx) {
cache.block_token_count += 1;
let mut block_rep = cache.partial_sum.clone();
for d in 0..hd { block_rep[d] /= (cache.block_token_count * self.block_config.layers_per_block) as f32; }
cache.block_reps.push(block_rep);
cache.partial_sum = vec![0.0f32; hd];
cache.block_token_count = 0;
let inter_out = self.inter_block_attention_single(&hidden, &cache.block_reps);
for d in 0..hd { hidden[d] += inter_out[d]; }
}
}

// Final norm + LM head (CPU — too large for GPU buffer on iGPU)
hidden = crate::model::gemma_mapper::rms_norm(&hidden, &self.final_norm, hd, 1e-6);
let mut logits = matmul(&hidden, &self.lm_head, 1, hd, vs);
if let Some(cap) = self.final_logit_softcapping {
for l in logits.iter_mut() { *l = (*l / cap).tanh() * cap; }
}
logits
}
}

// ---------------------------------------------------------------------------
Expand Down
Loading