Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,12 @@ def __rsub__(self, other):
def __rmul__(self, other):
return mul(other, self)

def __rtruediv__(self, other):
return true_div(other, self)

def __rfloordiv__(self, other):
return int_div(other, self)

def __rmod__(self, other):
return mod(other, self)

Expand Down
16 changes: 0 additions & 16 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pytensor.graph.utils import MetaType
from pytensor.scalar import (
ComplexError,
IntegerDivisionError,
)
from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
Expand Down Expand Up @@ -138,18 +137,6 @@ def __mul__(self, other):
except (NotImplementedError, TypeError):
return NotImplemented

def __div__(self, other):
# See explanation in __add__ for the error caught
# and the return value in that case
try:
return pt.math.div_proxy(self, other)
except IntegerDivisionError:
# This is to raise the exception that occurs when trying to divide
# two integer arrays (currently forbidden).
raise
except (NotImplementedError, TypeError):
return NotImplemented

def __pow__(self, other):
# See explanation in __add__ for the error caught
# and the return value in that case
Expand Down Expand Up @@ -210,9 +197,6 @@ def __rsub__(self, other):
def __rmul__(self, other):
return pt.math.mul(other, self)

def __rdiv__(self, other):
return pt.math.div_proxy(other, self)

def __rmod__(self, other):
return pt.math.mod(other, self)

Expand Down
6 changes: 0 additions & 6 deletions pytensor/xtensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,6 @@ def __sub__(self, other):
def __mul__(self, other):
return px.math.mul(self, other)

def __div__(self, other):
return px.math.div(self, other)

def __pow__(self, other):
return px.math.pow(self, other)

Expand Down Expand Up @@ -341,9 +338,6 @@ def __rsub__(self, other):
def __rmul__(self, other):
return px.math.mul(other, self)

def __rdiv__(self, other):
return px.math.div_proxy(other, self)

def __rmod__(self, other):
return px.math.mod(other, self)

Expand Down
18 changes: 18 additions & 0 deletions tests/scalar/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
EQ,
ComplexError,
Composite,
IntDiv,
ScalarType,
TrueDiv,
add,
and_,
arccos,
Expand Down Expand Up @@ -531,3 +533,19 @@ def test_scalar_hash_default_output_type_preference():
del old_eq.output_types_preference # mimic old Op
assert new_eq == old_eq
assert hash(new_eq) == hash(old_eq)


def test_rtruediv():
x = ScalarType(dtype="float64")()
y = 1.0 / x
assert isinstance(y.owner.op, TrueDiv)
assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 0.5


def test_rfloordiv():
x = ScalarType(dtype="float64")()
y = 5.0 // x
assert isinstance(y.owner.op, IntDiv)
assert isinstance(y.type, ScalarType)
assert y.eval({x: 2.0}) == 2.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't cast to int here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't testing the dtype, 2 == 2.0 is true in Python. I trust our old code was doing the correct thing already

Loading