New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Permute cp factors #380
Permute cp factors #380
Conversation
Codecov Report
@@ Coverage Diff @@
## main #380 +/- ##
==========================================
- Coverage 88.55% 88.46% -0.10%
==========================================
Files 107 107
Lines 6238 6284 +46
==========================================
+ Hits 5524 5559 +35
- Misses 714 725 +11
📣 Codecov can now indicate which changes are the most critical in Pull Requests. Learn more |
68a26cd
to
c80e5ac
Compare
c80e5ac
to
af059db
Compare
af059db
to
079107f
Compare
Thanks, this is great @caglayantuna and @cohenjer. I am really tempted to either i) skip the tensorflow test for now, like we do e.g. in cross approximation or ii) use the experimental tensorflow numpy API like @aarmey suggested. I don't think it's worth defining |
Yes, please, let's skip the tensorflow test for now. The overlap of folks using this with TF will be nil. Given the constant frustration with TF I think it's reasonable to require the experimental interface eventually. @cyrillustan and @murphymadeleine21, you'll be interested in this. |
131b501
to
730037a
Compare
730037a
to
a1fa0dc
Compare
I updated this PR and skipped tensorflow test. To import |
Awesome! This is super useful and we should advertise it to users - @caglayantuna, let's also add it to the API in the doc and perhaps add a few lines in the user-guide? I'll merge this and we can add through another PR? |
Permute cp factors
This pull request adds
cp_permute_factors
function tocp_tensor.py
and relevant test totest_cp_tensor
.This function has 2 inputs which are ref_cp_tensor (one cp tensor) and tensors_to_permute (one cp tensor or list of cp tensor). It returns permuted tensor or tensors and permutation list for each permuted tensor.
tl.gather
Since permuting factors require an operator to swap columns, I added
tl.gather
function to Tensorly to make this function work for tensorflow backend. While using [:, indices] is enough for all the other backends, index operation is not possible with tensorflow unfortunately. I have usedtake
function for numpy, jax, mxnet andindice_select
for pytorch backend to have same behaviour for all backends.Note
I don't see any error with Jax backend in my computer but it fails in github test.