Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime Error thrown when using Optimizer in a Pytorch Function: element 0 of tensors does not require grad and does not have a grad_fn #8847

Closed
sbarratt opened this issue Jun 25, 2018 · 2 comments
Labels
todo Not as important as medium or high priority tasks, but we will work on these.

Comments

@sbarratt
Copy link

Issue description

When I am trying to run an optimizer on a separate variable inside a pytorch function, pytorch throws an error, element 0 of tensors does not require grad and does not have a grad_fn.

Code example

import numpy as np

from torch.autograd import Function, Variable
import torch

N, n = 100, 2
X = np.random.randn(N, n)
y = np.random.randint(0,2,size=N)
Xv = Variable(torch.Tensor(X), requires_grad=True)
yv = Variable(torch.Tensor(y), requires_grad=True)

def solve_logistic_regression(X, y, lamb):
    N, n = X.shape
    theta = Variable(torch.ones(n), requires_grad=True)
    optimizer = torch.optim.LBFGS([theta], lr=.8)
    def closure():
        optimizer.zero_grad()
        pi = 1./(1.+torch.exp(-X.mm(theta.unsqueeze(-1))))
        loss = 1./N*torch.nn.BCELoss()(pi.squeeze(), y) + lamb/2*torch.norm(theta[:-1])**2
        print (loss.item())
        loss.backward()
        return loss
    optimizer.step(closure)
    return theta
    
class LogisticRegression(Function):
    @staticmethod
    def forward(ctx, X, y, lamb):
        theta = solve_logistic_regression(X, y, lamb)
        return 0
        
    @staticmethod
    def backward(ctx, grad_output):
        return None, None, None
lr = LogisticRegression.apply

# this works
solve_logistic_regression(Xv.detach(),yv.detach(),.1)

# this doesn't work
lr(Xv.detach(),yv.detach(),.1)

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-26-618cf3a429a3> in <module>()
----> 1 lr(Xv.detach(),yv.detach(),.1)

<ipython-input-24-5b3d43f3448b> in forward(ctx, X, y, lamb)
     16     @staticmethod
     17     def forward(ctx, X, y, lamb):
---> 18         theta = solve_logistic_regression(X, y, lamb)
     19         return 0
     20 

<ipython-input-24-5b3d43f3448b> in solve_logistic_regression(X, y, lamb)
     10         loss.backward()
     11         return loss
---> 12     optimizer.step(closure)
     13     return theta
     14 

~/anaconda/lib/python3.6/site-packages/torch/optim/lbfgs.py in step(self, closure)
    101 
    102         # evaluate initial f(x) and df/dx
--> 103         orig_loss = closure()
    104         loss = float(orig_loss)
    105         current_evals = 1

<ipython-input-24-5b3d43f3448b> in closure()
      8         loss = 1./N*torch.nn.BCELoss()(pi.squeeze(), y) + lamb/2*torch.norm(theta[:-1])**2
      9         print (loss.item())
---> 10         loss.backward()
     11         return loss
     12     optimizer.step(closure)

~/anaconda/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
     91                 products. Defaults to ``False``.
     92         """
---> 93         torch.autograd.backward(self, gradient, retain_graph, create_graph)
     94 
     95     def register_hook(self, hook):

~/anaconda/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     87     Variable._execution_engine.run_backward(
     88         tensors, grad_tensors, retain_graph, create_graph,
---> 89         allow_unreachable=True)  # allow_unreachable flag
     90 
     91 

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

System Info

PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.12.6
GCC version: Could not collect
CMake version: version 3.6.3

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy (1.14.4)
[pip3] numpydoc (0.6.0)
[pip3] torch (0.4.0)
[pip3] torchaudio (0.1)
[pip3] torchvision (0.2.0)
[conda] pytorch                   0.3.0           py36_cuda0.0_cudnn0.0h57b1bc9_4    pytorch
[conda] torch                     0.4.0                     <pip>
[conda] torch                     0.2.0+cd9b272             <pip>
[conda] torchaudio                0.1                       <pip>
[conda] torchvision               0.2.0            py36hf5eb7ec_1    pytorch
@sbarratt sbarratt changed the title Runtime Erorr thrown when using Optimizer in a Pytorch Function: element 0 of tensors does not require grad and does not have a grad_fn Runtime Error thrown when using Optimizer in a Pytorch Function: element 0 of tensors does not require grad and does not have a grad_fn Jun 25, 2018
@zou3519 zou3519 added the todo Not as important as medium or high priority tasks, but we will work on these. label Jun 25, 2018
@t-vi
Copy link
Collaborator

t-vi commented Jun 29, 2018

I think all you need to do is to wrap the call of the function (in the autograd.Function forward):

        with torch.enable_grad():
            theta = solve_logistic_regression(X, y, lamb)

...oh, and return something more reasonable than 0.

@fmassa
Copy link
Member

fmassa commented Jul 3, 2018

Closing following @t-vi comment

@fmassa fmassa closed this as completed Jul 3, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
todo Not as important as medium or high priority tasks, but we will work on these.
Projects
None yet
Development

No branches or pull requests

4 participants