Skip to content

Commit

Permalink
fixing the falsh attn mask
Browse files Browse the repository at this point in the history
  • Loading branch information
HamidShojanazeri committed Oct 9, 2023
1 parent 504c734 commit 16ad186
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/large_models/tp_llama/llama2.py
Expand Up @@ -206,7 +206,8 @@ def forward(
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
#calling PT SDPA to enable using Flash Attention 2 and Xformer memory efficient kernels.
output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, is_causal=True)
output = torch.nn.functional.scaled_dot_product_attention(xq.transpose(1,2), keys.transpose(1,2), values.transpose(1,2), attn_mask=mask, dropout_p=0.0, is_causal=False)

output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)

Expand Down

0 comments on commit 16ad186

Please sign in to comment.