Skip to content

Commit

Permalink
casting Optional as Tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
ynonaolga committed Nov 15, 2022
1 parent 1d2bc5c commit 06cb1a2
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions torch/nn/functional.py
Expand Up @@ -5023,16 +5023,19 @@ def multi_head_attention_forward(
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

merged_mask, mask_type = _merge_masks(attn_mask, key_padding_mask, query, num_heads)
in_proj_weight_ = cast(torch.Tensor, in_proj_weight)
in_proj_bias_ = cast(torch.Tensor, in_proj_bias)
out_proj_bias_ = cast(torch.Tensor, out_proj_bias)
attn_output, attn_output_weights = torch._native_multi_head_attention(
query,
key,
value,
embed_dim_to_check,
num_heads,
cast(torch.Tensor, in_proj_weight),
cast(torch.Tensor, in_proj_bias),
in_proj_weight_,
in_proj_bias_,
out_proj_weight,
cast(torch.Tensor, out_proj_bias),
out_proj_bias_,
merged_mask,
need_weights,
average_attn_weights,
Expand Down

0 comments on commit 06cb1a2

Please sign in to comment.