From ed74c083c31f8a817ab1b8c99afc242221260391 Mon Sep 17 00:00:00 2001 From: Ryan Soklaski Date: Tue, 28 Jul 2020 15:30:53 -0400 Subject: [PATCH] dispatch special cases for x ** 1 and x ** 2 --- src/mygrad/tensor_base.py | 10 +++ tests/tensor_base/test_graph_tracking.py | 2 +- tests/tensor_base/test_pow_special_cases.py | 72 +++++++++++++++++++++ 3 files changed, 83 insertions(+), 1 deletion(-) create mode 100644 tests/tensor_base/test_pow_special_cases.py diff --git a/src/mygrad/tensor_base.py b/src/mygrad/tensor_base.py index 10dae665..d2075e5d 100644 --- a/src/mygrad/tensor_base.py +++ b/src/mygrad/tensor_base.py @@ -5,6 +5,7 @@ """ from functools import wraps +from numbers import Number from typing import Optional, Set, Type, Union import numpy as np @@ -19,6 +20,7 @@ Negative, Positive, Power, + Square, Subtract, ) from mygrad.operation_base import BroadcastableOp, Operation @@ -672,6 +674,14 @@ def __rmatmul__(self, other): return self._op(MatMul, other, self) def __pow__(self, other): + if isinstance(other, Number) or ( + isinstance(other, np.ndarray) and other.ndim == 0 + ): + if other == 1: + return self._op(Positive, self) + elif other == 2: + return self._op(Square, self) + return self._op(Power, self, other) def __rpow__(self, other): diff --git a/tests/tensor_base/test_graph_tracking.py b/tests/tensor_base/test_graph_tracking.py index 1178ff8d..6a5d5d53 100644 --- a/tests/tensor_base/test_graph_tracking.py +++ b/tests/tensor_base/test_graph_tracking.py @@ -16,7 +16,7 @@ def test_op_tracks_graph(): h = z - f assert h.creator.graph == {h.creator} | f.creator.graph - i = ((h + 3) ** 2) / 5 + i = ((h + 3) ** 4) / 5 assert h.creator.graph < i.creator.graph assert ( len(i.creator.graph - h.creator.graph) == 3 diff --git a/tests/tensor_base/test_pow_special_cases.py b/tests/tensor_base/test_pow_special_cases.py new file mode 100644 index 00000000..d301ab45 --- /dev/null +++ b/tests/tensor_base/test_pow_special_cases.py @@ -0,0 +1,72 @@ +from functools import partial + +import hypothesis.strategies as st +import numpy as np +import pytest +from hypothesis import given + +import mygrad as mg +from mygrad.math.arithmetic.ops import Positive, Square + +from ..wrappers.uber import backprop_test_factory, fwdprop_test_factory + + +def custom_pow(x, p, constant=False): + out = x ** p + if isinstance(out, mg.Tensor): + out._constant = constant + return out + + +def any_scalar(*args, p): + return st.sampled_from([int(p), float(p), np.array(p)]) + + +@pytest.mark.parametrize("power, op", [(1, Positive), (2, Square)]) +def test_pow_uses_special_case(power, op): + @given(exp=st.sampled_from([int(power), float(power), np.array(power)])) + def wrapped_func(exp): + out = mg.arange(2) ** exp + assert isinstance(out.creator, op) + + wrapped_func() + + +@fwdprop_test_factory( + mygrad_func=custom_pow, + true_func=custom_pow, + num_arrays=1, + kwargs={"p": partial(any_scalar, p=1)}, +) +def test_pow_1_fwd(): + pass + + +@backprop_test_factory( + mygrad_func=custom_pow, + true_func=custom_pow, + num_arrays=1, + kwargs={"p": partial(any_scalar, p=1)}, +) +def test_pow_1_bkwd(): + pass + + +@fwdprop_test_factory( + mygrad_func=custom_pow, + true_func=custom_pow, + num_arrays=1, + kwargs={"p": partial(any_scalar, p=2)}, +) +def test_pow_2_fwd(): + pass + + +@backprop_test_factory( + mygrad_func=custom_pow, + true_func=custom_pow, + num_arrays=1, + kwargs={"p": partial(any_scalar, p=2)}, +) +def test_pow_2_bkwd(): + pass