Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Tracing a script function/module where not all args are Tensors #14455

Open
nikhilmishra000 opened this issue Nov 28, 2018 · 14 comments
Open
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@nikhilmishra000
Copy link

nikhilmishra000 commented Nov 28, 2018

🐛 Bug

Suppose I write a script function/module where one argument is an int.

Then tracing a larger model that uses this script function will fail.

To Reproduce

@torch.jit.script
def foo(x, y:int):
    return x + y

def bar(x):
    return foo(x, 4)

x = torch.zeros(3)
foo(x, 4)     # this works for any `x` and `y`
bar(x)         # this works too
traced_bar = torch.jit.trace(bar, (x,)) # this errors

The tracing fails with ValueError: Auto nesting doesn't know how to process an input object of type int. Accepted types: Tensors, or lists/tuples of them

Environment

PyTorch version: 1.0.0a0+60e7d04
Is debug build: No
CUDA used to build PyTorch: 9.1.85

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.11.1

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti

Nvidia driver version: 390.77
cuDNN version: Probably one of the following:
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn.so.7.1.3
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] Could not collect
[conda] cuda91 1.0 h4c16780_0 pytorch
[conda] magma-cuda91 2.3.0 1 pytorch
[conda] torch 1.0.0a0+60e7d04
[conda] torchvision 0.2.1 py36_1 pytorch

@pytorchbot pytorchbot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Nov 28, 2018
@t-vi
Copy link
Collaborator

t-vi commented Nov 29, 2018

The snippet you show (edit: the original snippet that had no @torch.jit.script above) does not error for me.
What does error is traced_foo = torch.jit.trace(foo, (x,4)), but this is a fundamental limitation of tracing: You can't trace arbitrary Python objects, including numbers, and they will be constant to the traced functions. So the typical workaround is to trace a wrapper, as you do with bar.

@beichen2012
Copy link

The snippet you show does not error for me.
What does error is traced_foo = torch.jit.trace(foo, (x,4)), but this is a fundamental limitation of tracing: You can't trace arbitrary Python objects, including numbers, and they will be constant to the traced functions. So the typical workaround is to trace a wrapper, as you do with bar.

Trace a wrapper, but the innner function is also need a number(int, float...)。
Then, the error occurs too...

@t-vi
Copy link
Collaborator

t-vi commented Dec 10, 2018

For functions that should not happen. For modules, that is a different story...

@nikhilmishra000
Copy link
Author

nikhilmishra000 commented Dec 10, 2018

Apologies -- there should have been a @torch.jit.script decorator on foo() in my the original post (something weird happened with the formatting). I've updated the comment above.

@beichen2012
Copy link

Apologies -- there should have been a @torch.jit.script decorator on foo() my the original post (something weird happened with the formatting). I've updated the comment above.

I meet the same error...

@t-vi
Copy link
Collaborator

t-vi commented Dec 10, 2018

Ah, right. That is because @torch.jit.scripted functions are internally treated as modules.

Technically, it think it would be easy enough to allow more things to be passed to ScriptModules (that are already compiled(!)).It would seem reasonable to me for precisely your usecase.
I don't know the origin of the design decision not to allow them, but I imagine that it might be related to the fact that you don't want to allow modules that are currently scripted to take those. (As background: Python numbers are constants and the signature of the traced module would not match with that of the module before tracing, which would be bad.)

I'd venture it's a feature request, not a bug, but I'll happily admit that I think it's a very reasonable one and one that I have wished that works before.

@zdevito (or some other JIT expert) Would it sound reasonable to you to allow already compiled modules to be called with arbitrary arguments or at least allowing the common cases?

I used to have a good workaround with using defaults for the @torch.jit.script'ed function arguments, but I can't seem to get it to work. (Which could be me being particularly confused this before ☕ or it could be that some jit patching broke that.)

@apaszke
Copy link
Contributor

apaszke commented Dec 10, 2018

So to clarify the problem is that:

  • There's a scripted function that takes an int argument
  • This function is called in tracing, and the tracer complains that it cannot be called because it has an int argument?

I'm not super sold on any way, but raising an error is definitely the more conservative approach, that would let us avoid giving people a false sense of security that their int is not a constant because it's scripted! One workaround would be to do something like this:

@torch.jit.script
def my_fn(x: Tensor, y: int):
  ...

@torch.jit.script
def my_fn_5(x: Tensor):
    return my_fn(x, 5)

and use my_fn_5 in the trace.

@t-vi
Copy link
Collaborator

t-vi commented Dec 10, 2018

Personally, I think that it's awkward for scripted functions to behave substantially different here to what built-in functions would. (And if you have a few of the work-arounds in your model, it adds up quickly.)

@apaszke
Copy link
Contributor

apaszke commented Dec 10, 2018

Fair enough. That convinces me.

@ferrine
Copy link

ferrine commented Feb 21, 2019

I faced this issue in a similar context, trying to trace some operations involving script+trace
My usecase is like this.

@torch.jit.script
def dist(x, y, keepdim: bool=False):
    return torch.norm(x-y, p=2, dim=-1, keepdim=keepdim)

This seems to be trivial, but keepdim argument is important in my application, I would expect this torch script would work while tracing, but it does not and gives errors like above, complaining that not arguments are tensors

@ferrine
Copy link

ferrine commented Feb 21, 2019

