Skip to content

Commit

Permalink
refactor: rewrite several sections to better match the 1.58 paper & a…
Browse files Browse the repository at this point in the history
…dd rope (#3)

* lot's of random changes to allign with original paper

* more fixes

* more fixes

* unused import

* fix test

* clippy fixes
  • Loading branch information
tomsanbear committed Apr 18, 2024
1 parent 4628f85 commit 2adf025
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 359 deletions.
2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ name = "bitnet-rs"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
accelerate-src = { version = "0.3.2", optional = true }
anyhow = "1.0.80"
Expand Down
202 changes: 94 additions & 108 deletions src/bit_attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use crate::{
bit_linear::{Bitlinear, BitlinearCfg},
utils_tensor::scaled_dot_product_attention,
};
use anyhow::{anyhow, Result};
use candle_core::{Tensor, D};
use anyhow::{anyhow, Ok, Result};
use candle_core::{DType, Device, Tensor};
use candle_einops::einops;
use candle_nn::{layer_norm, LayerNormConfig, Module, VarBuilder};
use candle_nn::{rotary_emb::rope_i, VarBuilder};
use tracing::instrument;

#[derive(Debug, Clone, Copy)]
Expand All @@ -14,21 +14,42 @@ pub struct BitAttentionCfg {
pub n_heads: usize,
pub n_kv_heads: usize,
pub dropout: f32,
pub bias: bool,
pub layer_norm_enabled: bool,
pub eps: f32,
pub eps: f64,
pub max_seq_len: usize,
}

#[derive(Debug)]
pub struct BitAttention {
qkv_proj: Bitlinear,
q_proj: Bitlinear,
k_proj: Bitlinear,
v_proj: Bitlinear,
o_proj: Bitlinear,
norm: Option<candle_nn::LayerNorm>,
dropout: f32,
dim: usize,
head_dim: usize,
n_heads: usize,
n_kv_heads: usize,
cos: Tensor,
sin: Tensor,
kv_cache: Option<(Tensor, Tensor)>,
}

fn precompute_freqs_cis(
head_dim: usize,
freq_base: f32,
max_seq_len: usize,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let cos = idx_theta.cos()?;
let sin = idx_theta.sin()?;
Ok((cos, sin))
}

impl BitAttention {
Expand Down Expand Up @@ -65,71 +86,94 @@ impl BitAttention {
return Err(anyhow!("head_dim must be less than or equal to 128"));
}

let total_head_dim = (cfg.n_heads + (2 * cfg.n_kv_heads)) * head_dim;
let qkv_proj = Bitlinear::load(
let q_proj = Bitlinear::load(
BitlinearCfg {
in_features: cfg.dim,
out_features: total_head_dim,
num_groups: 1,
b: 8,
out_features: cfg.n_heads * head_dim,
eps: cfg.eps,
bias: cfg.bias,
},
vb.pp("qkv_proj"),
vb.pp("q_proj"),
)?;
let k_proj = Bitlinear::load(
BitlinearCfg {
in_features: cfg.dim,
out_features: cfg.n_kv_heads * head_dim,
eps: cfg.eps,
},
vb.pp("k_proj"),
)?;
let v_proj = Bitlinear::load(
BitlinearCfg {
in_features: cfg.dim,
out_features: cfg.n_kv_heads * head_dim,
eps: cfg.eps,
},
vb.pp("v_proj"),
)?;

let norm = match cfg.layer_norm_enabled {
true => {
let config = LayerNormConfig {
eps: cfg.eps.into(),
..LayerNormConfig::default()
};
Some(layer_norm(
head_dim * cfg.n_kv_heads,
config,
vb.pp("layer_norm"),
)?)
}
false => None,
};

let o_proj = Bitlinear::load(
BitlinearCfg {
in_features: cfg.dim,
out_features: cfg.dim,
num_groups: 1,
b: 8,
eps: cfg.eps,
bias: true,
},
vb.pp("o_proj"),
)?;

let (cos, sin) = precompute_freqs_cis(head_dim, 10000., cfg.max_seq_len, vb.device())?;

Ok(BitAttention {
qkv_proj,
q_proj,
k_proj,
v_proj,
o_proj,
norm,
dim: cfg.dim,
head_dim,
n_heads: cfg.n_heads,
n_kv_heads: cfg.n_kv_heads,
dropout: cfg.dropout,
cos,
sin,
kv_cache: None,
})
}

#[instrument]
pub fn forward(&self, x: &Tensor, is_causal: bool) -> Result<Tensor> {
let qkv = self.qkv_proj.forward(x)?;
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let dtype = x.dtype();
let x = x.to_dtype(DType::F32)?;
let (_b_sz, _n_head, seq_len, _n_embd) = x.dims4()?;
let cos = self.cos.narrow(0, index_pos, seq_len)?;
let sin = self.sin.narrow(0, index_pos, seq_len)?;
let x = rope_i(&x.contiguous()?, &cos, &sin)?;
let x = x.to_dtype(dtype)?;
Ok(x)
}

let kv_size = self.n_kv_heads * self.head_dim;
let q = qkv.narrow(D::Minus1, 0, self.dim)?;
let k = qkv.narrow(D::Minus1, self.dim, kv_size)?;
let v = qkv.narrow(D::Minus1, self.dim + kv_size, kv_size)?;
#[instrument]
pub fn forward(&mut self, x: &Tensor, is_causal: bool, index_pos: usize) -> Result<Tensor> {
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;

let q = einops!("b n ({self.n_heads} d) -> b n {self.n_heads} d", q);
let k = einops!("b n ({self.n_kv_heads} d) -> b n {self.n_kv_heads} d", k);
let v = einops!("b n ({self.n_kv_heads} d) -> b n {self.n_kv_heads} d", v);

let q = self.apply_rotary_emb(&q, index_pos)?;
let k = self.apply_rotary_emb(&k, index_pos)?;
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((k_cache, v_cache)) => {
if index_pos == 0 {
(k, v)
} else {
let k = Tensor::cat(&[k_cache, &k], 2)?;
let v = Tensor::cat(&[v_cache, &v], 2)?;
(k, v)
}
}
};
self.kv_cache = Some((k.clone(), v.clone()));

let scale = (q.dims4()?.3 as f64).sqrt();
let x = scaled_dot_product_attention(
&q,
Expand All @@ -140,11 +184,8 @@ impl BitAttention {
Some(is_causal),
Some(scale),
)?;

let x = einops!("b n h d -> b n (h d)", x);
let x = match self.norm {
Some(ref norm) => norm.forward(&x)?,
None => x,
};
let x = self.o_proj.forward(&x)?;
Ok(x)
}
Expand All @@ -156,7 +197,7 @@ mod bit_attention_tests {
bit_attention::{BitAttention, BitAttentionCfg},
utils_tensor::device,
};
use candle_core::{DType, Tensor};
use candle_core::Tensor;
use candle_nn::VarBuilder;

#[test]
Expand All @@ -165,74 +206,19 @@ mod bit_attention_tests {
let vb = VarBuilder::zeros(candle_core::DType::F32, &device);

let input_tensor = Tensor::randn(0.0f32, 1.0f32, (2, 8, 64), &device)?;
let bit_attention = BitAttention::load(
BitAttentionCfg {
dim: 64,
n_heads: 8,
n_kv_heads: 8,
bias: true,
dropout: 0.1,
layer_norm_enabled: true,
eps: 1e-6,
},
vb,
)?;

let output_tensor = bit_attention.forward(&input_tensor, true).unwrap();

assert_eq!(output_tensor.shape().dims(), &[2, 8, 64]);

Ok(())
}

#[test]
fn forward_produces_expected_shape_f16() -> anyhow::Result<()> {
let device = device(true).unwrap();
let vb = VarBuilder::zeros(candle_core::DType::F16, &device);

let input_tensor =
Tensor::randn(0.0f32, 1.0f32, (2, 8, 64), &device)?.to_dtype(DType::F16)?;
let bit_attention = BitAttention::load(
BitAttentionCfg {
dim: 64,
n_heads: 8,
n_kv_heads: 8,
bias: true,
dropout: 0.1,
layer_norm_enabled: true,
eps: 1e-6,
},
vb,
)?;

let output_tensor = bit_attention.forward(&input_tensor, true).unwrap();

assert_eq!(output_tensor.shape().dims(), &[2, 8, 64]);

Ok(())
}

#[test]
fn forward_produces_expected_shape_bf16() -> anyhow::Result<()> {
let device = device(true).unwrap();
let vb = VarBuilder::zeros(candle_core::DType::BF16, &device);

let input_tensor =
Tensor::randn(0.0f32, 1.0f32, (2, 8, 64), &device)?.to_dtype(DType::BF16)?;
let bit_attention = BitAttention::load(
let mut bit_attention = BitAttention::load(
BitAttentionCfg {
dim: 64,
n_heads: 8,
n_kv_heads: 8,
bias: true,
dropout: 0.1,
layer_norm_enabled: true,
eps: 1e-6,
max_seq_len: 64,
},
vb,
)?;

let output_tensor = bit_attention.forward(&input_tensor, true).unwrap();
let output_tensor = bit_attention.forward(&input_tensor, true, 0).unwrap();

assert_eq!(output_tensor.shape().dims(), &[2, 8, 64]);

Expand Down

0 comments on commit 2adf025

Please sign in to comment.