From 2adf025bda13a7e579c63f940d36fd21423b4b13 Mon Sep 17 00:00:00 2001 From: Thomas Santerre Date: Thu, 18 Apr 2024 15:14:36 -0400 Subject: [PATCH] refactor: rewrite several sections to better match the 1.58 paper & add rope (#3) * lot's of random changes to allign with original paper * more fixes * more fixes * unused import * fix test * clippy fixes --- Cargo.toml | 2 - src/bit_attention.rs | 202 +++++++++++++++++++---------------------- src/bit_ffn.rs | 52 ++++------- src/bit_linear.rs | 161 ++++++-------------------------- src/bit_transformer.rs | 45 +-------- src/config.rs | 2 +- src/inference.rs | 4 +- src/rms_norm.rs | 8 +- src/training.rs | 6 +- src/utils_tensor.rs | 34 ------- 10 files changed, 157 insertions(+), 359 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9e2a3cf..8c3240a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,8 +3,6 @@ name = "bitnet-rs" version = "0.1.0" edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] accelerate-src = { version = "0.3.2", optional = true } anyhow = "1.0.80" diff --git a/src/bit_attention.rs b/src/bit_attention.rs index e1f3bde..c28b2c6 100644 --- a/src/bit_attention.rs +++ b/src/bit_attention.rs @@ -2,10 +2,10 @@ use crate::{ bit_linear::{Bitlinear, BitlinearCfg}, utils_tensor::scaled_dot_product_attention, }; -use anyhow::{anyhow, Result}; -use candle_core::{Tensor, D}; +use anyhow::{anyhow, Ok, Result}; +use candle_core::{DType, Device, Tensor}; use candle_einops::einops; -use candle_nn::{layer_norm, LayerNormConfig, Module, VarBuilder}; +use candle_nn::{rotary_emb::rope_i, VarBuilder}; use tracing::instrument; #[derive(Debug, Clone, Copy)] @@ -14,21 +14,42 @@ pub struct BitAttentionCfg { pub n_heads: usize, pub n_kv_heads: usize, pub dropout: f32, - pub bias: bool, - pub layer_norm_enabled: bool, - pub eps: f32, + pub eps: f64, + pub max_seq_len: usize, } #[derive(Debug)] pub struct BitAttention { - qkv_proj: Bitlinear, + q_proj: Bitlinear, + k_proj: Bitlinear, + v_proj: Bitlinear, o_proj: Bitlinear, - norm: Option, dropout: f32, - dim: usize, - head_dim: usize, n_heads: usize, n_kv_heads: usize, + cos: Tensor, + sin: Tensor, + kv_cache: Option<(Tensor, Tensor)>, +} + +fn precompute_freqs_cis( + head_dim: usize, + freq_base: f32, + max_seq_len: usize, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, max_seq_len as u32, device)? + .to_dtype(DType::F32)? + .reshape((max_seq_len, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) } impl BitAttention { @@ -65,71 +86,94 @@ impl BitAttention { return Err(anyhow!("head_dim must be less than or equal to 128")); } - let total_head_dim = (cfg.n_heads + (2 * cfg.n_kv_heads)) * head_dim; - let qkv_proj = Bitlinear::load( + let q_proj = Bitlinear::load( BitlinearCfg { in_features: cfg.dim, - out_features: total_head_dim, - num_groups: 1, - b: 8, + out_features: cfg.n_heads * head_dim, eps: cfg.eps, - bias: cfg.bias, }, - vb.pp("qkv_proj"), + vb.pp("q_proj"), + )?; + let k_proj = Bitlinear::load( + BitlinearCfg { + in_features: cfg.dim, + out_features: cfg.n_kv_heads * head_dim, + eps: cfg.eps, + }, + vb.pp("k_proj"), + )?; + let v_proj = Bitlinear::load( + BitlinearCfg { + in_features: cfg.dim, + out_features: cfg.n_kv_heads * head_dim, + eps: cfg.eps, + }, + vb.pp("v_proj"), )?; - - let norm = match cfg.layer_norm_enabled { - true => { - let config = LayerNormConfig { - eps: cfg.eps.into(), - ..LayerNormConfig::default() - }; - Some(layer_norm( - head_dim * cfg.n_kv_heads, - config, - vb.pp("layer_norm"), - )?) - } - false => None, - }; let o_proj = Bitlinear::load( BitlinearCfg { in_features: cfg.dim, out_features: cfg.dim, - num_groups: 1, - b: 8, eps: cfg.eps, - bias: true, }, vb.pp("o_proj"), )?; + let (cos, sin) = precompute_freqs_cis(head_dim, 10000., cfg.max_seq_len, vb.device())?; + Ok(BitAttention { - qkv_proj, + q_proj, + k_proj, + v_proj, o_proj, - norm, - dim: cfg.dim, - head_dim, n_heads: cfg.n_heads, n_kv_heads: cfg.n_kv_heads, dropout: cfg.dropout, + cos, + sin, + kv_cache: None, }) } #[instrument] - pub fn forward(&self, x: &Tensor, is_causal: bool) -> Result { - let qkv = self.qkv_proj.forward(x)?; + fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result { + let dtype = x.dtype(); + let x = x.to_dtype(DType::F32)?; + let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + let x = rope_i(&x.contiguous()?, &cos, &sin)?; + let x = x.to_dtype(dtype)?; + Ok(x) + } - let kv_size = self.n_kv_heads * self.head_dim; - let q = qkv.narrow(D::Minus1, 0, self.dim)?; - let k = qkv.narrow(D::Minus1, self.dim, kv_size)?; - let v = qkv.narrow(D::Minus1, self.dim + kv_size, kv_size)?; + #[instrument] + pub fn forward(&mut self, x: &Tensor, is_causal: bool, index_pos: usize) -> Result { + let q = self.q_proj.forward(x)?; + let k = self.k_proj.forward(x)?; + let v = self.v_proj.forward(x)?; let q = einops!("b n ({self.n_heads} d) -> b n {self.n_heads} d", q); let k = einops!("b n ({self.n_kv_heads} d) -> b n {self.n_kv_heads} d", k); let v = einops!("b n ({self.n_kv_heads} d) -> b n {self.n_kv_heads} d", v); + let q = self.apply_rotary_emb(&q, index_pos)?; + let k = self.apply_rotary_emb(&k, index_pos)?; + let (k, v) = match &self.kv_cache { + None => (k, v), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k, v) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; + (k, v) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + let scale = (q.dims4()?.3 as f64).sqrt(); let x = scaled_dot_product_attention( &q, @@ -140,11 +184,8 @@ impl BitAttention { Some(is_causal), Some(scale), )?; + let x = einops!("b n h d -> b n (h d)", x); - let x = match self.norm { - Some(ref norm) => norm.forward(&x)?, - None => x, - }; let x = self.o_proj.forward(&x)?; Ok(x) } @@ -156,7 +197,7 @@ mod bit_attention_tests { bit_attention::{BitAttention, BitAttentionCfg}, utils_tensor::device, }; - use candle_core::{DType, Tensor}; + use candle_core::Tensor; use candle_nn::VarBuilder; #[test] @@ -165,74 +206,19 @@ mod bit_attention_tests { let vb = VarBuilder::zeros(candle_core::DType::F32, &device); let input_tensor = Tensor::randn(0.0f32, 1.0f32, (2, 8, 64), &device)?; - let bit_attention = BitAttention::load( - BitAttentionCfg { - dim: 64, - n_heads: 8, - n_kv_heads: 8, - bias: true, - dropout: 0.1, - layer_norm_enabled: true, - eps: 1e-6, - }, - vb, - )?; - - let output_tensor = bit_attention.forward(&input_tensor, true).unwrap(); - - assert_eq!(output_tensor.shape().dims(), &[2, 8, 64]); - - Ok(()) - } - - #[test] - fn forward_produces_expected_shape_f16() -> anyhow::Result<()> { - let device = device(true).unwrap(); - let vb = VarBuilder::zeros(candle_core::DType::F16, &device); - - let input_tensor = - Tensor::randn(0.0f32, 1.0f32, (2, 8, 64), &device)?.to_dtype(DType::F16)?; - let bit_attention = BitAttention::load( - BitAttentionCfg { - dim: 64, - n_heads: 8, - n_kv_heads: 8, - bias: true, - dropout: 0.1, - layer_norm_enabled: true, - eps: 1e-6, - }, - vb, - )?; - - let output_tensor = bit_attention.forward(&input_tensor, true).unwrap(); - - assert_eq!(output_tensor.shape().dims(), &[2, 8, 64]); - - Ok(()) - } - - #[test] - fn forward_produces_expected_shape_bf16() -> anyhow::Result<()> { - let device = device(true).unwrap(); - let vb = VarBuilder::zeros(candle_core::DType::BF16, &device); - - let input_tensor = - Tensor::randn(0.0f32, 1.0f32, (2, 8, 64), &device)?.to_dtype(DType::BF16)?; - let bit_attention = BitAttention::load( + let mut bit_attention = BitAttention::load( BitAttentionCfg { dim: 64, n_heads: 8, n_kv_heads: 8, - bias: true, dropout: 0.1, - layer_norm_enabled: true, eps: 1e-6, + max_seq_len: 64, }, vb, )?; - let output_tensor = bit_attention.forward(&input_tensor, true).unwrap(); + let output_tensor = bit_attention.forward(&input_tensor, true, 0).unwrap(); assert_eq!(output_tensor.shape().dims(), &[2, 8, 64]); diff --git a/src/bit_ffn.rs b/src/bit_ffn.rs index bf152e1..55f3d0c 100644 --- a/src/bit_ffn.rs +++ b/src/bit_ffn.rs @@ -1,7 +1,8 @@ use crate::bit_dropout::{Dropout, DropoutCfg}; use crate::bit_linear::{Bitlinear, BitlinearCfg}; +use crate::rms_norm::RmsNorm; use candle_core::{Module, Tensor}; -use candle_nn::{layer_norm, Activation, LayerNorm, LayerNormConfig, VarBuilder}; +use candle_nn::{Activation, VarBuilder}; use tracing::instrument; pub struct BitFeedForwardCfg { @@ -9,16 +10,16 @@ pub struct BitFeedForwardCfg { pub ff_mult: usize, pub dropout: f32, pub train: bool, - pub eps: f32, + pub eps: f64, } #[derive(Debug)] pub struct BitFeedForward { - glu_linear: Bitlinear, + proj_in: Bitlinear, activation: Activation, - norm: LayerNorm, + post_act_norm: RmsNorm, dropout: Dropout, - linear: Bitlinear, + proj_out: Bitlinear, } impl BitFeedForward { @@ -26,9 +27,8 @@ impl BitFeedForward { // Setup internal parameters let inner_dim = cfg.dim * cfg.ff_mult; - // GELU is used as activation function - // The original implementation has the option for SiLU, look into adding that at some point - let activation = Activation::Gelu; + // Use swiglu from 1.58 paper + let activation = Activation::Swiglu; // Dropout layer, if train is passed then this is skipped let dropout = Dropout::load(DropoutCfg { @@ -36,58 +36,46 @@ impl BitFeedForward { is_training: cfg.train, })?; - // Layer normalization - let norm = layer_norm( - inner_dim, - LayerNormConfig { - eps: cfg.eps.into(), - ..LayerNormConfig::default() - }, - vb.pp("norm"), - )?; + // Post activation normalization + let post_act_norm = RmsNorm::load(cfg.eps, inner_dim, vb.pp("norm"))?; - let glu_linear = Bitlinear::load( + // Input linear layer + let proj_in = Bitlinear::load( BitlinearCfg { in_features: cfg.dim, - out_features: inner_dim, - num_groups: 1, - b: 8, + out_features: inner_dim * 2, eps: cfg.eps, - bias: true, }, vb.pp("proj"), )?; // Linear layer - let linear = Bitlinear::load( + let proj_out = Bitlinear::load( BitlinearCfg { in_features: inner_dim, out_features: cfg.dim, - num_groups: 1, - b: 8, eps: cfg.eps, - bias: true, }, vb.pp("linear"), )?; // Return the layer as a sequential module Ok(Self { - glu_linear, + proj_in, activation, - norm, + post_act_norm, dropout, - linear, + proj_out, }) } #[instrument] pub fn forward(&self, x: &Tensor) -> anyhow::Result { - let x = self.glu_linear.forward(x)?; + let x = self.proj_in.forward(x)?; let x = self.activation.forward(&x)?; - let x = self.norm.forward(&x)?; + let x = self.post_act_norm.forward(&x)?; let x = self.dropout.forward(&x)?; - let x = self.linear.forward(&x)?; + let x = self.proj_out.forward(&x)?; Ok(x) } } diff --git a/src/bit_linear.rs b/src/bit_linear.rs index 4e6ccd3..a8cbb54 100644 --- a/src/bit_linear.rs +++ b/src/bit_linear.rs @@ -1,28 +1,20 @@ -use crate::utils_tensor::sign; use anyhow::Ok; use candle_core::{Tensor, D}; -use candle_nn::{layer_norm, Init, LayerNorm, LayerNormConfig, Module, VarBuilder}; -use candle_transformers::models::with_tracing::Linear; +use candle_nn::{Init, Module, VarBuilder}; +use candle_transformers::models::with_tracing::{Linear, RmsNorm}; use tracing::instrument; #[derive(Debug, Clone, Copy)] pub struct BitlinearCfg { pub in_features: usize, pub out_features: usize, - pub num_groups: usize, - pub b: i32, - pub eps: f32, - pub bias: bool, + pub eps: f64, } #[derive(Debug)] pub struct Bitlinear { - num_groups: usize, weight: Tensor, - bias: Option, - layer_norm: LayerNorm, - eps: f64, - q_b: f64, + layer_norm: RmsNorm, } impl Bitlinear { @@ -35,134 +27,42 @@ impl Bitlinear { stdev: 1.0, }, )?; - let bias = match cfg.bias { - true => Some(vb.get_with_hints(cfg.out_features, "bias", Init::Const(0.0))?), - false => None, - }; - let layer_norm = layer_norm( - cfg.in_features, - LayerNormConfig { - eps: cfg.eps.into(), - ..LayerNormConfig::default() - }, - vb.pp("layer_norm"), - )?; - let q_b = 2f64.powi(cfg.b - 1); - Ok(Self { - num_groups: cfg.num_groups, - weight, - layer_norm, - bias, - eps: cfg.eps as f64, - q_b, - }) - } - - #[instrument] - fn ste(&self, x: &Tensor) -> candle_core::Result { - let binarized_x = sign(x)?; - (binarized_x - x)?.detach() + x + let layer_norm = RmsNorm::new(cfg.in_features, cfg.eps, vb.pp("rms_norm"))?; + Ok(Self { weight, layer_norm }) } #[instrument] - fn binarize_weights_groupwise(&self) -> anyhow::Result<(Tensor, Tensor)> { - /* - * Note: - * The original code uses slice assignment on a zeroed tensor to create the final tensor - * We instead push the chunks into a vector then combine at the very end to avoid the need to call - * slice_assign. - */ - let mut binarized_weight_groups: Vec = Vec::with_capacity(self.num_groups); - let mut beta_groups: Vec = Vec::with_capacity(self.num_groups); - let group_size = self.weight.dims()[0] / self.num_groups; - for i in 0..self.num_groups { - let start_idx = i * group_size; - let weights_group = self.weight.narrow(0, start_idx, group_size)?; - let alpha_g = weights_group.mean_all()?; - let beta = weights_group.abs()?.mean(D::Minus1)?.mean(D::Minus1)?; - beta_groups.push(beta); - let binarized_weights = self.ste(&(weights_group.broadcast_sub(&alpha_g)?))?; - binarized_weight_groups.push(binarized_weights); + pub fn forward(&self, x: &Tensor) -> anyhow::Result { + fn activation_quant(x: &Tensor) -> anyhow::Result { + let scale = (127.0 + / x.abs()? + .max(D::Minus1)? + .max(D::Minus1)? + .clamp(1e-5, f32::INFINITY)?)?; + let y = x + .broadcast_mul(&scale.unsqueeze(D::Minus1)?.unsqueeze(D::Minus1)?)? + .clamp(-128.0, 127.0)?; + Ok(y) } - let binarized_weights = Tensor::cat(&binarized_weight_groups, D::Minus1)?; - let beta = Tensor::cat(&beta_groups, D::Minus1)?; - Ok((binarized_weights, beta)) - } - - #[instrument] - fn dequantize_activations( - &self, - x: &Tensor, - beta: &Tensor, - gamma: &Tensor, - ) -> anyhow::Result { - Ok((x - .broadcast_mul( - &gamma - .unsqueeze(D::Minus1) - .unwrap() - .unsqueeze(D::Minus1) - .unwrap(), - )? - .broadcast_mul(beta)? - / self.q_b)?) - } - - #[instrument] - fn quantize_activations(&self, x: &Tensor) -> anyhow::Result<(Tensor, Tensor)> { - let mut quantized_x_groups: Vec = Vec::with_capacity(self.num_groups); - let mut gamma_groups: Vec = Vec::with_capacity(self.num_groups); - let group_size = x.dims()[0] / self.num_groups; - for i in 0..self.num_groups { - let start_idx = i * group_size; - let activation_group = x.narrow(0, start_idx, group_size).unwrap(); - let gamma = activation_group.abs()?.max(D::Minus1)?.max(D::Minus1)?; - let clamp_min = -self.q_b + self.eps; - let clamp_max = self.q_b - self.eps; - let x = (activation_group * self.q_b).unwrap(); - let x = x - .broadcast_div( - &(gamma.clone() + self.eps) - .unwrap() - .unsqueeze(D::Minus1) - .unwrap() - .unsqueeze(D::Minus1) - .unwrap(), - ) - .unwrap(); - let quantized_x = x.clamp(clamp_min, clamp_max).unwrap(); - quantized_x_groups.push(quantized_x); - gamma_groups.push(gamma.clone()); + fn weight_quant(x: &Tensor) -> anyhow::Result { + let scale = x.abs()?.mean_all()?; + let e = x.mean_all()?; + let u = x.broadcast_sub(&e)?.sign()?.broadcast_mul(&scale)?; + Ok(u) } - let quantized_x = Tensor::cat(&quantized_x_groups, 0)?; - let gamma = Tensor::cat(&gamma_groups, 0)?; - Ok((quantized_x, gamma)) - } - #[instrument] - pub fn forward(&self, x: &Tensor) -> anyhow::Result { - // normalize input - let x = self.layer_norm.forward(x)?; + let weight = self.weight.clone(); - // binarize weights and quantize activations - let (binarized_weights, beta) = self.binarize_weights_groupwise()?; + let x_norm = self.layer_norm.forward(x)?; - // quantize activations - let (x_quantized, gamma) = self.quantize_activations(&x)?; + let x_quant = (x_norm.clone() + (activation_quant(&x_norm)? - x_norm)?.detach())?; - // perform linear transformation - let output = match &self.bias { - Some(bias) => { - Linear::from_weights(binarized_weights, Some(bias.clone())).forward(&x_quantized)? - } - None => Linear::from_weights(binarized_weights, None).forward(&x_quantized)?, - }; + let w_quant = (weight.clone() + (weight_quant(&weight)? - weight)?.detach())?; - // dequantize activations - let output = self.dequantize_activations(&output, &beta, &gamma)?; + let y = Linear::from_weights(w_quant, None).forward(&x_quant)?; - Ok(output) + Ok(y) } } @@ -183,16 +83,13 @@ mod bitlinear_tests { BitlinearCfg { in_features, out_features, - num_groups: 1, - b: 8, eps: 1e-6, - bias: true, }, vb, )?; let input: Tensor = Tensor::randn(0.0f32, 1.0f32, (1, 64), &device.clone())?; let output = bl.forward(&input)?; - assert_eq!(output.shape().dims2()?, (64, 64)); + assert_eq!(output.shape().dims2()?, (1, 64)); Ok(()) } } diff --git a/src/bit_transformer.rs b/src/bit_transformer.rs index c045787..951ca37 100644 --- a/src/bit_transformer.rs +++ b/src/bit_transformer.rs @@ -29,9 +29,8 @@ impl BitTransformer { n_heads: cfg.heads, n_kv_heads: 8, dropout: 0.1, - layer_norm_enabled: true, - bias: true, eps: cfg.eps, + max_seq_len: cfg.seq_len, }, vb.pp(&format!("attn.{i}")), ) @@ -63,14 +62,14 @@ impl BitTransformer { } #[instrument] - pub fn forward(&self, x: &Tensor) -> Result { + pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result { // Run the embedding layer let x_embed = self.embedding.forward(x)?; // Fold each block forward let mut x = x_embed.clone(); - for (attn, ffn) in &self.blocks { - x = attn.forward(&x, true)?; + for (attn, ffn) in self.blocks.iter_mut() { + x = attn.forward(&x, true, index_pos)?; x = x.add(&x_embed)?; x = ffn.forward(&x)?; x = x.add(&x)?; @@ -82,39 +81,3 @@ impl BitTransformer { Ok(x) } } - -#[cfg(test)] -mod bitnet_transformer_tests { - use crate::{config::Config, utils_tensor::device}; - - use super::BitTransformer; - use anyhow::Result; - use candle_core::{DType, Tensor}; - use candle_nn::VarBuilder; - - #[test] - fn it_applies_forward_pass() -> Result<()> { - let device = &device(true)?; - let vb = VarBuilder::zeros(DType::F32, device); - let t = BitTransformer::load( - Config { - dim: 8 * 8, - depth: 8, - vocab_size: 32000, - heads: 8, - ff_mult: 10, - eps: 1e-6, - ff_dropout: 0.1, - seq_len: 10, - }, - vb, - true, - )?; - let x = Tensor::ones((1, 128), DType::U32, device)?; - let x = t.forward(&x)?; - - assert_eq!(x.shape().dims(), &[1, 128, 32000]); - - Ok(()) - } -} diff --git a/src/config.rs b/src/config.rs index 6ca7e91..16bb6e8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,7 +5,7 @@ pub struct Config { pub(crate) vocab_size: usize, pub(crate) heads: usize, pub(crate) ff_mult: usize, - pub(crate) eps: f32, + pub(crate) eps: f64, pub(crate) ff_dropout: f32, pub(crate) seq_len: usize, } diff --git a/src/inference.rs b/src/inference.rs index 19cb0e2..443f98c 100644 --- a/src/inference.rs +++ b/src/inference.rs @@ -27,7 +27,7 @@ pub fn run(args: &InferenceCmd, common_args: &Args) -> Result<()> { let vb = VarBuilder::from_tensors(safetensors, DType::F32, &device); let config = Config::default(); - let model = BitTransformer::load(config, vb, false)?; + let mut model = BitTransformer::load(config, vb, false)?; println!("starting the inference loop"); let mut logits_processor = LogitsProcessor::new(rng.gen(), args.temperature, Some(args.top_p)); @@ -52,7 +52,7 @@ pub fn run(args: &InferenceCmd, common_args: &Args) -> Result<()> { let context_size = if index > 0 { 1 } else { tokens.len() }; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; - let logits = model.forward(&input)?; + let logits = model.forward(&input, index)?; let logits = logits.i((0, logits.dim(1)? - 1))?; let logits = if args.repeat_penalty == 1. || tokens.is_empty() { logits diff --git a/src/rms_norm.rs b/src/rms_norm.rs index 7bd2b78..d9d686d 100644 --- a/src/rms_norm.rs +++ b/src/rms_norm.rs @@ -8,8 +8,8 @@ pub struct RmsNorm { } impl RmsNorm { - pub fn load(rms_norm_eps: f32, size: usize, vb: VarBuilder) -> Result { - let inner = candle_nn::rms_norm(size, rms_norm_eps.into(), vb)?; + pub fn load(rms_norm_eps: f64, size: usize, vb: VarBuilder) -> Result { + let inner = candle_nn::rms_norm(size, rms_norm_eps, vb)?; Ok(Self { inner }) } } @@ -31,14 +31,14 @@ mod rmsnorm_tests { #[test] fn it_loads() -> Result<()> { let vb = VarBuilder::zeros(DType::F64, &Device::Cpu); - RmsNorm::load(1e-6f32, 512, vb)?; + RmsNorm::load(1e-6, 512, vb)?; Ok(()) } #[test] fn it_applies_forward_pass() -> Result<()> { let vb = VarBuilder::zeros(DType::F32, &Device::Cpu); - let rmsnorm = RmsNorm::load(1e-6f32, 512, vb)?; + let rmsnorm = RmsNorm::load(1e-6, 512, vb)?; let input = Tensor::ones((1, 512), DType::F32, &Device::Cpu)?; let output = rmsnorm.forward(&input).unwrap(); assert_eq!(output.shape().dims(), &[1, 512]); diff --git a/src/training.rs b/src/training.rs index 933b9b9..12bc519 100644 --- a/src/training.rs +++ b/src/training.rs @@ -30,7 +30,7 @@ fn valid_loss( let span = span!(tracing::Level::TRACE, "validate-loss-iter"); let _enter = span.enter(); let (inp, tgt) = inp_tgt?; - let logits = model.forward(&inp)?; + let logits = model.forward(&inp, 0)?; let loss = cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; sum_ce += match loss.dtype() { DType::F32 => f64::from(loss.to_vec0::()?), @@ -106,7 +106,7 @@ pub fn run(args: &TrainingCmd, common_args: &Args) -> Result<()> { let _enter = span.enter(); let (inp, tgt) = batch?; - let logits = model.forward(&inp)?; + let logits = model.forward(&inp, 0)?; let loss = cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?; training_loss = match dtype { candle_core::DType::F32 => f64::from(loss.to_vec0::()?), @@ -116,7 +116,7 @@ pub fn run(args: &TrainingCmd, common_args: &Args) -> Result<()> { }; opt.backward_step(&loss)?; - if batch_index > 0 && batch_index % 10 == 0 { + if batch_index > 0 && batch_index % 100 == 0 { validation_loss = valid_loss(args.seq_len, args.batch_size, &dataset, &mut model, &device)?; if batch_index % 10000 == 0 { diff --git a/src/utils_tensor.rs b/src/utils_tensor.rs index 60dc858..7985e3e 100644 --- a/src/utils_tensor.rs +++ b/src/utils_tensor.rs @@ -4,40 +4,6 @@ use candle_core::{DType, Device, Shape, Tensor, WithDType, D}; use candle_nn::ops::{self}; use tracing::instrument; -// Transform the input values of the tensor to it's signs, -1, 0 or 1 -#[instrument] -pub fn sign(x: &Tensor) -> candle_core::Result { - let zeros = x.eq(&Tensor::zeros(x.shape(), x.dtype(), x.device())?)?; - let abs_x = x.abs()?.add(&zeros.to_dtype(x.dtype())?)?; - let sign_x = (x / abs_x)?; - Ok(sign_x) -} - -#[cfg(test)] -mod sign_tests { - use crate::utils_tensor::sign; - use candle_core::{Device, Result, Tensor}; - - #[test] - fn it_works() -> Result<()> { - let input = vec![-3f32, -2f32, -1f32, 0f32, 1f32, 2f32, 3f32]; - let input_size = input.len(); - let tensor = Tensor::from_vec(input, (input_size,), &Device::Cpu)?; - let output = sign(&tensor).unwrap(); - - let expected_shape = [input_size]; - assert_eq!(output.shape().dims(), &expected_shape); - - let expected_output = [-1f32, -1f32, -1f32, 0f32, 1f32, 1f32, 1f32]; - let output = output.squeeze(0)?; - let output = output.to_vec1::()?; - - assert_eq!(output, expected_output); - - Ok(()) - } -} - // Get the device to use for the tensor operations, only really used for tests // Originally from: https://github.com/huggingface/candle/blob/314630638d8f6886c07d73211d6c35f8cf05d56a/candle-examples/src/lib.rs#L9 pub fn device(cpu: bool) -> Result {