New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Kronecker product of tensors (torch.kron) #45358
Conversation
Tests pass. The implementation is based on tensordot.
💊 CI failures summary and remediationsAs of commit a1b3255 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 100 times. |
Codecov Report
@@ Coverage Diff @@
## master #45358 +/- ##
=======================================
Coverage 60.81% 60.81%
=======================================
Files 2748 2748
Lines 254027 254070 +43
=======================================
+ Hits 154488 154522 +34
- Misses 99539 99548 +9 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good to me. Can we get some benchmarks too - comparison to np.linalg.kron
?
Here is the code for the benchmark.
|
Any idea why some cases are not as fast? |
Well, the implementation is different. NumPy's implementation is based on |
Alright, I've realized that the previous timings were in debug mode 😄 |
Hi, this is really exciting to see, I was hoping to use the kronecker product with complex tensors, but I couldn't discern if that would be supported by this. I look forward to using this! |
Updated |
torch/_torch_docs.py
Outdated
|
||
Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. | ||
|
||
If :attr:`input` is a :math:`(m \times n)` tensor and :attr:`other` is a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We didn't discuss this previously, but is the Kronecker product defined if either of A or B aren't matrices? Should we add a check for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the Kronecker product is defined mathematically only for matrices. We can think of vectors as m×1 matrix, and scalars as 1×1 matrix, then everything works.
Vectors are tested here (as (4, )
shape).
As for n-dimensional arrays with n>2, NumPy extends the definition as described in the notes section to "blocks of the second tensor scaled by the first tensor".
So kron does not support batching.
Sometimes it's said that for matrices Kronecker product == tensor outer product, but this is not true for tensors in general. For the example from Wiki about the tensor product, kron would give a tensor with dimensions (31, 510, 7*100).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Being like NumPy seems OK (our goal is to be compatible, after all). Would you add a note about how input and other are treated if they're not matrices? The current docs deal with them as if they must be matrices. Something like:
- NOTE
- The Kronecker product is typically defined only for two matrices
- When either is a scalar or vector it's unsqueezed as...
- When either is an input is a tensor with 3+ dimensions then...
ALTERNATIVELY you could expand the description of the function describe the matrix case, say that scalars and vectors are unsqueezed to be matrices, and THEN define the "general" case. That seems like a more challenging but better approach.
What are your thoughts, @IvanYashchuk?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about in the main description present the general definition first and then mention what it does for matrices.
Computes the Kronecker product of input
and other
.
For general n-dimensional tensors this function computes:
... math expression
The number of dimensions of input
and other
is assumed to be the same and if necessary the smaller tensor is unsqueezed as the larger one.
If input
is a (m \times n)
tensor matrix and other
is a (p \times q)
tensor matrix, the result will be a (p*m \times q*n)
block tensor matrix:
... kron definition for matrices from wiki
Scalar and vector inputs are unsqueezed to be matrices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That approach sounds good. Instead of
"The number of dimensions of input and other is assumed to be the same and if necessary the smaller tensor is unsqueezed as the larger one."
I think it can say "If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions."
That change would make the explicit reference to scalar and vector inputs at the end redundant, so it can be removed.
The last caveat is that this needs to define the dot operator in both equations above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, dot was for the normal multiplication of scalars. Is asterisk (*) preferred?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dot's OK if the documentation defines it (the dot operator is used for so many mathematical operations it's highly ambiguous), an asterisk would also be fine and probably doesn't require definition (it's typically used for elementwise multiplication and scalar multiplication).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work, @IvanYashchuk. Made one last small docs comment for your review.
It'll be great to have torch.kron available. I was just talking to a PyTorch user the other day about how he's using matrices with a "Kronecker structure" as a form of structured sparsity.
Do me a favor, though, and when the updates are ready for review be sure to re-request review or say "This is ready for another review." With the number of PRs I'm tracking it's hard to understand when updates are being made vs. something is ready for review again.
Just ping me when you'd like this merged.
Hi @IvanYashchuk! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but we do not have a signature on file. In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks! |
Hi @mruberry, I think now this PR should be ready for merging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Tests seem to fail. Is it related to |
I would ignore the "Facebook internal" build signal. It's complete nonsense. |
This PR adds a function for calculating the Kronecker product of tensors.
The implementation is based on
at::tensordot
with permutations and reshape.Tests pass.
TODO:
common_methods_invokations.py
Ref. #42666