Skip to content

Commit

Permalink
integrate candle einsum library
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsanbear committed Mar 12, 2024
1 parent 35c6ba9 commit 6819c17
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ intel-mkl-src = {version="0.8.1", optional = true}
cudarc = {version="0.10.0", optional = true}
candle-onnx = {version="0.4.1", optional = true}
anyhow = "1.0.80"
candle-einops = "0.1.0"

[build-dependencies]
bindgen_cuda = { version = "0.1.1", optional = true }
Expand Down
26 changes: 9 additions & 17 deletions src/utils_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::{anyhow, Result};
use candle_core::utils::{cuda_is_available, metal_is_available};
use candle_core::{DType, Device, Tensor, D};
use candle_einops::einops;
use candle_nn::ops::softmax;
use candle_nn::Dropout;
use tracing::{event, Level};
Expand Down Expand Up @@ -104,15 +105,9 @@ pub fn scaled_dot_product_gqa(
return Err(anyhow!("Input tensors must have 4 dimensions"));
};

// Move sequence length dimension to axis 2, this makes it faster in torch
// "b n h d -> b h n d"
let query = query.permute([0, 2, 1, 3])?;

// "b s h d -> b h s d"
let key = key.permute([0, 2, 1, 3])?;

// "b s h d -> b h s d"
let value = value.permute([0, 2, 1, 3])?;
let query = einops!("b n h d -> b h n d", &query);
let key = einops!("b s h d -> b h s d", &key);
let value = einops!("b s h d -> b h s d", &value);

// Extract the dimensions
let (bq, hq, _nq, dq) = query.dims4()?;
Expand Down Expand Up @@ -149,14 +144,11 @@ pub fn scaled_dot_product_gqa(
true => {
// query = rearrange(query, "b (h g) n d -> b g h n d", g=num_head_groups)
// similarity = einsum(query, key, "b g h n d, b h s d -> b h n s")
let (batch_size, heads, seq_len, depth) = query.dims4()?;
let heads = heads / num_head_groups; // Calculate the number of heads per group.

// Reshape query to [batch, num_head_groups, heads, seq_len, depth]
let query_reshaped =
query.reshape((batch_size, num_head_groups, heads, seq_len, depth))?;

let query_for_matmul = query_reshaped.sum(1)?;
let query = einops!(
"b (h {num_head_groups}) n d -> b {num_head_groups} h n d",
&query
);
let query_for_matmul = query.sum(1)?;

// Transpose the last two dimensions of key to align them for matmul.
let key_transposed = key.transpose(D::Minus2, D::Minus1)?; // [batch, heads, depth, seq_len]
Expand Down

0 comments on commit 6819c17

Please sign in to comment.