Skip to content

Commit

Permalink
Implement ufunc_outer like add.outer for binary Elemwise operat…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
Dhruvanshu-Joshi committed May 7, 2024
1 parent 35ae5db commit b79d232
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
10 changes: 10 additions & 0 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,16 @@ def c_code_cache_version_apply(self, node):
else:
return ()

def outer(self, x, y):
from pytensor.tensor.basic import expand_dims

if self.scalar_op.nin not in (-1, 2):
raise NotImplementedError("outer is only available for binary operators")

x_ = expand_dims(x, tuple(range(-y.ndim, 0)))
y_ = expand_dims(y, tuple(range(x.ndim)))
return self(x_, y_)


class CAReduce(COp):
"""Reduces a scalar operation along specified axes.
Expand Down
21 changes: 21 additions & 0 deletions tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import pytensor
import pytensor.scalar as ps
import pytensor.tensor as pt
import tests.unittest_tools as utt
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
Expand Down Expand Up @@ -893,6 +895,25 @@ def test_invalid_static_shape(self):
):
x + y

@pytest.mark.parametrize(
"shape_x, shape_y, op, np_op",
[
((3, 5), (7, 1, 3), pt.add, np.add),
((2, 3), (1, 4), pt.mul, np.multiply),
],
)
def test_outer(self, shape_x, shape_y, op, np_op):
x = tensor(dtype=np.float64, shape=shape_x)
y = tensor(dtype=np.float64, shape=shape_y)

z = op.outer(x, y)

f = function([x, y], z)
x1 = np.ones(shape_x)
y1 = np.ones(shape_y)

np.testing.assert_array_equal(f(x1, y1), np_op.outer(x1, y1))


def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
Expand Down

0 comments on commit b79d232

Please sign in to comment.