What workaround should I use there?
I currently have two options:

  1. never use optional arguments and support only broadcasted behavior
    That contradicts pytorch api for reduction functions like sum, I do not like that...

  2. replace this with python runtime dispatch
    That is awkward, extremely

    def function(x, a, b, signed=False, keepdim=False):
        ...

    will result in 4 (!) function definitions and quite complex python dispatch

  3. remove torch script
    a lot of efficiency drop(

@t-vi
Copy link
Collaborator

t-vi commented Feb 21, 2019

  1. Fix PyTorch. 😄

@ferrine
Copy link

ferrine commented Feb 26, 2019

Any suggestions from pytorch devs maybe? I'm really stuck what to do with this code (example)

@t-vi
Copy link
Collaborator

t-vi commented Feb 26, 2019

If you need it fast, you might consider wrapping it in C++ and using a custom op. Or you - if it's not user facing - you could wrap the bools into tensors. It is something I'd latently want to fix and would expect not to be too hard, but I don't really know when I would get to it.

ferrine added a commit to geoopt/geoopt that referenced this issue Mar 2, 2019
ferrine added a commit to geoopt/geoopt that referenced this issue Mar 2, 2019
ferrine added a commit to geoopt/geoopt that referenced this issue Mar 2, 2019
ferrine added a commit to geoopt/geoopt that referenced this issue Mar 2, 2019
ferrine added a commit to geoopt/geoopt that referenced this issue Mar 3, 2019
@suo suo removed the jit-triaged label Mar 11, 2019
ferrine added a commit to geoopt/geoopt that referenced this issue Mar 17, 2019
ferrine added a commit to geoopt/geoopt that referenced this issue Mar 31, 2019
* add base

* add mobius add|sub

* fix

* missing formulas

* remove unused import

* add scalar mul, test props

* unnessesary cons in project

* no cover script functions

* add distance

* fix typo in comment

* add geodesics

* add expmap

* add functions

* add singlt apply

* black

* fix typos in docs

* fix typos in docs

* add parallel transport

* add dist to a plane and parallel transport. Parallel transport is numerically unstable

* fix math bugs

* add cool plots

* fix small things

* add egrad2rgrad

* add reference

* docs

* fix typos

* finish Poincare ball implementation

* fix small typo

* add to inifinite and beyond test

* add signed distance

* infinity and beyond test

* black

* docfix

* fix docs

* fix doc

* fix docs typos

* add import

* add dist0

* optim fails

* fix numerics, do not repare broken test

* black

* some refactoring

* fix typo

* p.data -> p in optim

* update docs a bit

* split pr

* remove torch script (it gave minor improvemets), delay to pytorch/pytorch#14455 resolution

* fix coadd impl

* coma typo in docs

* nan police float32

* nan police! arcsinh

* typo

* nan police scripted!\nwratpping artanh in a script function results in umstable behavior

* tests

* fix typo

* another test for parallel transport 0

* random doc fix to make typechecker happy

* manifold->module migration fix

* black

* fix test for poincare (autocast double)

* add float32 tests

* fix typo

* rename project->clip tangent

* docs

* fix side effect in tests

* infinity anb beyond test was failing in torch==1.0.1 but not in torch_nightly, acceptable tolerance differs

* add dim argument for poincare math

* batched matvec

* typo in dist formula

* fix tracing issues and grad numerics for Arsinh,Artanh

* _max_norm, specify device + dtype

* clamp before save to backward in artanh

* inplace ops in function impl

* black

* fix typo

* fix spelling

* some fixes to docs

* euclidean -> Euclidean

* black

* math font for number

* random travis fail?

* pytorch future reminder
andbloch pushed a commit to andbloch/geoopt that referenced this issue Dec 29, 2019
* add base

* add mobius add|sub

* fix

* missing formulas

* remove unused import

* add scalar mul, test props

* unnessesary cons in project

* no cover script functions

* add distance

* fix typo in comment

* add geodesics

* add expmap

* add functions

* add singlt apply

* black

* fix typos in docs

* fix typos in docs

* add parallel transport

* add dist to a plane and parallel transport. Parallel transport is numerically unstable

* fix math bugs

* add cool plots

* fix small things

* add egrad2rgrad

* add reference

* docs

* fix typos

* finish Poincare ball implementation

* fix small typo

* add to inifinite and beyond test

* add signed distance

* infinity and beyond test

* black

* docfix

* fix docs

* fix doc

* fix docs typos

* add import

* add dist0

* optim fails

* fix numerics, do not repare broken test

* black

* some refactoring

* fix typo

* p.data -> p in optim

* update docs a bit

* split pr

* remove torch script (it gave minor improvemets), delay to pytorch/pytorch#14455 resolution

* fix coadd impl

* coma typo in docs

* nan police float32

* nan police! arcsinh

* typo

* nan police scripted!\nwratpping artanh in a script function results in umstable behavior

* tests

* fix typo

* another test for parallel transport 0

* random doc fix to make typechecker happy

* manifold->module migration fix

* black

* fix test for poincare (autocast double)

* add float32 tests

* fix typo

* rename project->clip tangent

* docs

* fix side effect in tests

* infinity anb beyond test was failing in torch==1.0.1 but not in torch_nightly, acceptable tolerance differs

* add dim argument for poincare math

* batched matvec

* typo in dist formula

* fix tracing issues and grad numerics for Arsinh,Artanh

* _max_norm, specify device + dtype

* clamp before save to backward in artanh

* inplace ops in function impl

* black

* fix typo

* fix spelling

* some fixes to docs

* euclidean -> Euclidean

* black

* math font for number

* random travis fail?

* pytorch future reminder
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants