Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
voznesenskym committed Jan 27, 2023
1 parent a6d27fd commit ff18ac2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
29 changes: 21 additions & 8 deletions torch/_refs/__init__.py
Expand Up @@ -948,8 +948,22 @@ def _ref(
return inner


def shape_preserving_add_custom(*args, **kwargs):
if isinstance(args[1], (int, float, SymInt, SymFloat)):
return args[0]
if isinstance(args[0], (int, float, SymInt, SymFloat)):
return args[1]
if args[0].shape == args[1].shape:
return FakeTensor(
args[0].fake_mode,
torch.empty(args[0].shape, device="meta"),
device=args[0].device,
)
return None


# Add has its own implementation because it has an alpha argument
@register_decomposition(aten.add, shape_preserving=shape_preserving_default)
@register_decomposition(aten.add, shape_preserving=shape_preserving_add_custom)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
Expand Down Expand Up @@ -1180,7 +1194,6 @@ def float_power(
@_make_elementwise_binary_reference(
type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
supports_two_python_scalars=True,
shape_preserving=shape_preserving_default,
)
def floor_divide(
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
Expand Down Expand Up @@ -1600,7 +1613,7 @@ def rsub(
# TODO: add docstring
# TODO: consider refactoring this with add impl
# sub has its own implementation because it has an alpha argument
@register_decomposition(aten.sub, shape_preserving=shape_preserving_default)
@register_decomposition(aten.sub, shape_preserving=shape_preserving_add_custom)
@out_wrapper()
@elementwise_type_promotion_wrapper(
type_promoting_args=("a", "b"),
Expand Down Expand Up @@ -5220,7 +5233,7 @@ def exponential(self, rate=1, generator=None):
abs_ = _make_inplace(abs)
acos_ = _make_inplace(acos)
acosh_ = _make_inplace(acosh)
add_ = _make_inplace(add, shape_preserving=shape_preserving_default)
add_ = _make_inplace(add, shape_preserving=shape_preserving_add_custom)
addcmul_ = _make_inplace(addcmul, shape_preserving=shape_preserving_mul_custom)
addcdiv_ = _make_inplace(addcdiv)
asin_ = _make_inplace(asin)
Expand All @@ -5235,12 +5248,12 @@ def exponential(self, rate=1, generator=None):
bitwise_right_shift_ = _make_inplace(bitwise_right_shift)
bitwise_xor_ = _make_inplace(bitwise_xor)
ceil_ = _make_inplace(ceil, shape_preserving=shape_preserving_default)
clamp_ = _make_inplace(clamp, shape_preserving=shape_preserving_default)
clamp_ = _make_inplace(clamp)
clamp_min_ = _make_inplace(clamp_min)
clamp_max_ = _make_inplace(clamp_max)
conj_physical_ = _make_inplace(conj_physical)
copysign_ = _make_inplace(copysign)
cos_ = _make_inplace(cos, shape_preserving=shape_preserving_default)
cos_ = _make_inplace(cos)
cosh_ = _make_inplace(cosh)
cumsum_ = _make_inplace(cumsum)
digamma_ = _make_inplace(digamma)
Expand Down Expand Up @@ -5291,12 +5304,12 @@ def exponential(self, rate=1, generator=None):
sgn_ = _make_inplace(sgn)
sigmoid_ = _make_inplace(sigmoid)
sign_ = _make_inplace(sign)
sin_ = _make_inplace(sin, shape_preserving=shape_preserving_default)
sin_ = _make_inplace(sin)
sinc_ = _make_inplace(sinc)
sinh_ = _make_inplace(sinh)
sqrt_ = _make_inplace(sqrt)
square_ = _make_inplace(square)
sub_ = _make_inplace(sub, shape_preserving=shape_preserving_default)
sub_ = _make_inplace(sub, shape_preserving=shape_preserving_add_custom)
tan_ = _make_inplace(tan)
tanh_ = _make_inplace(tanh)
tril_ = _make_inplace(tril)
Expand Down
6 changes: 3 additions & 3 deletions torch/_refs/nn/functional/__init__.py
Expand Up @@ -6,7 +6,7 @@
import torch._prims as prims
import torch._prims_common as utils
import torch._refs as refs
from torch._decomp import register_decomposition
from torch._decomp import register_decomposition, shape_preserving_default
from torch._decomp.decompositions import Reduction
from torch._prims_common import (
check,
Expand Down Expand Up @@ -926,7 +926,7 @@ def _triplet_margin_with_distance_loss(
return _apply_loss_reduction(loss, reduction)


@register_decomposition(aten.hardtanh)
@register_decomposition(aten.hardtanh, shape_preserving=shape_preserving_default)
@inplace_wrapper
@out_wrapper()
@elementwise_unary_scalar_wrapper
Expand Down Expand Up @@ -959,7 +959,7 @@ def hardtanh(
return torch.clamp(a, min_val, max_val) # type: ignore[arg-type]


@register_decomposition(aten.gelu)
@register_decomposition(aten.gelu, shape_preserving=shape_preserving_default)
@out_wrapper()
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(
Expand Down

0 comments on commit ff18ac2

Please sign in to comment.