Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spurious "NumPy array is not writeable" warning on torch.tensor(np_array) #47160

Closed
jpetkau opened this issue Oct 31, 2020 · 6 comments
Closed
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: tensor creation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jpetkau
Copy link

jpetkau commented Oct 31, 2020

馃悰 Bug

torch.tensor(a), for non-writeable numpy array a, produces this warning:

UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor.

But since torch.tensor(a) always makes a copy, the warning makes no sense. Avoiding the warning requires making a copy of the numpy array before converting to torch.

A partial workaround is to do torch.from_numpy(a.copy()).

To Reproduce

Steps to reproduce the behavior:

import torch, numpy
a=numpy.arange(5.)
a.flags.writeable=False
t=torch.tensor(a)

Expected behavior

The warning should only be shown when actually aliasing a read-only buffer with a mutable Tensor, not when copying from a read-only buffer.

Environment

Verified on Windows, pytorch 1.5.1, git_version = '3c31d73c875d9a4a6ea8a843b9a0d1b19fbe36f3' (pip version)
And on linux at pytorch 1.8.0a0fb

Additional context

Warning was introduced in pull 33615. #44027 (read-only tensors) would make this irrelevant.

This warning is important for from_numpy, because you can easily accidentally do things like create a mutable view onto a Python str or bytes object and overwrite it. Having the warning also give trivial false alarms like this can cause real problems to be overlooked.

Probably related: torch.tensor(nparray, pin_memory=True) doesn't work, probably for similar reasons: it tries to construct a temporary view before making the copy. If you want the torch tensor in pinned memory, there appears to be no way to avoid the warning without introducing a spurious copy.

cc @mruberry @rgommers @heitorschueroff

@ngimel
Copy link
Collaborator

ngimel commented Nov 1, 2020

Thank you for your bug report, warning is indeed spurious in this case.

@ngimel ngimel added module: numpy Related to numpy support, and also numpy compatibility of our operators module: tensor creation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 1, 2020
@ArtistBanda
Copy link
Contributor

Hello! I am new to open source development and would like to help to resolve this issue. It would help a lot with a bit of guidance here.

@mruberry
Copy link
Collaborator

mruberry commented Nov 2, 2020

You probably want to suppress the warning from here:

auto tensor = tensor_from_numpy(data);

when copy_numpy is true. Then add tests for the callsites to verify the behavior is correct.

@ArtistBanda
Copy link
Contributor

Sure, I'll try to do that.

@desaixie
Copy link

Can I ignore this warning if it is coming from torch.from_numpy(batch).to('cuda'), where batch is a slice of a numpy ndarray and therefore not writeable? I guess it shouldn't matter since to('cuda') is making a copy, but I want to make sure.

@mruberry
Copy link
Collaborator

Can I ignore this warning if it is coming from torch.from_numpy(batch).to('cuda'), where batch is a slice of a numpy ndarray and therefore not writeable? I guess it shouldn't matter since to('cuda') is making a copy, but I want to make sure.

Yes, I think you can ignore it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: numpy Related to numpy support, and also numpy compatibility of our operators module: tensor creation triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants