Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in TensorVariable.__rmatmul__ #465

Merged
merged 2 commits into from
Oct 4, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Oct 3, 2023

Introduced in #452
Closes #464

CC @tomicapretto

@ricardoV94 ricardoV94 changed the title Fix bug in TensorVariable.__rmatmul__ Fix bug in TensorVariable.__rmatmul__ Oct 3, 2023
@ricardoV94 ricardoV94 added bug Something isn't working graph objects linalg Linear algebra labels Oct 3, 2023
@jessegrabowski
Copy link
Member

jessegrabowski commented Oct 4, 2023

I'm not positive __rmatmul__ is the issue. I put prints into all of __matmul__, __rmatmul__, __dot__, and __rdot__ (I brought these methods back -- I was getting errors in the tests without them?). Only __dot__ was called by the test you flagged as not working.

Digging a bit into the failure case:

    X_val = np.arange(2 * 3).reshape((2, 3))
    y_val = np.arange(3)
    res = X_val.dot(y)
    exp_res = dot(X_val, y)

exp_res looks like this:

[[Mul.0 Mul.0 Mul.0]
 [Mul.0 Mul.0 Mul.0]]

The dprint for each cell looks like this (this is cell 0,0)

Mul [id A]
 ├─ ExpandDims{axis=0} [id B]
 │  └─ 0 [id C]
 └─ y [id D]

Evidently it's broadcasting multiplication of y into every cell of X,

evaled_res = np.array([[cell.eval({y:y_val}) for cell in row] for row in res])
evaled_res
[[[ 0.  0.  0.]
  [ 0.  1.  2.]
  [ 0.  2.  4.]]

 [[ 0.  3.  6.]
  [ 0.  4.  8.]
  [ 0.  5. 10.]]]

evaled_res.shape
(2, 3, 3)

It seems like it's treating X like a batch?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 4, 2023

To be clear, numpy.array().dot(TensorVariable) doesn't work (and never did AFAICT) and I couldn't see anything we could do about it without delving into the __array_func__ stuff. Numpy converts that into the weird rmul by zero you describe, and I couldn't trace where that comes from.

I added an explicit check on a test so we "know it's broken".

Note that @ is not a dot operation but a matmul operation. Aesara/pytensor had this aspect confused and were using the wrong operator. I flipped the operands of the rmatmul by mistake when I fixed it to use matmul in #452.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 4, 2023

About the __dot__, __rdot__, where are those used? I couldn't find a reference to this being a "real dunder" method anywhere. They were being called before by the @ operator but again that was a mistake, because it should be a matmul not a dot

@@ -98,10 +100,15 @@ def test_infix_matmul_method():
assert equal_computations([res], [exp_res])

X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val) @ y
res = X_val @ y
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the test relevant to the reported bug

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running test_dot_method, and it failed with an error that __dot__ was not implemented

Copy link
Member Author

@ricardoV94 ricardoV94 Oct 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but if that's the only test that used it, it's not a reason to have the mysterious dunder method. I couldn't see it being used anywhere else nor could I see how to trigger it naturally

@jessegrabowski
Copy link
Member

jessegrabowski commented Oct 4, 2023

I will believe you that np.array @ TensorVariable never worked, but it seems a bit sus. Shouldn't this have come up before if true? PyMC users mix numpy/pytensor objects all the time.

EDIT: Actually I'm all confused now. @tomicapretto reported that Array @ Tensor (__matmul__) doesn't work, but pt.dot(Array, Tensor) does, and you're saying the Array.dot(Tensor) method also doesn't work?

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 4, 2023

I will believe you that np.array @ TensorVariable never worked, but it seems a bit sus. Shouldn't this have come up before if true? PyMC users mix numpy/pytensor objects all the time.

@ worked!!! But it used to use pt.dot under the hood and not pt.matmul. I fixed it to use pt.matmul in #452 but messed up the order of arguments in __rmatmul__ which is called when you do np.array @ TensorVariable

EDIT: Actually I'm all confused now. @tomicapretto reported that Array @ Tensor (__matmul__) doesn't work, but pt.dot(Array, Tensor) does, and you're saying the Array.dot(Tensor) method also doesn't work?

pt.dot(np.array, TensorVariable) works. What doesn't work is np.array.dot(TensorVariable). It didn't work in Aesara either:

image

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 4, 2023

To recap:

  • Calling pytensor functions directly like pt.dot or pt.matmul always works because it converts the objects to TensorVariables
  • Calling @ always works (after this PR), as it will end up in a pt.matmul call
  • Calling np.array(...).dot(pt.tensor(...)) doesn't work and never did AFAICT. This is one example of why I don't like to mix numpy and pytensor objects, because somethings work, like np.exp(pt.tensor(...)) but not others. The cases that work are achieved via the old __array_priority__ protocol Supposedly the __array__function__ protocol is the modern way of overloading numpy expressions when you have custom types, but I never had the time to check how to implement that. I suspect that would allow us to overload np.array().dot(...) correctly.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Oct 4, 2023

Okay so the __dot__ methods are used in

def dot(l, r):
"""Return a symbolic dot product.
This is designed to work with both sparse and dense tensors types.
"""
if not isinstance(l, Variable):
l = as_tensor_variable(l)
if not isinstance(r, Variable):
r = as_tensor_variable(r)
try:
res = l.__dot__(r)
if res is NotImplemented:
raise NotImplementedError
except (NotImplementedError, AttributeError, TypeError):
res = r.__rdot__(l)
if res is NotImplemented:
raise NotImplementedError()
return res

Seems a bit odd, and I am sure other stuff like pt.exp doesn't try to handle sparse Variables.
In any case, I reverted their removal.

@codecov-commenter
Copy link

Codecov Report

Merging #465 (686725c) into main (326cb2e) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #465   +/-   ##
=======================================
  Coverage   80.65%   80.65%           
=======================================
  Files         160      160           
  Lines       46022    46022           
  Branches    11265    11265           
=======================================
+ Hits        37120    37121    +1     
+ Misses       6669     6668    -1     
  Partials     2233     2233           
Files Coverage Δ
pytensor/tensor/variable.py 87.62% <100.00%> (+0.19%) ⬆️

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for being patient with my questions. This is good to go. I agree the infixes for dot are a strange design choice, but it's a detail for another day.

@ricardoV94 ricardoV94 merged commit 931297f into pymc-devs:main Oct 4, 2023
53 checks passed
@ricardoV94 ricardoV94 deleted the fix_rmatmul branch October 4, 2023 12:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working graph objects linalg Linear algebra
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect shape inferrence for blockwise dot
3 participants