Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchScript] kind_.is_prim() INTERNAL ASSERT FAILED #54133

Open
n-gao opened this issue Mar 17, 2021 · 3 comments
Open

[TorchScript] kind_.is_prim() INTERNAL ASSERT FAILED #54133

n-gao opened this issue Mar 17, 2021 · 3 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@n-gao
Copy link

n-gao commented Mar 17, 2021

馃悰 Bug

Computing second oder grads in TorchScript works a single time but when used again, fails.

To Reproduce

import torch
import torch.nn as nn

class Laplacian(nn.Module):
    def __init__(
            self,
            model: nn.Module):
        super().__init__()
        self.model = model
        
    def forward(self, x: torch.Tensor):
        batch_size, n_features = x.shape
        # We have to split the input to compute individual grads
        x_splitted = x.requires_grad_(True).split(1, -1)
        x = torch.cat(x_splitted, -1)
        # Forward pass
        y = self.model(x)
        # First order gradients
        dy_dx = torch.autograd.grad(
            [y.sum()], [x], create_graph=True, allow_unused=False)[0]
        if dy_dx is not None:
            dy_dx = dy_dx.squeeze()
            # Second order gradients
            d2y_dx2 = torch.zeros((batch_size, 1), device=x.device)
            for x_i in x_splitted:
                val = torch.autograd.grad([dy_dx.sum()],
                                          [x_i], retain_graph=True)[0]
                if val is not None:
                    d2y_dx2 += val
            return -0.5 * ((dy_dx**2).sum(-1) + d2y_dx2.sum(-1))
        return torch.zeros(())

class Model(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.sin(x).sum(-1)

One can initialize the model

scripted_laplacian = torch.jit.script(Laplacian(Model()))

And the first call works just fine

scripted_laplacian(torch.randn(10, 3))
tensor([-1.0881, -0.8601, -1.7254, -0.7840, -0.6763, -0.7716, -0.3451, -0.8867,
       -0.1406,  0.0031], grad_fn=<MulBackward1>)

However, if I now call the model a second time it fails

scripted_laplacian(torch.randn(10, 3))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-5-62893995fdbc> in <module>
----> 1 scripted_laplacian(torch.randn(10, 3))

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

RuntimeError: kind_.is_prim() INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/ir/ir.cpp":1098, please report a bug to PyTorch. Only prim ops are allowed to not have a registered operator but aten::cat doesn't have one either. We don't know if this op has side effects.

Expected behavior

The second call should return correctly like the first one.

Environment

  • PyTorch Version (e.g., 1.0): 1.8
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Python version: 3.8
  • CUDA/cuDNN version: 10.2
  • GPU models and configuration: GTX 1080

Additional context

cc @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 17, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Mar 17, 2021
@n-gao
Copy link
Author

n-gao commented Mar 17, 2021

This seems to be related to #24243.
If I replace

x.requires_grad_(True).split(1, -1)

by

x_splitted = list(x.requires_grad_(True).split(1, -1))

It throws another error (only on the second call!):

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-29-62893995fdbc> in <module>
----> 1 scripted_laplacian(torch.randn(10, 3))

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "<ipython-input-25-7a172cd7c8d6>", line 19, in forward
        y = self.model(x)
        # First order gradients
        dy_dx = torch.autograd.grad(
                ~~~~~~~~~~~~~~~~~~~ <--- HERE
            [y.sum()], [x], create_graph=True)[0]
        if dy_dx is not None:
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

@eellison
Copy link
Contributor

CC @Krovatkin, you were looking into this error: RuntimeError: kind_.is_prim() INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/jit/ir/ir.cpp":1098, please report a bug to PyTorch. Only prim ops are allowed to not have a registered operator but aten::cat doesn't have one either. We don't know if this op has side effects..

RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.

May be a different issue.

@SplitInfinity SplitInfinity added this to To do in JIT Performance via automation Mar 23, 2021
@SplitInfinity SplitInfinity removed this from Need triage in JIT Triage Mar 23, 2021
@eellison eellison assigned eellison and Krovatkin and unassigned eellison Mar 23, 2021
@eellison eellison moved this from To do to In Progress in JIT Performance Mar 23, 2021
@davidberard98
Copy link
Contributor

cc @eellison @Krovatkin is this fixed? I can't repro

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Performance
  
In Progress
Development

No branches or pull requests

5 participants