Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy conjugated tensor is not numpy interoperable #59945

Closed
pearu opened this issue Jun 14, 2021 · 20 comments
Closed

Lazy conjugated tensor is not numpy interoperable #59945

pearu opened this issue Jun 14, 2021 · 20 comments
Labels
high priority module: bc-breaking Related to a BC-breaking change module: complex Related to complex number support in PyTorch module: correctness (silent) issue that returns an incorrect result silently module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@pearu
Copy link
Collaborator

pearu commented Jun 14, 2021

🐛 Bug

The bug is best illustrated by the following examples (using the current master):

>>> import torch
>>> torch.tensor([1+2j]).conj().numpy()
array([1.+2.j], dtype=complex64)
>>> numpy.array(torch.tensor([1+2j]).conj())
array([1.+2.j], dtype=complex64)

Expected behavior

Since the numpy view of a lazy conjugated tensor is impossible, the above examples should raise an exception rather than ignore the set conjugate bit.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @Varal7 @dylanbespalko @mruberry @lezcano @nikitaved @rgommers @heitorschueroff

@pearu pearu added module: complex Related to complex number support in PyTorch module: numpy Related to numpy support, and also numpy compatibility of our operators module: correctness (silent) issue that returns an incorrect result silently labels Jun 14, 2021
@ezyang
Copy link
Contributor

ezyang commented Jun 14, 2021

Whoops, yup, we need to raise an error here.

@anjali411 anjali411 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 14, 2021
@anjali411
Copy link
Contributor

anjali411 commented Jun 14, 2021

@pearu thanks for filing the issue! yeah we should raise an error here (same thing for __array__ too) and probably also add OpInfos for these ops ...

@rgommers
Copy link
Collaborator

Raising an error is BC-breaking. Why wouldn't you just use .resolve_conj inside .numpy() and .__array__? This is code that used to work and is now made lazy - so it seems like just doing the >O(1) work later seems justified. In particular for .numpy() I don't see that as guaranteeing being an O(1) operation, it should just work for CPU tensors that are not in an autograd graph.

@pearu
Copy link
Collaborator Author

pearu commented Jun 14, 2021

Raising an error is BC-breaking. Why wouldn't you just use .resolve_conj inside .numpy() and .__array__?

