Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048

Open
QizhongYao opened this issue Jun 25, 2019 · 9 comments
Open

Comments

@QizhongYao
Copy link

Environment:
Python 3.5
torch 1.1.0
torchvision 0.3.0

Reproducible example:
import torch
import torchvision
model = torchvision.models.inception_v3().cuda()
model = torch.nn.DataParallel(model, [0, 1])
x = torch.rand((8, 3, 299, 299)).cuda()
model.forward(x)

Error:

Traceback (most recent call last):
File "", line 1, in
File "env/lib/python3.5/site-packages/torch/nn/modules/module.py", line 493, in call
result = self.forward(*input, **kwargs)
File "env/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 153, in forward
return self.gather(outputs, self.output_device)
File "/env/lib/python3.5/site-packages/torch/nn/parallel/data_parallel.py", line 165, in gather
return gather(outputs, output_device, dim=self.dim)
File "/env/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 67, in gather
return gather_map(outputs)
File "env/lib/python3.5/site-packages/torch/nn/parallel/scatter_gather.py", line 62, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: new() missing 1 required positional argument: 'aux_logits'

I guess the error occurs because the output of inception_v3 was changed from tuple to namedtuple.

@fmassa
Copy link
Member

fmassa commented Jun 25, 2019

Yes, that's probably the reason.

I believe we have three options:

  1. remove namedtuple and use tuple, as before, so basically reverting some of the changes in make auxiliary heads in pretrained models optional #828
  2. fix PyTorch DataParallel to suppose namedtuple https://github.com/pytorch/pytorch/blob/c8b5f1d2f8f31781e664917f132af31a9abf9cbd/torch/nn/parallel/scatter_gather.py#L5-L31
  3. encourage the use of DistributedDataParallel instead, and do nothing.

I'd vote for option number 2.

ccing @TheCodez and @Separius , who have commented / sent the aforementioned PR initially. What are your thoughts here?

@TheCodez
Copy link
Contributor

@fmassa I agree option 2 would be the best to avoid problems in the future

@Separius
Copy link
Contributor

@fmassa yeah second option makes the most sense

@YongWookHa
Copy link

The problem seems still there.
I've made a little trick of detouring this unsupported namedtuple problem.

It's a kind of mixed solution of @fmassa 's option 1 and 2.
It doesn't change inception_v3 of torchvision.models but change namedtuple to dict at the parallel parts.

Change gether function in scatter_gather.py file to below.

def gather(outputs, target_device, dim=0):
    r"""
    Gathers tensors from different GPUs on a specified device
      (-1 means the CPU).
    """
    def gather_map(outputs):
        def isnamedtupleinstance(x):
            t = type(x)
            b = t.__bases__
            if len(b) != 1 or b[0] != tuple: return False
            f = getattr(t, '_fields', None)
            if not isinstance(f, tuple): return False
            return all(type(n)==str for n in f)
            
        out = outputs[0]
        if isinstance(out, torch.Tensor):
            return Gather.apply(target_device, dim, *outputs)
        if out is None:
            return None
            
        if isnamedtupleinstance(out):
            outputs = [dict(out._asdict()) for out in outputs]
            out = outputs[0]

        if isinstance(out, dict):
            if not all((len(out) == len(d) for d in outputs)):
                raise ValueError('All dicts must have the same number of keys')
            return type(out)(((k, gather_map([d[k] for d in outputs]))
                              for k in out))           
        
        return type(out)(map(gather_map, zip(*outputs)))

    # Recursive function calls like this create reference cycles.
    # Setting the function to None clears the refcycle.
    try:
        res = gather_map(outputs)
    finally:
        gather_map = None
    return res

And you can get the result of inception_v3 model by below.

outputs, aux_outputs = self.model(imgs).values()

Don't forget to add .values() at the end.

I know this is not the best solution.
But I just hope this could help someone for now.

@soumendukrg
Copy link

soumendukrg commented Nov 15, 2019

I tried out your solution @YongWookHa, however, now I am getting an error to calculate loss function
Error:

File "/home/min/a/ghosh37/distiller/distiller/apputils/image_classifier.py", line 588, in train loss = criterion(output, target) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/modules/loss.py", line 916, in forward ignore_index=self.ignore_index, reduction=self.reduction) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/functional.py", line 2009, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/functional.py", line 1317, in log_softmax ret = input.log_softmax(dim) AttributeError: 'dict_values' object has no attribute 'log_softmax'

