Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors. #80809

Closed
yqi19 opened this issue Jul 3, 2022 · 15 comments

Comments

@yqi19
Copy link

yqi19 commented Jul 3, 2022

🐛 Describe the bug

Hi, congratulations on your amazing work.
When I want to continue my training on model by loading checkpoint.py, under the circumstances that my GPUs are all perfectly fine, I got this:

2022-07-03 06:06:18 - LOGS    - Exception occurred that interrupted the training. If capturable=False, state_steps shou
ld not be CUDA tensors.
If capturable=False, state_steps should not be CUDA tensors.

Traceback (most recent call last):                                                                           
  File "/home/yu/projects/mobilevit/ml-cvnets/engine/training_engine.py", line 682, in run
    train_loss, train_ckpt_metric = self.train_epoch(epoch)
  File "/home/yu/projects/mobilevit/ml-cvnets/engine/training_engine.py", line 353, in train_epoch
    self.gradient_scalar.step(optimizer=self.optimizer)
  File "/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 338, in step
    retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs)
  File "/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 285, in _may
be_opt_step
    retval = optimizer.step(*args, **kwargs)
  File "/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/optimizer.py", line 109, in wrapper
    return func(*args, **kwargs)
  File "/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorat
e_context
    return func(*args, **kwargs)
  File "/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/adamw.py", line 161, in step
    adamw(params_with_grad,
  File "/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/adamw.py", line 218, in adamw
    func(params,
  File "/home/yu/anaconda3/envs/mobilevit/lib/python3.8/site-packages/torch/optim/adamw.py", line 259, in _single_tenso
r_adamw
    assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."

Versions

PyTorch version: 1.12.0+cu102
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 16.04.7 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~16.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.23

Python version: 3.9.12 (main, Jun  1 2022, 11:38:51)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.4.0-210-generic-x86_64-with-glibc2.23
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: NVIDIA TITAN Xp
GPU 1: NVIDIA TITAN Xp
GPU 2: NVIDIA TITAN Xp
GPU 3: NVIDIA TITAN Xp

Nvidia driver version: 465.19.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.0
[pip3] pytorchvideo==0.1.5
[pip3] torch==1.12.0
[pip3] torchvision==0.13.0
[conda] numpy                     1.23.0                   pypi_0    pypi
[conda] pytorchvideo              0.1.5                    pypi_0    pypi
[conda] torch                     1.12.0                   pypi_0    pypi
[conda] torchvision               0.13.0                   pypi_0    pypi
@jaried
Copy link

jaried commented Jul 3, 2022

I also get this problem as well.

import tianshou, gym, torch, numpy, sys
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
0.4.8 0.21.0 1.12.0+cu113 1.20.1 3.8.8 (default, Apr 13 2021, 15:08:03) [MSC v.1916 64 bit (AMD64)] win32

see:
thu-ml/tianshou#681

I set all the optimizers to the following settings, and they can train normally. I also ask, what is the problem?Does my setting have any effect on training?

optim.param_groups[0]['capturable'] = True

@L0SG
Copy link
Contributor

L0SG commented Jul 4, 2022

Hi, I am also facing the same issue when I try to load the checkpoint and resume model training on the latest pytorch (1.12).

It seems to be related with a newly introduced parameter (capturable) for the Adam and AdamW optimizers. Currently two workarounds:

  1. forcing capturable = True after loading the checkpoint (as suggested above) optim.param_groups[0]['capturable'] = True . This seems to slow down the model training by approx. 10% (YMMV depending on the setup).

  2. Reverting pytorch back to previous versions (I have been using 1.11.0).

I'm wondering whether enforcing capturable = True may incur unwanted side effects.

@jaried
Copy link

jaried commented Jul 4, 2022

Hi, I am also facing the same issue when I try to load the checkpoint and resume model training on the latest pytorch (1.12).

It seems to be related with a newly introduced parameter (capturable) for the Adam and AdamW optimizers. Currently two workarounds:

  1. forcing capturable = True after loading the checkpoint (as suggested above) optim.param_groups[0]['capturable'] = True . This seems to slow down the model training by approx. 10% (YMMV depending on the setup).
  2. Reverting pytorch back to previous versions (I have been using 1.11.0).

I'm wondering whether enforcing capturable = True may incur unwanted side effects.

I'm also wondering about whether forcing captureable=True would have unwanted side effects. I will also return to torch1.11. Thank you for your answer.

@amrosado
Copy link

amrosado commented Jul 4, 2022

I'm also having this same error with pytorch=1.12 and needed to downgrade to pytorch=1.11.

@yqi19
Copy link
Author

yqi19 commented Jul 4, 2022

Hi, I am also facing the same issue when I try to load the checkpoint and resume model training on the latest pytorch (1.12).

It seems to be related with a newly introduced parameter (capturable) for the Adam and AdamW optimizers. Currently two workarounds:

  1. forcing capturable = True after loading the checkpoint (as suggested above) optim.param_groups[0]['capturable'] = True . This seems to slow down the model training by approx. 10% (YMMV depending on the setup).
  2. Reverting pytorch back to previous versions (I have been using 1.11.0).

I'm wondering whether enforcing capturable = True may incur unwanted side effects.

Thanks guys, I successfully resolve this!

@yqi19 yqi19 closed this as completed Jul 4, 2022
@yqi19 yqi19 reopened this Jul 4, 2022
@yqi19 yqi19 closed this as completed Jul 4, 2022
@1lint
Copy link

1lint commented Jul 4, 2022

I also had this issue, my workaround was to comment out lines 202-204 in pytorch_lightning.trainer.connectors.checkpoint_connector.py

#if self.trainer.state.fn == TrainerFn.FITTING:
    # restore optimizers and schedulers state
    #self.restore_optimizers_and_schedulers()

to find the file, you can do the following (inside a jupyter notebook)

import pytorch_lightning.trainer.connectors.checkpoint_connector as module_to_edit
!code {module_to_edit.__file__}

another option is to manually load the checkpoint without the optimizers. For example to just load the saved model weights you could do

checkpoint = torch.load('/path/to/last.ckpt')
lightning_module.load_state_dict(checkpoint['state_dict'])

@amrosado
Copy link

amrosado commented Jul 4, 2022

Personally, I feel like this issue should remain open. I think this is an inconsistency between stable pytorch versions and I would appreciate being able to run my code base on future pytorch versions.

@albanD
Copy link
Collaborator

albanD commented Jul 5, 2022

Hi,

We're sorry to have introduced this regression. We will fix that in the upcoming minor release for 1.12.1
If you want this fix earlier, you can follow the official instructions to get the nightly build of PyTorch!

pytorchmergebot pushed a commit that referenced this issue Jul 6, 2022
facebook-github-bot pushed a commit that referenced this issue Jul 8, 2022
Summary:
Finish fixing #80809

Pull Request resolved: #80881
Approved by: https://github.com/jbschlosser

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/9d20af50608b146fe1c3296210a05cd8e4c60af2

Reviewed By: mehtanirav

Differential Revision: D37687409

Pulled By: albanD

fbshipit-source-id: 4b899f76cbcb582cded8649e1166df90e73d78e9
@Xact-sniper
Copy link

Xact-sniper commented Jul 16, 2022

I know that this is closed, but I've encountered this issue multiple times on a few 'colab-a-like's and this post is the first that comes up. For anyone in the future, I want to mention that instead of setting capturable = True, you can instead call .cpu() on the tensors with key "step" in the state dictionary.

In my case, I found this cobbled together bit of code to be sufficient:

    def nested_dict_iter(dict_obj, indent = 0):
        for key, value in dict_obj.items():
            if isinstance(value, dict):
                print(' ' * indent, key, ':', '{')
                TrainLoop.nested_dict_iter(value, indent + 4)
                print(' ' * indent, '}')
            elif isinstance(value, list):
                TrainLoop.nested_dict_iter(dict(zip(['list_'+str(i) for i in range(len(value))], value)), 4)                    
            else:
                #############
                #relevant portion
                if 'step' in key:
                    try:
                        tst = value.cpu()
                        assert torch.all(tst == value)
                    except:
                        pass
                    dict_obj[key]=tst
                print(' ' * indent, key, ':', value)
    def iter_nested_dict(dict_obj):
        print('{')
        TrainLoop.nested_dict_iter(dict_obj, 4)
        print('}')```

@franchesoni
Copy link

Hi,

We're sorry to have introduced this regression. We will fix that in the upcoming minor release for 1.12.1 If you want this fix earlier, you can follow the official instructions to get the nightly build of PyTorch!

@albanD Could you explain what is capturable and its side effects if any? when (not) to use it?

@albanD
Copy link
Collaborator

albanD commented Jul 20, 2022

Hi,

This is to be used in conjunction with cuda graph. In particular, all ops must happen on the GPU for cuda graph to be able to "capture" all of them.
Passing the capturable flag will ensure that this is the case so that you can capture a whole forward/backward/optimizer step in a single cuda graph.

atalman pushed a commit to atalman/pytorch that referenced this issue Jul 21, 2022
atalman added a commit that referenced this issue Jul 21, 2022
Finish fixing #80809
Pull Request resolved: #80881
Approved by: https://github.com/jbschlosser

Co-authored-by: albanD <desmaison.alban@gmail.com>
@cliffordkleinsr
Copy link

cliffordkleinsr commented Aug 2, 2022

I was training an ESRGAN and my solution after kernel timeout was to reload the model states and downgrade pytorch to 1.11 with cu11.3:
if using colab do a
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

if you are using cuda binaries 11.6 with pytorch 1.12.0 then on command prompt do a :

pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116

@dongsiwen
Copy link

File "/root/autodl-tmp/DietNeRF-master/dietnerf/run_nerf.py", line 5, in
import clip_utils
ModuleNotFoundError: No module named 'clip_utils'

@zhilyzhang
Copy link

Hi,

We're sorry to have introduced this regression. We will fix that in the upcoming minor release for 1.12.1 If you want this fix earlier, you can follow the official instructions to get the nightly build of PyTorch!

It works when upcoming release for torch1.12.1. Thank you.

@linminhtoo
Copy link

hi all, without adding optim.param_groups[0]['capturable'] = True, I get an assertion error, "If capturable=True"
and when i add this line, i also get the assertion error "If capturable=False".

It is really puzzling. Any idea what's happening? I'm on torch 1.13.0+cu117 and I tried torch 2.0.0+cu117, both give the same problem. The optimizer was trained on a machine with torch 1.10.0, is this the root cause? But it's really difficult for me to install torch 1.10.0 on my current machine.

alexanderwerning added a commit to fgnt/padertorch that referenced this issue Nov 7, 2023
Resuming from a checkpoint in torch==1.12.0 is broken, this was fixed in torch=1.12.1. This workaround allows to load checkpoints with version 1.12.0 as well. In pytorch/pytorch#80809 a 10% slowdown was reported, which I did not observe.
alexanderwerning added a commit to alexanderwerning/padertorch that referenced this issue Feb 12, 2024
Resuming from a checkpoint in torch==1.12.0 is broken, this was fixed in torch=1.12.1. This workaround allows to load checkpoints with version 1.12.0 as well. In pytorch/pytorch#80809 a 10% slowdown was reported, which I did not observe.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests