Skip to content
This repository was archived by the owner on Nov 10, 2022. It is now read-only.

Conversation

george-qi
Copy link
Contributor

@george-qi george-qi commented Jun 17, 2022

Stack from ghstack (oldest at bottom):

Differential Revision: D37663285

[ghstack-poisoned]
george-qi added a commit that referenced this pull request Jun 17, 2022
ghstack-source-id: 3c7a7d0
Pull Request resolved: #65
@george-qi
Copy link
Contributor Author

george-qi commented Jun 17, 2022

For some reason, to_dense is failing when we specifically use a MaskedTensor with requires_grad=True. I discovered this while trying to debug the case where to_sparse_csr was also not working for what I think is a very similar reason despite an identical implementation to to_sparse (which converts to torch.sparse_coo

Below, find a minimal reproducible example:

import torch
from maskedtensor import masked_tensor

t = torch.randn(5, 3)
mask = torch.randint(0, 2, (5, 3)).bool()

m_csr = mask.to_sparse_csr()
t_csr = t.sparse_mask(m_csr)

m_coo = mask.to_sparse_coo()
t_coo = t.sparse_mask(m_coo)

mt1 = masked_tensor(t_csr, m_csr, requires_grad=False)
mt2 = masked_tensor(t_csr, m_csr, requires_grad=True)

mt3 = masked_tensor(t_coo, m_coo, requires_grad=False)
mt4 = masked_tensor(t_coo, m_coo, requires_grad=True)

mt1.to_dense()  # passes
mt2.to_dense()  # fails
mt3.to_dense()  # passes
mt4.to_dense()  # passes

mt2.to_dense() fails with the error:

> python test/test.py
/fsx/users/georgeqi/maskedtensor/test/test.py:7: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at  ../aten/src/ATen/SparseCsrTensorImpl.cpp:66.)
  m_csr = mask.to_sparse_csr()
terminate called after throwing an instance of 'c10::Error'
  what():  layout_impl is only implemented for TensorImpl subclasses.
Exception raised from layout_impl at ../c10/core/TensorImpl.h:777 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x3e (0x7f59b6eb8f0e in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x60 (0x7f59b6e9388f in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x2fe13 (0x7f59b6ea4e13 in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #3: torch::autograd::AccumulateGrad::AccumulateGrad(at::Tensor) + 0x34c (0x7f59bae2fe3c in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #4: torch::autograd::impl::grad_accumulator(at::Tensor const&) + 0x175 (0x7f59bae624c5 in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #5: torch::autograd::impl::gradient_edge(at::Tensor const&) + 0x6b (0x7f59bae6269b in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #6: std::vector<torch::autograd::Edge, std::allocator<torch::autograd::Edge> > torch::autograd::collect_next_edges<std::vector<at::Tensor, std::allocator<at::Tensor> >&>(std::vector<at::Tensor, std::allocator<at::Tensor> >&) + 0x55 (0x7f59d0c2d455 in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #7: std::pair<UnpackedInput, InputFlags> unpack_input<false>(_object*) + 0x737 (0x7f59d0c2ef67 in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #8: THPFunction_apply(_object*, _object*) + 0x55 (0x7f59d0c292b5 in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
<omitting python frames>
frame #17: torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<pybind11::handle>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName) + 0x53d (0x7f59d0f21d5d in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #18: torch::handle_torch_function(torch::PythonArgs&, _object*, _object*, _object*, _object*, char const*, char const*) + 0x20c (0x7f59d0f2335c in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #19: <unknown function> + 0x43ed7a (0x7f59d09e7d7a in /data/home/georgeqi/miniconda/envs/pytorch_env/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #31: __libc_start_main + 0xe7 (0x7f59d1cf7c87 in /lib/x86_64-linux-gnu/libc.so.6)

[1]    7736 abort (core dumped)  python test/test.py

I suspect there's something fishy going on with sparse_csr in core, perhaps..?

george-qi added a commit that referenced this pull request Jun 17, 2022
ghstack-source-id: 378d1b3
Pull Request resolved: #65
@george-qi george-qi requested a review from cpuhrsch June 17, 2022 16:09
@cpuhrsch
Copy link
Contributor

@george-qi - It's possible that the failure you describe is due to core. Can you isolate it further into a standalone snippet?

george-qi added a commit that referenced this pull request Jul 6, 2022
ghstack-source-id: 4151dec
Pull Request resolved: #65
george-qi added a commit that referenced this pull request Jul 6, 2022
ghstack-source-id: ba29606
Pull Request resolved: #65
george-qi added a commit that referenced this pull request Jul 6, 2022
ghstack-source-id: 200cc16
Pull Request resolved: #65
@george-qi
Copy link
Contributor Author

@george-qi has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@george-qi george-qi mentioned this pull request Jul 7, 2022
@facebook-github-bot facebook-github-bot deleted the gh/george-qi/33/head branch July 11, 2022 14:17
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants