@@ -463,11 +463,13 @@ def scaled_dot_product_attention_decomposition(
463
463
) -> torch .Tensor :
464
464
L , S = query .size (- 2 ), key .size (- 2 )
465
465
device = query .device
466
- attn_bias = torch .zeros (L , S , dtype = query .dtype , device = device )
466
+
467
+ if is_causal or attn_mask is not None :
468
+ attn_bias = torch .zeros ((L , S ), dtype = query .dtype , device = device )
467
469
468
470
if is_causal :
469
471
assert attn_mask is None , "attn_mask must be None when is_causal=True"
470
- temp_mask = torch .ones (L , S , dtype = torch .bool , device = device ).tril (diagonal = 0 )
472
+ temp_mask = torch .ones (( L , S ) , dtype = torch .bool , device = device ).tril (diagonal = 0 )
471
473
attn_bias = attn_bias .masked_fill (temp_mask .logical_not (), float ("-inf" ))
472
474
473
475
if attn_mask is not None :
@@ -480,7 +482,7 @@ def scaled_dot_product_attention_decomposition(
480
482
key = key .repeat_interleave (query .size (- 3 ) // key .size (- 3 ), - 3 )
481
483
value = value .repeat_interleave (query .size (- 3 ) // value .size (- 3 ), - 3 )
482
484
483
- attn_weight = query @ key .transpose (- 2 , - 1 )
485
+ attn_weight = torch . matmul ( query , key .transpose (- 2 , - 1 ) )
484
486
485
487
if scale is None :
486
488
scale = torch .sqrt (torch .scalar_tensor (query .size (- 1 ), dtype = torch .int )).to (
@@ -490,9 +492,12 @@ def scaled_dot_product_attention_decomposition(
490
492
else :
491
493
attn_weight = attn_weight * scale
492
494
493
- attn_weight = attn_weight + attn_bias
495
+ if is_causal or attn_mask is not None :
496
+ # We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0.
497
+ attn_weight = attn_weight + attn_bias
498
+
494
499
attn_weight = torch .softmax (attn_weight , dim = - 1 )
495
- return attn_weight @ value
500
+ return torch . matmul ( attn_weight , value )
496
501
497
502
498
503
@register_torch_trt_decomposition (
0 commit comments