Note that .numpy() result should share the storage of the tensor (see https://pytorch.org/docs/stable/tensors.html#torch.Tensor.numpy) while .resolve_conj() would require a copy that would be a BC-breaking as well.

@rgommers
Copy link
Collaborator

I don't think so - before a.conj() would also make a copy, so a.conj().numpy() would not share storage with a.

@pearu
Copy link
Collaborator Author

pearu commented Jun 14, 2021

I don't think so - before a.conj() would also make a copy, so a.conj().numpy() would not share storage with a.

a.conj() is a result of the conjugate operation and it is not expected that a.conj().numpy() would share the storage of a (unless a is real and then conj() would be no-op anyway).
Instead, it is expected that b.numpy() would share the storage with b where b = a.conj(). With the new lazy conjugate feature, b shares the storage with a, so the storage-sharing promise of .numpy() cannot be fulfilled in one or another way.

Btw, a.conj().conj().numpy() would share the storage with a (assuming that a does not have the conjugate bit set).

@anjali411 anjali411 added the module: bc-breaking Related to a BC-breaking change label Jun 14, 2021
@anjali411
Copy link
Contributor

@rgommers yeah this would be bc-breaking but I think as @pearu mentioned above it might be best to ask user to resolve conjugation before calling these view operations on tensors with conjugate bit set to 1, similar to what we do for torch.view_as_real

Tensor view_as_real(const Tensor& self) {
TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors. To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
return native::_view_as_real_physical(self);
}

@rgommers
Copy link
Collaborator

rgommers commented Jun 14, 2021

Btw, a.conj().conj().numpy() would share the storage with a (assuming that a does not have the conjugate bit set).

I don't think this is right either, and if it is it is going to lead to subtle bugs. As far as I understood @anjali411's presentation, this is why a conjugate bit is not allowed to flip back from 1 to 0.

If there are any examples like a.conj().conj() or that mix views and in-place ops that now are going to give silently different results, something seems quite wrong (right?). My impression was that it's a performance optimization that makes .conj() lazy, but doesn't change any semantics.

@rgommers
Copy link
Collaborator

With the new lazy conjugate feature, b shares the storage with a,

Only in terms of "they point at the same memory at this point in time" I hope, not semantically. I'd expect this:

b = a.conj()
a[:10] = -99  # does this modify b now?? (it shouldn't!)

to not change b. This is why I asked about the bi-directional lenses framework - you want to know what views to update when, so that the semantics of code like the above does not change between 1.9.0 and 1.10.0.

Or maybe I misunderstood, and the introducing of conjugate views is BC-breaking for this kind of code?

@vadimkantorov
Copy link
Contributor

Maybe numpy(force = True) from another PR could be also used in this case (when > O(1) work is expected and desired)

@pearu
Copy link
Collaborator Author

pearu commented Jun 14, 2021

Only in terms of "they point at the same memory at this point in time" I hope, not semantically. I'd expect this:

b = a.conj()
a[:10] = -99  # does this modify b now?? (it shouldn't!)

to not change b.

It does, and it looks like also in a correct way:

>>> a = torch.tensor([1+2j])
>>> b = a.conj()
>>> b[0] = 3+4j
>>> a
tensor([3.-4.j])

IIUC, the conjugate bit is just an additional attribute to a tensor that certain operation implementations may take advantage of by avoiding materialization of intermediate conjugate results. Resetting conjugate bit twice corresponds to just double conjugation that is an identity operation. I don't see why .conj().conj() would not be allowed.

It seems that the conjugate bit works reliably for operations that executions use the pytorch dispatching mechanism.

Only operations that directly access the tensor memory, such as numpy interoperability operations, or operations that re-interpret the tensor memory such as view() method (see related issue #59946), would require that the conjugate bit is not set to ensure expected results.

On the other hand, while a.conj().numpy() would give an unexpected result, but if one remembers that b.numpy() ignores the conjugate bit (just because numpy ndarray does not have a similar feature), the result becomes expected: the resulting numpy array shares the storage with a and it must be interpreted together with a.is_conj() result. It may be that such interpretation of numpy() result is cumbersome and will complicate programs, so raising a runtime exception with a message stating that one should materialize the lazy conjugated tensor before exposing the tensor storage outside the pytorch framework.

@rgommers
Copy link
Collaborator

rgommers commented Jun 15, 2021

Okay, then ignore everything I wrote above. I had not even considered that a.conj() changes the return from copy to actual view by design rather than just a lazy copy - this seems like the worst kind of bc-breaking change (silent change in numerical results for idiomatic code). If that was the intent, I would have expected new syntax instead of reusing .conj().

gh-59943 may have the same root cause.

@pearu
Copy link
Collaborator Author

pearu commented Jun 15, 2021

@rgommers , I agree, the conjugate bit feature would deserve a new method, say, .conj_lazy(), to avoid BC issues. It would mean that existing codes using .conj() would need a revisit where conj would be replaced with conj_lazy provided that

  • the corresponding change will have a positive effect on performance/memory savings (for instance, when conjugate needs to be locally materialized anyway, there would be no need to use conj_lazy)
  • and twiddling of conjugate bit is localized, or at least, there is proof that the conjugate bit twiddling affects only dispatch-based operations, all this to minimize BC issues with numpy or numba interoperability.

I also think that a.conj_lazy().numpy() should return a numpy ndarray that shares the storage with a provided that the .conj_lazy() and .numpy() docs clearly state that .numpy() will ignore the conjugate bit and when tensor storage is exposed via .numpy(), one must to take into account the state of a.is_conj() as well.

@ezyang
Copy link
Contributor

ezyang commented Jun 15, 2021

@rgommers We argued a bit about whether or not to make this BC breaking change, and you were there too :P The most relevant discussion is in #43270 (comment)

It took a while for us to actually land conjugate views on master so the BC breakage is a little worse than it might have been if we had managed to land this last year, but our guess is that the corpus of complex number using PyTorch code is still low enough that x.conj().inplace_op_() is unlikely to happen (and remember, this will mutate x if x is a real tensor! So it's not that incredible that people will avoid doing this idiom). But I suppose it's not too late to add some VC based mechanism for detecting when mutations have propagated through conjugations and warn people about it.

I also think that a.conj_lazy().numpy() should return a numpy ndarray that shares the storage with a provided that the .conj_lazy() and .numpy() docs clearly state that .numpy() will ignore the conjugate bit and when tensor storage is exposed via .numpy(), one must to take into account the state of a.is_conj() as well.

I disagree. It's more important for numpy() to give a semantically equivalent output to the input, and that means we must error.

@vadimkantorov
Copy link
Contributor

About numpy() semantics: I propose that by default it should throw an error if it cannot represent the array faithfully without copy. And that it should do a best-effort copying if some explicit force=True or copy=True argument is passed as in #59790

@pearu
Copy link
Collaborator Author

pearu commented Jun 15, 2021

I disagree. It's more important for numpy() to give a semantically equivalent output to the input, and that means we must error.

OK. No objections as one can do:

x = (a.conj() if a.is_conj() else a).numpy()

if one really wants to access a memory via numpy array.

@martinPasen
Copy link

About numpy() semantics: I propose that by default it should throw an error if it cannot represent the array faithfully without copy. And that it should do a best-effort copying if some explicit force=True or copy=True argument is passed as in #59790

I am working on #59790 and this issue is closely connected. Should I implement it in that PR, or is it supposed to be done separately after #59790 is merged?

@rgommers
Copy link
Collaborator

rgommers commented Jun 16, 2021

@rgommers We argued a bit about whether or not to make this BC breaking change, and you were there too :P The most relevant discussion is in #43270 (comment)

Thanks for that link. That's like a year ago - I had completely forgotten about that:)

It took a while for us to actually land conjugate views on master so the BC breakage is a little worse than it might have been if we had managed to land this last year, but our guess is that the corpus of complex number using PyTorch code is still low enough

Okay. Maybe I'm too scarred by working on NumPy for too long, and this is all not too bad. For me making a change after complex becoming stable simply didn't compute. Certainly PyTorch users are more forgiving about BC-breaking changes. In the end this does result in a better design. So if there aren't too many complaints from people using the nightlies, then will be fine to leave it as is I guess.

This comment is also still relevant: #43270 (comment).

Actually, since we're exploring conjugate view, let me suggest we make an even stronger claim: we should tell users to flat out NOT mutate the output of this function, ever. If you avoid doing so, you will write code that is forwards compatible with any possible semantics of conj.
...
I despair about actually making sure people don't do this with just a doc update. We really need some way to mark views as non-mutable.

This was one of the pain points for the array API standard design as well. Since libraries are not consistent with each other about view-copy semantics, we say the same thing in the docs: "don't mutate arrays that can be views if you want code to be portable". It's very hard to guarantee though.

@ezyang
Copy link
Contributor

ezyang commented Jun 16, 2021

I despair about actually making sure people don't do this with just a doc update. We really need some way to mark views as non-mutable.

I used to think that this was not so easy to do because of the base update problem: normally, tensors are mutable, but if you take out an immutable view to the tensor, if you mutate the base tensor, the view would get updated too, so this leads to the unsavory conclusion that when you take out an immutable view, this must make the base tensor immutable too.

However, I recently realized that version counters neatly solve this problem. Instead making the base immutable, you just record the base version counter at time the view was taken out. If base mutates, and then you try to use the view, you now get an error message saying that the base has been updated and so the view is no longer valid. So we could easily implement this for x.conj() and then x.conj_view() is for when you actually want the aliasing semantics.

Just need to implement it!

anjali411 added a commit that referenced this issue Jul 20, 2021
… conjugate or negative bit set"

Resolves #59945 and #59946


bc breaking note: Unlike before, complex_tensor.conj().numpy(),  complex_float_tensor.conj().view(torch.float64) now doesn't return a view but instead errors out

[ghstack-poisoned]
anjali411 added a commit that referenced this issue Jul 21, 2021
… conjugate or negative bit set"

Resolves #59945 and #59946


bc breaking note: Unlike before, complex_tensor.conj().numpy(),  complex_float_tensor.conj().view(torch.float64), complex_float_tensor.conj().imag.view(torch.int32) now doesn't return a view but instead errors out

[ghstack-poisoned]
anjali411 added a commit that referenced this issue Jul 21, 2021
… conjugate or negative bit set"

Resolves #59945 and #59946


bc breaking note: Unlike before, complex_tensor.conj().numpy(),  complex_float_tensor.conj().view(torch.float64), complex_float_tensor.conj().imag.view(torch.int32) now doesn't return a view but instead errors out

Differential Revision: [D29819288](https://our.internmc.facebook.com/intern/diff/D29819288)

[ghstack-poisoned]
anjali411 added a commit that referenced this issue Jul 21, 2021
… conjugate or negative bit set"

Resolves #59945 and #59946


bc breaking note: Unlike before, complex_tensor.conj().numpy(),  complex_float_tensor.conj().view(torch.float64), complex_float_tensor.conj().imag.view(torch.int32) now doesn't return a view but instead errors out

Differential Revision: [D29819288](https://our.internmc.facebook.com/intern/diff/D29819288)

[ghstack-poisoned]
anjali411 added a commit that referenced this issue Jul 22, 2021
… conjugate or negative bit set"

Resolves #59945 and #59946


bc breaking note: Unlike before, complex_tensor.conj().numpy(),  complex_float_tensor.conj().view(torch.float64), complex_float_tensor.conj().imag.view(torch.int32) now doesn't return a view but instead errors out

Differential Revision: [D29819288](https://our.internmc.facebook.com/intern/diff/D29819288)

[ghstack-poisoned]
@lucascolley
Copy link

lucascolley commented Aug 7, 2024

This raising an error is biting us in SciPy, where np.asarray(...(xp.conj(x))) is not an uncommon pattern. While the array API standard does not require general interoperability with any library's asarray, it seems that the same problem will be present for __dlpack__, which is in the standard.

The workaround is to just use torch.conj_physical every time someone calls xp.conj(a_torch_tensor), but that is not ideal given that most occurrences of xp.conj are not followed by a call to np.asarray or __dlpack__.

x-ref data-apis/array-api-compat#173 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: bc-breaking Related to a BC-breaking change module: complex Related to complex number support in PyTorch module: correctness (silent) issue that returns an incorrect result silently module: numpy Related to numpy support, and also numpy compatibility of our operators triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants