Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot access data pointer of Tensor that doesn't have storage #2652

Closed
rabeehk opened this issue Nov 27, 2020 · 16 comments
Closed

Cannot access data pointer of Tensor that doesn't have storage #2652

rabeehk opened this issue Nov 27, 2020 · 16 comments
Assignees

Comments

@rabeehk
Copy link

rabeehk commented Nov 27, 2020

Hi,
I am running a code on pytorch XLA 1.7, python 3.7, and I have getting the following error. The line it happens it is computing the loss. the code runs fine on GPU. To give more context, I am using seq2seq model from huggingface repo, but I modified their code and added adapter layers, then I set all parameters of the model to grad = False and keep only some adapter layers as parameters to finetune. Thank you for the help.

92c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-70d903b9912b3e3e.arrow
  0%|          | 0/160 [00:00<?, ?it/s], 20966.66ex/s]
Exception in device=TPU:3: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Exception in device=TPU:4: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Exception in device=TPU:6: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Exception in device=TPU:0: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Exception in device=TPU:7: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /home/rabeeh/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4...
Dataset glue downloaded and prepared to /home/rabeeh/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4. Subsequent calls will reuse this data.

Exception in device=TPU:5: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Exception in device=TPU:1: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Exception in device=TPU:2: Cannot access data pointer of Tensor that doesn't have storage
Traceback (most recent call last):
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 227, in _mp_fn
    main()
  File "/home/rabeeh/internship/seq2seq/finetune_t5_trainer.py", line 164, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 775, in train
    tr_loss += self.training_step(model, inputs)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/transformers/trainer.py", line 1126, in training_step
    loss.backward()
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
  0%|          | 0/160 [01:10<?, ?it/s]Traceback (most recent call last):
  File "/home/rabeeh//internship/seq2seq/xla_spawn.py", line 106, in <module>
    main()
  File "/home/rabeeh//internship/seq2seq/xla_spawn.py", line 90, in main
    xmp.spawn(mod._mp_fn, args=(), nprocs=args.num_cores)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 395, in spawn
    start_method=start_method)
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 157, in start_processes
    while not context.join():
  File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 112, in join
    (error_index, exitcode)
Exception: process 4 terminated with exit code 17

I am happy to provide the code to reproduce the error . to explain more,
I am defining a model like:

class Model(nn.Module):
  self.all_parameters 
  if self.add_adapters:
            self.adapter = Adapter(config.d_model, adapter_config)

and then in the main loop I put the gradient of all_parameters expect the ones inside the Adapter to False.
I call model.to(device) and then compute the loss. Now the parameters of Adapter are the only parameters which can have gradient.

Thank you.

@rabeehk
Copy link
Author

rabeehk commented Nov 28, 2020

Hi
I do not think the code has the issue, since if I get a local machine and use docker file 1.7 it works.
On my job scheduler, I submit the jobs and use the conda environment of pytorch xla 1.7, could you update this environment, might be some bugs.
I also tried with nightly version and it freezes during training.
thank you.

@tmabraham
Copy link

I am also having a similar issue when trying to write and run Kaggle Kernel to train an EfficientNet. I also get the error on loss.backward() step and I am unsure what is causing this error.

cc: @dlibenzi @zcain117

@JackCaoG
Copy link
Collaborator

Pytorch/XLA tensor does not have storage and won't try to access the storage. This might be that we used default kernel for one of the op and that implementation tried to access the storage. @ailzhang Could you take a look?

@tmabraham
Copy link

@JackCaoG I was working on a regular image classification problem with standard Cross Entropy Loss. There isn't any custom function I am using (unlike #1667). If I use an older version of PyTorch XLA, it seems to work fine. Could there have been some recent changes that have lead to this issue?

@JackCaoG
Copy link
Collaborator

@tmabraham My guess is that pytorch changed one of its default kernel for some op to access the data pointer of the tensor. We saw this kind of error come up sometimes and we usually asked pytorch folks to fix their default kernel to not access storage since it should not assume every backend has storage.

@ailzhang
Copy link
Contributor

ailzhang commented Dec 1, 2020

Hi @tmabraham , would you mind providing a minimal repro script? That'd help locate the bug quicker, thanks a lot!!

@ailzhang ailzhang self-assigned this Dec 2, 2020
@tmabraham
Copy link

tmabraham commented Dec 3, 2020

@ailzhang Upon further investigation, I think this bug is specific to EfficientNet (at least the one implemented in timm).

Here is a Kaggle Kernel where I replaced the Colab example with timm's EfficientNet model.

Since this is a different scenario than the seq2seq model that the original poster had issues with, should I open a new issue?

@tmabraham
Copy link

@ailzhang Just wanted to follow up regarding whether I should make a separate issue for the EfficientNet bug.

@eedalong
Copy link
Contributor

@rabeehk I ran into this problem when i train an ASR model, because there's a CUDA op that's not supported in xla, so when we sink into this op, GPU cannot access data from the pointer of tensor. What i did is change tensor device to cuda before this op and change the output tensor's device of this op to xla device. And that will fix the problem

@rabeehk
Copy link
Author

rabeehk commented Dec 11, 2020 via email

@tmabraham
Copy link

@ailzhang sorry for repeatedly tagging you, but I am hoping that issue can get resolved, because I would love to be able to train EfficientNet models from the timm package with PyTorch XLA.

@ailzhang
Copy link
Contributor

Hi @tmabraham , sorry for the late reply! I took a look and pytorch/pytorch#49439 should fix the issue. It'll take some time to review and land the fix, but I'll let you know once it's ready for you to try out a new nightly!

@tmabraham
Copy link

Sounds great, thanks for the update!

@ailzhang
Copy link
Contributor

@tmabraham The fix has been landed yesterday, I believe today's nightly should work, would you mind giving it a try? Thanks!

@tmabraham
Copy link

@ailzhang Thanks for letting me know! It looks like it's running, but the loss, accuracy are way off and the time it takes to run is quite slow. Maybe that should be a different issue though.

@ailzhang
Copy link
Contributor

@tmabraham Yea feel free to open a new issue with the perf report! I'm going to close this issue for now since it's fixed. Thanks for report!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants