Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.parallel.gather does not accept scalars (0-dim tensors) of v0.4 #6983

Closed
L0SG opened this issue Apr 26, 2018 · 8 comments
Closed

nn.parallel.gather does not accept scalars (0-dim tensors) of v0.4 #6983

L0SG opened this issue Apr 26, 2018 · 8 comments
Assignees
Labels
todo Not as important as medium or high priority tasks, but we will work on these.

Comments

@L0SG
Copy link
Contributor

L0SG commented Apr 26, 2018

Issue description

Since v0.4 returns scalar (0-dim tensor) loss, gathering the scalar loss manually raises an error like the example.

Unsqueezing the scalar losses back to 1-dim vector like the previous versions works, but is this an intended behavior of nn.parallel.gather?

The given parallel GPU code scheme is used in Annotated Transformer implementation.

Code example

import torch.nn as nn
import torch

# GPUs to use
devices = [0, 1, 2, 3]

# toy feed-forward net
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 5)
        self.fc3 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

# define random data
random_input = torch.randn((4, 10))
random_target = torch.randn((4))

net = Net().cuda()

# replicate nets, scatter inputs and parallel_apply
replicas = nn.parallel.replicate(net, devices)
random_input_scatter = nn.parallel.scatter(random_input, devices)
replicas = replicas[:len(random_input_scatter)]
outputs = nn.parallel.parallel_apply(replicas, random_input_scatter)

# replicate losses, scatter targets, zip output-target pairs
criterion = nn.MSELoss()
criterion = nn.parallel.replicate(criterion, devices)
random_target_scatter = nn.parallel.scatter(random_target, devices)
output_target_pairs = [(output, target) for output, target in zip(outputs, random_target_scatter)]

# this results in scalar (0-dim tensor) losses with v0.4
loss = nn.parallel.parallel_apply(criterion, output_target_pairs)
# gathering 0-dim tensors raises error in line 54 of parallel.gather function => ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
loss_gather = nn.parallel.gather(loss, target_device=devices[0])

# unsqueezing the scalar loss to vector (like from previous versions) works as intended
"""
for idx in range(len(loss)):
    loss[idx] = loss[idx].unsqueeze(0)
loss_gather = nn.parallel.gather(loss, target_device=devices[0])
"""

Traceback (most recent call last):
File "gather_bug.py", line 43, in
loss_gather = nn.parallel.gather(loss, target_device=devices[0])
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
return gather_map(outputs)
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/scatter_gather.py", line 55, in gather_map
return Gather.apply(target_device, dim, *outputs)
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in forward
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
File "/home/tkdrlf9202/anaconda3/envs/tkdrlf9202_p36/lib/python3.6/site-packages/torch/nn/parallel/_functions.py", line 54, in
ctx.input_sizes = tuple(map(lambda i: i.size(ctx.dim), inputs))
RuntimeError: dimension specified as 0 but tensor has no dimensions

System Info

PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.9) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 390.30
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.6.0.21
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_static.a
/usr/local/MATLAB/R2017b/bin/glnxa64/libcudnn.so.5.1.5
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn.so.6.0.21
/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] msgpack-numpy (0.4.1)
[pip] numpy (1.13.3)
[pip] torch (0.4.0)
[pip] torchfile (0.1.0)
[pip] torchnet (0.0.1)
[pip] torchtext (0.2.3)
[pip] torchvision (0.2.1)
[conda] cuda90 1.0 h6433d27_0 pytorch
[conda] pytorch 0.4.0 py36_cuda9.0.176_cudnn7.1.2_1 [cuda90] pytorch
[conda] torchfile 0.1.0
[conda] torchnet 0.0.1
[conda] torchtext 0.2.3
[conda] torchvision 0.2.1 py36_1 pytorch

  • PyTorch or Caffe2: PyTorch
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): NA
  • OS: ubuntu 16.04
  • PyTorch version: 0.4
  • Python version: 3.6
  • CUDA/cuDNN version: 9.0.176/7.0.5
  • GPU models and configuration: 4x Titan Xp
  • GCC version (if compiling from source): N/A
  • CMake version: 3.5.1
  • Versions of any other relevant libraries: N/A
@apaszke
Copy link
Contributor

apaszke commented Apr 26, 2018

That's a bug, we should handle this. Thanks for the report!

@zou3519 zou3519 added the bug label Apr 26, 2018
@gchanan
Copy link
Contributor

gchanan commented Apr 26, 2018

This has the same problem as torch.cat: for non-scalar tensors x and y with shapes (..., X, ...) and (..., Y, ...), gather gives a result with shape (..., X + Y, ...), i.e. the result has the same dimensionality as the inputs. But this doesn't work with scalars because you can't have a 0-dimensional tensor with multiple elements.

With torch.cat we just gave an error message (which is what Numpy does); if someone really wants a 1-dimensional result they can unsqueeze, which is clear from the error message. Should we do the same here? The alternative is to just give a 1-dimensional result, but it's sort of non-obvious and not consistent with torch.cat.

Thoughts?

@gchanan
Copy link
Contributor

gchanan commented Apr 26, 2018

Another possibility would be to change the API to return a list/tuple from dp.gather; from there you could just add the results (works for everything) or concatenate them (works for everything except for scalars), but we could push that problem to other apis instead of data parallel ones.

@gchanan
Copy link
Contributor

gchanan commented Apr 26, 2018

Yet another possibility: we could use torch.stack semantics, which are well defined for scalars.

@gchanan
Copy link
Contributor

gchanan commented Apr 26, 2018

CC @apaszke @teng-li @ailzhang

@apaszke
Copy link
Contributor

apaszke commented Apr 30, 2018

Hmm good point. I don't think we should be adding a flag like this. It would be much simpler to check that no inputs are scalars, and raise a readable error if they are.

@ailzhang ailzhang self-assigned this May 14, 2018
@zou3519 zou3519 added the todo Not as important as medium or high priority tasks, but we will work on these. label May 14, 2018
@ssnl
Copy link
Collaborator

ssnl commented Jun 1, 2018

@ailzhang this might be fixed by #7973

@ssnl
Copy link
Collaborator

ssnl commented Jun 1, 2018

Indeed fixed by #7973. However, your script should be changed a bit. In particular, random_input_scatter = nn.parallel.scatter(random_input, devices) should be

random_input_scatter = nn.parallel.scatter([random_input], devices)

because otherwise each device will see a single tensor in parallel_apply and then removes the first dimension here

output = module(*input, **kwargs)

so your target should also be.

random_target = torch.randn((4, 1))

Without these changes it will only work with batch size == num devices.

However, considering how subtle this is, I think it makes sense to let parallel_apply accept list of tensors as inputs. I'll send out a PR to update that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
todo Not as important as medium or high priority tasks, but we will work on these.
Projects
None yet
Development

No branches or pull requests

6 participants