-
Notifications
You must be signed in to change notification settings - Fork 137
Open
Description
Description
We have a local_0_dot_x
that removes useless dots with zero'd inputs. We don't seem to have anything for dots with ones as reported in #637 (comment)
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
x = tn.col('x')
f = x @ [[1.]]
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([x], f, mode=get_default_mode().excluding("BlasOpt"))
pytensor.dprint(fn)
dot [id A] 0
├─ x [id B]
└─ [[1.]] [id C]
I excluded the BlasOpt just to have a simpler graph, but it will still not rewrite it away with those, just add the more complex Blas Op.
pytensor/pytensor/tensor/rewriting/math.py
Lines 155 to 190 in d3dd34e
@register_canonicalize | |
@register_stabilize | |
@node_rewriter([Dot]) | |
def local_0_dot_x(fgraph, node): | |
if not isinstance(node.op, Dot): | |
return False | |
x = node.inputs[0] | |
y = node.inputs[1] | |
replace = False | |
try: | |
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0: | |
replace = True | |
except NotScalarConstantError: | |
pass | |
try: | |
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0: | |
replace = True | |
except NotScalarConstantError: | |
pass | |
if replace: | |
constant_zero = constant(0, dtype=node.outputs[0].type.dtype) | |
if x.ndim == 2 and y.ndim == 2: | |
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) | |
return [alloc(constant_zero, x.shape[0], y.shape[1])] | |
elif x.ndim == 1 and y.ndim == 2: | |
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) | |
return [alloc(constant_zero, y.shape[1])] | |
elif x.ndim == 2 and y.ndim == 1: | |
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0])) | |
return [alloc(constant_zero, x.shape[0])] | |
elif x.ndim == 1 and y.ndim == 1: | |
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) | |
return [constant_zero] |