Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify dots with 1 #638

Open
ricardoV94 opened this issue Feb 8, 2024 · 3 comments
Open

Simplify dots with 1 #638

ricardoV94 opened this issue Feb 8, 2024 · 3 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 8, 2024

Description

We have a local_0_dot_x that removes useless dots with zero'd inputs. We don't seem to have anything for dots with ones as reported in #637 (comment)

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

x = tn.col('x')
f = x @ [[1.]]
with pytensor.config.change_flags(optimizer_verbose=True):
    fn = pytensor.function([x], f, mode=get_default_mode().excluding("BlasOpt"))

pytensor.dprint(fn)
dot [id A] 0
 ├─ x [id B]
 └─ [[1.]] [id C]

I excluded the BlasOpt just to have a simpler graph, but it will still not rewrite it away with those, just add the more complex Blas Op.

@register_canonicalize
@register_stabilize
@node_rewriter([Dot])
def local_0_dot_x(fgraph, node):
if not isinstance(node.op, Dot):
return False
x = node.inputs[0]
y = node.inputs[1]
replace = False
try:
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
try:
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [alloc(constant_zero, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[1], y.shape[0]))
return [alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [constant_zero]

@Dhruvanshu-Joshi
Copy link
Contributor

Looks like an interesting issue. We'd just have to replace 0 with x in the local_0_dot_x right?
Here's what I have in mind:

 @register_canonicalize 
 @register_stabilize 
 @node_rewriter([Dot]) 
 def local_1_dot_x(fgraph, node): 
     if not isinstance(node.op, Dot): 
         return False 
  
     x = node.inputs[0] 
     y = node.inputs[1] 
     replace = False 
     try: 
         if get_underlying_scalar_constant_value(x, only_process_constants=True) == 1: 
             replace = True 
             var = y
     except NotScalarConstantError: 
         pass 
  
     try: 
         if get_underlying_scalar_constant_value(y, only_process_constants=True) == 1: 
             replace = True 
             var=x
     except NotScalarConstantError: 
         pass 
  
     if replace: 
         constant_value = constant(get_underlying_scalar_constant_value(var, only_process_constants=True), dtype=node.outputs[0].type.dtype) 
         if x.ndim == 2 and y.ndim == 2: 
             constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0])) 
             return [alloc(constant_value, x.shape[0], y.shape[1])] 
         elif x.ndim == 1 and y.ndim == 2: 
             constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0])) 
             return [alloc(constant_value, y.shape[1])] 
         elif x.ndim == 2 and y.ndim == 1: 
             constant_value = assert_op(constant_value, eq(x.shape[1], y.shape[0])) 
             return [alloc(constant_value, x.shape[0])] 
         elif x.ndim == 1 and y.ndim == 1: 
             constant_value = assert_op(constant_value, eq(x.shape[0], y.shape[0])) 
             return [constant_value] 

However, I think using constant value might be wrong here. Will I have to replace with the entire var itself? If yes, then is this the correct way of moving forward?

var=assert_op(var,  eq(...)
alloc(var, shape)

@ricardoV94
Copy link
Member Author

No, the rule is slightly different for ones, as it consists of summing the left matrix. Also have to reason about broadcasting.

I suggest playing with numpy to get a feel of what it should do.

@Dhruvanshu-Joshi
Copy link
Contributor

Ohk.
Just so that I get it correctly, for a given graph say

Sub [id A]
 ├─ dot [id B]
 │  ├─ dot [id C]
 │  │  ├─ Transpose{axes=[1, 0]} [id D] 'A.T'
 │  │  │  └─ A [id E]
 │  │  └─ Neg [id F]
 │  │     └─ x [id G]
 │  └─ [[1.]] [id H]
 └─ dot [id I]
    ├─ A [id E]
    └─ dot [id J]
       ├─ x [id G]
       └─ [[1.]] [id H]

we want the output of the rewrite to be:

Sub [id A]
 ├─ dot [id B]
 │  ├─ Transpose{axes=[1, 0]} [id C] 'A.T'
 │  │  └─ A [id D]
 │  └─ Neg [id E]
 │     └─ x [id F]
 └─ dot [id G]
    ├─ A [id D]
    └─ x [id F]

Is this correct? And if yes, how does summing of left matrices and broadcasting come into picture here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants