Skip to content

Commit

Permalink
add null check in_proj, out_proj
Browse files Browse the repository at this point in the history
  • Loading branch information
ynonaolga committed Nov 14, 2022
1 parent 1ada67e commit b42f855
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 26 deletions.
4 changes: 2 additions & 2 deletions test/test_jit.py
Expand Up @@ -15106,8 +15106,8 @@ def test_functional_multi_head_attn_fast_path_None(self):
with torch.no_grad():
fn_out = torch.nn.functional.multi_head_attention_forward(query, key, value,
embed_size, nhead,
None,
None,
multi_head_attn_nn.in_proj_weight,
multi_head_attn_nn.in_proj_bias,
multi_head_attn_nn.bias_k,
multi_head_attn_nn.bias_v,
multi_head_attn_nn.add_zero_attn,
Expand Down
41 changes: 18 additions & 23 deletions torch/nn/functional.py
Expand Up @@ -2524,8 +2524,6 @@ def group_norm(
"""
if has_torch_function_variadic(input, weight, bias):
return handle_torch_function(group_norm, (input, weight, bias,), input, num_groups, weight=weight, bias=bias, eps=eps)
if input.dim() < 2:
raise RuntimeError(f"Expected at least 2 dimensions for input tensor but received {input.dim()}")
_verify_batch_size([input.size(0) * input.size(1) // num_groups, num_groups] + list(input.size()[2:]))
return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled)

Expand Down Expand Up @@ -4993,30 +4991,20 @@ def multi_head_attention_forward(
:math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(num_heads, L, S)` when input is unbatched or :math:`(N, num_heads, L, S)`.
"""
with torch.no_grad():
in_proj_weight_ = torch.nn.init.xavier_uniform_(torch.empty((3 * embed_dim_to_check, embed_dim_to_check),\
device=query.device, dtype=query.dtype))\
if in_proj_weight is None or in_proj_weight.dim() != 2 else in_proj_weight
with torch.no_grad():
in_proj_bias_ = torch.zeros(3 * embed_dim_to_check, device=query.device, dtype=query.dtype) \
if in_proj_bias is None else in_proj_bias

with torch.no_grad():
out_proj_bias_ = torch.zeros(embed_dim_to_check, device=query.device, dtype=query.dtype) \
if out_proj_bias is None else out_proj_bias
why_not_fast_path = _can_use_fastpath(query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight_,
in_proj_bias_,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p,
out_proj_weight,
out_proj_bias_,
out_proj_bias,
training,
key_padding_mask,
attn_mask)
Expand All @@ -5038,10 +5026,10 @@ def multi_head_attention_forward(
value,
embed_dim_to_check,
num_heads,
in_proj_weight_,
in_proj_bias_,
in_proj_weight,
in_proj_bias,
out_proj_weight,
out_proj_bias_,
out_proj_bias,
merged_mask,
need_weights,
average_attn_weights,
Expand Down Expand Up @@ -5116,14 +5104,14 @@ def _can_use_fastpath(
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Tensor,
in_proj_bias: Tensor,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
attn_mask: Optional[Tensor] = None,
Expand Down Expand Up @@ -5183,13 +5171,20 @@ def _can_use_fastpath(
out_proj_weight,
out_proj_bias,
)

# We have to use list comprehensions below because TorchScript does not support
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
elif in_proj_weight is None:
why_not_fast_path = "in projection weight is None"
elif in_proj_bias is None:
why_not_fast_path = "in projection bias is None"
elif out_proj_bias is None:
why_not_fast_path = "out projecion bias is None"
elif not all([x is not None and (x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = ("grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad")
return why_not_fast_path
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/activation.py
Expand Up @@ -1060,7 +1060,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: O
.. note::
`batch_first` argument is ignored for unbatched inputs.
"""
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True, attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Expand Down

0 comments on commit b42f855

Please sign in to comment.