EDIT: I figured out the problem. Was an issue with dict.

@YongWookHa
Copy link

I tried out your solution @YongWookHa, however, now I am getting an error to calculate loss function
Error:

File "/home/min/a/ghosh37/distiller/distiller/apputils/image_classifier.py", line 588, in train loss = criterion(output, target) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 541, in __call__ result = self.forward(*input, **kwargs) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/modules/loss.py", line 916, in forward ignore_index=self.ignore_index, reduction=self.reduction) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/functional.py", line 2009, in cross_entropy return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction) File "/home/min/a/ghosh37/distiller/env/lib64/python3.6/site-packages/torch/nn/functional.py", line 1317, in log_softmax ret = input.log_softmax(dim) AttributeError: 'dict_values' object has no attribute 'log_softmax'

EDIT: I figured out the problem. Was an issue with dict.

I think you forgot to add .values() when you get outputs from your inception model.
So, have you solved the problem?

@soumendukrg
Copy link

Yes, I did add values, but I was copying model.values only to single output instead of output, aux_output, and so when computing loss function on dict instead of a tensor, I got the error.

Thanks, but your method solved me hours of training time. Earlier, I had to train inception only one a single GPU, not modifying pytorch file using your code, I am able to train on more than 1 GPU.

nzmora added a commit to IntelLabs/distiller that referenced this issue Apr 27, 2020
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
@sanka4rea
Copy link

sanka4rea commented Aug 6, 2020

I tried out your solution @YongWookHa, but got an error as shown below:

`train Loss: 0.9664 Acc: 0.5738

Traceback (most recent call last):
File "/home/xxx/anaconda3/envs/torch0721/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 153, in
num_epochs=25, is_inception=True)
File "", line 91, in train_model
outputs, aux_outputs = model(inputs).values()
RuntimeError: Could not run 'aten::values' with arguments from the 'CUDA' backend. 'aten::values' is only available for these backends: [SparseCPU, SparseCUDA, Autograd, Profiler, Tracer].`

Could you please give me some suggestions?

Edit: fixed. As there is no need to use the aux classifiers for inference, i change the code to:

if phase == 'train':

    outputs, aux_outputs = model(inputs).values()
    loss1 = criterion(outputs, labels)
    loss2 = criterion(aux_outputs, labels)
    loss = loss1 + 0.4 * loss2

else:

    outputs = model(inputs)
    loss = criterion(outputs, labels)

Thanks!

OZA15015 pushed a commit to OZA15015/pruning that referenced this issue Sep 6, 2020
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](IntelLabs/distiller#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
@QiangZiBro
Copy link

I used APEX.amp with inceptionv3, got the same problem:

  • APEX 0.1
  • torch 1.13
  • torchvision 0.13
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/apex/amp/_initialize.py", line 198, in new_fwd
    return applier(output, output_caster)
  File "/opt/conda/lib/python3.8/site-packages/apex/amp/_initialize.py", line 51, in applier
    return type(value)(applier(v, fn) for v in value)
TypeError: __new__() missing 1 required positional argument: 'aux_logits'

To solve this problem, I replaced namedtuple to function returning tuple, and it works:

torchvision.models.inception.InceptionOutputs = lambda a,b:(a,b)

michaelbeale-IL pushed a commit to IntelLabs/distiller that referenced this issue Apr 24, 2023
* Merge pytorch 1.3 commits

This PR is a fix for issue #422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR #425 [comments](#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
fangvv pushed a commit to fangvv/distiller that referenced this issue May 23, 2023
* Merge pytorch 1.3 commits

This PR is a fix for issue IntelLabs#422.

1. ImageNet models usually use input size [batch, 3, 224, 224], but all Inception models require an input image size of [batch, 3, 299, 299].

2. Inception models have auxiliary branches which contribute to the loss only during training.  The reported classification loss only considers the main classification loss.

3. Inception_V3 normalizes the input inside the network itself.  More details can be found in @soumendukrg's PR IntelLabs#425 [comments](IntelLabs#425 (comment)).

NOTE: Training using Inception_V3 is only possible on a single GPU as of now. This issue talks about this problem. I have checked and this problem persists in torch 1.3.0:
[inception_v3 of vision 0.3.0 does not fit in DataParallel of torch 1.1.0 #1048](pytorch/vision#1048)

Co-authored-by: Neta Zmora <neta.zmora@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants