Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Embedding with max_norm shows unstable behavior and causes sometimes runtime error. #26596

Closed
dschaehi opened this issue Sep 21, 2019 · 10 comments
Assignees
Labels
high priority module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@dschaehi
Copy link

dschaehi commented Sep 21, 2019

馃悰 Bug

An nn.Embedding object with max_norm set to True causes a RuntimeError that is hard to track.

To Reproduce

The following code causes a RuntimeError. The error can be avoided by removing the max_norm feature or by swapping Line a and Line b in the code.

import torch
import torch.nn as nn

n, d, m = 3, 5, 7
batch_size = 11

embedding = nn.Embedding(n, d, max_norm=True)
W = torch.randn((m, d), requires_grad=True)
optimizer = torch.optim.Adam(list(embedding.parameters()) + [W], lr=1e-3)

optimizer.zero_grad()
idx = torch.tensor([1, 2])

a = embedding.weight @ W.t()  # Line a 
b = embedding(idx) @ W.t()    # Line b

out = (a.unsqueeze(0) + b.unsqueeze(1))
loss = out.sigmoid().prod()
loss.backward()
optimizer.step()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-472-103ef18503d8> in <module>
     17 out = (a.unsqueeze(0) + b.unsqueeze(1))
     18 loss = out.sigmoid().prod()
---> 19 loss.backward()
     20 optimizer.step()

~/miniconda3/envs/kg/lib/python3.7/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    116                 products. Defaults to ``False``.
    117         """
--> 118         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    119 
    120     def register_hook(self, hook):

~/miniconda3/envs/kg/lib/python3.7/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     91     Variable._execution_engine.run_backward(
     92         tensors, grad_tensors, retain_graph, create_graph,
---> 93         allow_unreachable=True)  # allow_unreachable flag
     94 
     95 

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [3, 5]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Expected behavior

There shouldn't be any error when running the code above.
Strangely, there is no RuntimeError when Line a and Line b are swapped. This is something that has to be investigated.

Environment

PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.3 LTS
GCC version: (Homebrew gcc 5.5.0_4) 5.5.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 430.26
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.3

Versions of relevant libraries:
[pip] botorch==0.1.3
[pip] gpytorch==0.3.5
[pip] numpy==1.17.2
[pip] torch==1.2.0
[pip] torchvision==0.4.0a0+6b959ee
[conda] blas 1.0 mkl
[conda] botorch 0.1.3 pypi_0 pypi
[conda] gpytorch 0.3.5 pypi_0 pypi
[conda] libblas 3.8.0 12_mkl conda-forge
[conda] libcblas 3.8.0 12_mkl conda-forge
[conda] liblapack 3.8.0 12_mkl conda-forge
[conda] mkl 2019.4 243
[conda] pytorch 1.2.0 py3.7_cuda10.0.130_cudnn7.6.2_0 pytorch
[conda] torchvision 0.4.0 py37_cu100 pytorch

Additional context

cc @ezyang @gchanan @zou3519 @jlin27 @albanD @mruberry

@VitalyFedyunin VitalyFedyunin added module: docs Related to our documentation, both in docs/ and docblocks module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 24, 2019
@VitalyFedyunin
Copy link
Contributor

Per documentation (of functional.embedding):

max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place.

So we need to update https://pytorch.org/docs/stable/nn.html#embedding accordingly.

@dschaehi
Copy link
Author

dschaehi commented Sep 24, 2019

I see. Could you explain why swapping Line a and Line b then does not lead to the same error?

By the way, setting max_norm to True allows for back propagation in normal case. Otherwise it won't make sense to use max_norm at all in practice. Perhaps what matters in my example above is when the normalization step happens, i.e., whether it happens immediately after back propagation or during back propagation.

@VitalyFedyunin
Copy link
Contributor

Because when you call embedding(idx) it actually does F.embedding which (with max_norm provided) inplace modifies embedding.weight

@dschaehi
Copy link
Author

dschaehi commented Sep 25, 2019

I am wondering whether one could introduce an embedding normalization function that is a subclass of Function, which applies torch.embedding_renorm_ in the forward pass, but leaves the gradients unchanged in the backward pass. This way one could avoid the error above?

@nirkra
Copy link

nirkra commented Aug 16, 2020

This is a super critical bug, when used with DDP, causes embedding to be out of alignment and is detrimental for training. Furthermore, warning doesn't exist in nn.Embedding (only in F.ebmedding), assuming someone would read the code and realize underlying F.embedding, this must be detrimental for hundreds of projects that rely on embedding combined with DDP.....

@VitalyFedyunin VitalyFedyunin removed module: docs Related to our documentation, both in docs/ and docblocks triage review labels Aug 28, 2020
@VitalyFedyunin
Copy link
Contributor

I will bump priority on it as it seems not only docs issue, and many people hitting it.

@kurtamohler kurtamohler self-assigned this Sep 21, 2020
@kurtamohler
Copy link
Collaborator

kurtamohler commented Sep 21, 2020

To me, it seems like we have at least two options to solve this issue:

  1. Keep the current behavior and update torch.nn.Embedding's documentation to mention the in-place modification. Then the user will just have to make sure to clone the weight tensor before using it for another calculation that occurs before the torch.nn.Embedding.forward() call. Perhaps the documentation should have a clear warning that specifically suggests cloning to avoid this error. Note that in the repro script above, replacing the line a = embedding.weight @ W.t() with a = embedding.weight.clone() @ W.t() fixes the error.

  2. Changetorch.nn.Embedding.forward() to avoid the in-place modification by first cloning the weight tensor before giving it to torch.nn.functional.embedding, then setting self.weight to the new clone after the embedding call. The code snippet below shows what I mean. This fixes the error in the repro script and gives correct result. However, this would be a BC breaking change (although maybe this type of BC break is ok?). Currently, if the user provides a tensor weight_ argument to torch.nn.Embedding, we don't create a clone of that tensor, so when it's modified in place, the user's reference to that weight tensor will still be pointing to the modified tensor. This behavior seems to be purposeful, since we have a test that checks this behavior explicitly (test_embedding_from_pretrained_options in test_nn.py). So if we do decide to create a clone of the weight tensor in the forward method, this behavior would change--the weight tensor that the user provided to torch.nn.Embedding would no longer be modified when the forward function is called.

        weight = self.weight if (self.max_norm is None) else Parameter(self.weight.clone())
        result = F.embedding(
            input, weight, self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse)
        if self.max_norm is not None:
            self.weight = weight
        return result

Even though option 2 might seem more ideal, but maybe we should go with option 1, to avoid a BC break and to keep torch.nn.Embedding and torch.nn.functional.embedding's behavior consistent. @VitalyFedyunin, does this make sense?

@VitalyFedyunin
Copy link
Contributor

Both options make sense, but the first variant is much easier, especially taking into account that where is CPP version of nn, and we would need to keep in sync both as well as take perf impacts.

@tonygracious
Copy link

tonygracious commented Jun 19, 2023

I am facing the same issue in pytorch 1.13.1

@kurtamohler
Copy link
Collaborator

@tonygracious, as the documentation mentions, if you're doing operations on embedding.weight which you want to be differentiable, you just have to clone it first with embedding.weight.clone().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants