-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Open
Labels
needs reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.Ensure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
- Create a computation graph with a sum operation and another operation which is on gpu.
TensorIterator expected type CUDAByteType but got CPUByteType (check_type_conversions at ../aten/src/ATen/native/TensorIterator.cpp:547)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6a (0x7f860917d8fa in /usr/local/lib/python3.5/dist-packages/torch/lib/libc10.so)
frame #1: at::TensorIterator::check_type_conversions() + 0x24d (0x7f8668082b2d in /usr/local/lib/python3.5/dist-packages/torch/lib/libcaffe2.so)
frame #2: at::TensorIterator::Builder::build() + 0x2ef (0x7f8668087d3f in /usr/local/lib/python3.5/dist-packages/torch/lib/libcaffe2.so)
frame #3: at::TensorIterator::binary_op(at::Tensor&, at::Tensor const&, at::Tensor const&) + 0x309 (0x7f8668089589 in /usr/local/lib/python3.5/dist-packages/torch/lib/libcaffe2.so)
frame #4: at::native::mul_out(at::Tensor&, at::Tensor const&, at::Tensor const&) + 0x95 (0x7f8667f284e5 in /usr/local/lib/python3.5/dist-packages/torch/lib/libcaffe2.so)
frame #5: at::native::mul(at::Tensor const&, at::Tensor const&) + 0x43 (0x7f8667f29553 in /usr/local/lib/python3.5/dist-packages/torch/lib/libcaffe2.so)
frame #6: at::TypeDefault::mul(at::Tensor const&, at::Tensor const&) const + 0x4d (0x7f86682a601d in /usr/local/lib/python3.5/dist-packages/torch/lib/libcaffe2.so)
frame #7: torch::autograd::VariableType::mul(at::Tensor const&, at::Tensor const&) const + 0x302 (0x7f8609808842 in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so.1)
frame #8: <unknown function> + 0x54df7b (0x7f86098d3f7b in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so.1)
frame #9: <unknown function> + 0x5c0e43 (0x7f8609946e43 in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so.1)
frame #10: torch::jit::InterpreterState::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&) + 0x31 (0x7f8609942651 in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so.1)
frame #11: torch::jit::GraphExecutor::run(std::vector<c10::IValue, std::allocator<c10::IValue> >&) + 0x1c4 (0x7f8609924a74 in /usr/local/lib/python3.5/dist-packages/torch/lib/libtorch.so.1)
frame #12: ./MainRandomizeRnnDistance() [0x47fd12]
frame #13: ./MainRandomizeRnnDistance() [0x45da0d]
frame #14: __libc_start_main + 0xf0 (0x7f85ffef3830 in /lib/x86_64-linux-gnu/libc.so.6)
frame #15: ./MainRandomizeRnnDistance() [0x459fe9]
Expected behavior
Having no problem.
Environment
Collecting environment information...
PyTorch version: 1.0.0a0+5510196
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.5 LTS
GCC version: (GCC) 8.2.0
CMake version: version 3.12.0
Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: GPU 0: GeForce GTX 1070
Nvidia driver version: 410.78
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.3.1
/usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect
Additional context
I think this can be solved by the following change
@torch.jit.script
def batch_sum(data, mask, dims):
data = data * mask.type_as(data)
for _ in range(dims.size(0)):
data = data.sum(1)
mask = torch.ones([data.size(0)], dtype=torch.uint8).type_as(mask)
#mask = torch.ones([data.size(0)], dtype=torch.uint8)
dims = dims[:0] # empty tensor
return data, mask, dims
Metadata
Metadata
Assignees
Labels
needs reproductionEnsure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.Ensure you have actionable steps to reproduce the issue. Someone else needs to confirm the repro.oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue