From abd83ce18046e84cf6a1d67b85d2b0d02b89c895 Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Mon, 11 Sep 2023 18:43:53 -0700 Subject: [PATCH] Small fix in SDPA docstring codeblock (#109086) Fix https://github.com/pytorch/pytorch/issues/109072 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109086 Approved by: https://github.com/drisspg --- torch/nn/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 2e1f6be0a6bcb..d9ff7b5604ff9 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -4922,7 +4922,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. if attn_mask is not None: if attn_mask.dtype == torch.bool: - attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf")) + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor