Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jit] requires_grad in JIT constructor/factories #19393

Open
suo opened this issue Apr 18, 2019 · 9 comments
Open

[jit] requires_grad in JIT constructor/factories #19393

suo opened this issue Apr 18, 2019 · 9 comments
Labels
high priority oncall: jit Add this issue/PR to JIT oncall triage queue triage review
Projects

Comments

@suo
Copy link
Member

suo commented Apr 18, 2019

Something like:

        t = torch.ones(2, requires_grad=True)

fails with:

RuntimeError:
arguments for call are not valid:

  for operator aten::ones(int[] size, *, Tensor out) -> Tensor:
  argument out not provided.
  @torch.jit.script
  def foo():
      t = torch.ones(2, requires_grad=True)
          ~~~~~~~~~~ <--- HERE

  for operator aten::ones(int[] size, *, int? dtype=<default>, int? layout=<default>, Device? device=<default
>) -> Tensor:
  keyword argument requires_grad unknown
for call at:
@torch.jit.script
def foo():
    t = torch.ones(2, requires_grad=True)
        ~~~~~~~~~~ <--- HERE

We also can't do:

@torch.jit.script
def foo():
    t = torch.ones(2)
    t.requires_grad_(True)

I assume this had to do with lack of mutability annotations before? @zdevito are there any other difficulties you could see?

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @gmagogsfm @suo

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Apr 18, 2019
@zdevito
Copy link
Contributor

zdevito commented Apr 18, 2019

I believe it was due to lack of mutability of tensors, and the fact that you will have to update the requires_grad propagation pass appropriately given alias annotations.

@suo suo self-assigned this Apr 18, 2019
@eellison
Copy link
Contributor

i did the initial implementation of torch.tensor, i can do that

@ezyang
Copy link
Contributor

ezyang commented Aug 19, 2020

Another reason for this is the schema for factories doesn't actually include requires_grad; it should! This would also entail changing the calling convention for functions.

@pritamdamania87
Copy link
Contributor

@suo Was wondering if there are any updates on this issue? It is currently blocking #28786 which has been a long standing issue on the RPC side.

@suo suo added this to Need triage in JIT Triage via automation Jan 14, 2021
@gmagogsfm
Copy link
Contributor

Is this still a problem? I tried following code and seems to be working fine by using requires_grad_

import torch

@torch.jit.script
def fn():
    t1 = torch.ones(2)
    t2 = torch.ones(2)
    t1.requires_grad_(True)
    t2.requires_grad_(False)
    return (t1, t2)


print(fn())

Output:

(tensor([1., 1.], requires_grad=True), tensor([1., 1.]))

@SplitInfinity SplitInfinity moved this from Need triage to HIGH PRIORITY in JIT Triage Jan 19, 2021
@gmagogsfm
Copy link
Contributor

Is this still a problem? I tried following code and seems to be working fine by using requires_grad_

import torch

@torch.jit.script
def fn():
    t1 = torch.ones(2)
    t2 = torch.ones(2)
    t1.requires_grad_(True)
    t2.requires_grad_(False)
    return (t1, t2)


print(fn())

Output:

(tensor([1., 1.], requires_grad=True), tensor([1., 1.]))

@pritamdamania Could you comment if this is still blocking #28786 ?

@pritamdamania87
Copy link
Contributor

@gmagogsfm This is still an issue and the code to repro is:

import torch

@torch.jit.script
def foo():
    t = torch.ones(2, requires_grad=True)

foo()

Note that this only fails when requires_grad is passed as an arg to torch.ones. In your example above, you are setting requires_grad separately and thats why it works.

@gmagogsfm
Copy link
Contributor

@gmagogsfm This is still an issue and the code to repro is:

import torch

@torch.jit.script
def foo():
    t = torch.ones(2, requires_grad=True)

foo()

Note that this only fails when requires_grad is passed as an arg to torch.ones. In your example above, you are setting requires_grad separately and thats why it works.

Sorry for not being clear enough. What I am really offering is a work-around that seems easy to use, is it sufficient to unblock your issue?

@pritamdamania87
Copy link
Contributor

Sorry for not being clear enough. What I am really offering is a work-around that seems easy to use, is it sufficient to unblock your issue?

@gmagogsfm The workaround isn't a great user experience and especially when we use it in the RPC framework to invoke an op remotely:

rpc.remote(dst, torch.ones, args=((3, 3),), kwargs={"requires_grad":True})

The workaround would involve multiple RPCs to the same node, one for actually creating the tensor and another for setting the requires_grad field. Another alternative would be to create a custom user function to this and just call RPC on the remote end, but this isn't a great user experience since the user has to create a custom function each time they want to create a Tensor remotely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: jit Add this issue/PR to JIT oncall triage queue triage review
Projects
JIT Triage
  
HIGH PRIORITY
Development

No branches or pull requests

7 participants