diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 818a964117..6adc16ec59 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -943,6 +943,12 @@ def __rsub__(self, other): def __rmul__(self, other): return mul(other, self) + def __rtruediv__(self, other): + return true_div(other, self) + + def __rfloordiv__(self, other): + return int_div(other, self) + def __rmod__(self, other): return mod(other, self) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 474d08c49d..31e08fd39b 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -13,7 +13,6 @@ from pytensor.graph.utils import MetaType from pytensor.scalar import ( ComplexError, - IntegerDivisionError, ) from pytensor.tensor import _get_vector_length from pytensor.tensor.exceptions import AdvancedIndexingError @@ -138,18 +137,6 @@ def __mul__(self, other): except (NotImplementedError, TypeError): return NotImplemented - def __div__(self, other): - # See explanation in __add__ for the error caught - # and the return value in that case - try: - return pt.math.div_proxy(self, other) - except IntegerDivisionError: - # This is to raise the exception that occurs when trying to divide - # two integer arrays (currently forbidden). - raise - except (NotImplementedError, TypeError): - return NotImplemented - def __pow__(self, other): # See explanation in __add__ for the error caught # and the return value in that case @@ -210,9 +197,6 @@ def __rsub__(self, other): def __rmul__(self, other): return pt.math.mul(other, self) - def __rdiv__(self, other): - return pt.math.div_proxy(other, self) - def __rmod__(self, other): return pt.math.mod(other, self) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index e5882b1dba..db6a66036d 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -308,9 +308,6 @@ def __sub__(self, other): def __mul__(self, other): return px.math.mul(self, other) - def __div__(self, other): - return px.math.div(self, other) - def __pow__(self, other): return px.math.pow(self, other) @@ -341,9 +338,6 @@ def __rsub__(self, other): def __rmul__(self, other): return px.math.mul(other, self) - def __rdiv__(self, other): - return px.math.div_proxy(other, self) - def __rmod__(self, other): return px.math.mod(other, self) diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 138d99abc6..b6afdab9e8 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -10,7 +10,9 @@ EQ, ComplexError, Composite, + IntDiv, ScalarType, + TrueDiv, add, and_, arccos, @@ -531,3 +533,19 @@ def test_scalar_hash_default_output_type_preference(): del old_eq.output_types_preference # mimic old Op assert new_eq == old_eq assert hash(new_eq) == hash(old_eq) + + +def test_rtruediv(): + x = ScalarType(dtype="float64")() + y = 1.0 / x + assert isinstance(y.owner.op, TrueDiv) + assert isinstance(y.type, ScalarType) + assert y.eval({x: 2.0}) == 0.5 + + +def test_rfloordiv(): + x = ScalarType(dtype="float64")() + y = 5.0 // x + assert isinstance(y.owner.op, IntDiv) + assert isinstance(y.type, ScalarType) + assert y.eval({x: 2.0}) == 2.0