New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JIT] Tracing a script function/module where not all args are Tensors #14455
Comments
The snippet you show (edit: the original snippet that had no |
Trace a wrapper, but the innner function is also need a number(int, float...)。 |
For functions that should not happen. For modules, that is a different story... |
Apologies -- there should have been a |
I meet the same error... |
Ah, right. That is because Technically, it think it would be easy enough to allow more things to be passed to ScriptModules (that are already compiled(!)).It would seem reasonable to me for precisely your usecase. I'd venture it's a feature request, not a bug, but I'll happily admit that I think it's a very reasonable one and one that I have wished that works before. @zdevito (or some other JIT expert) Would it sound reasonable to you to allow already compiled modules to be called with arbitrary arguments or at least allowing the common cases? I used to have a good workaround with using defaults for the |
So to clarify the problem is that:
I'm not super sold on any way, but raising an error is definitely the more conservative approach, that would let us avoid giving people a false sense of security that their @torch.jit.script
def my_fn(x: Tensor, y: int):
...
@torch.jit.script
def my_fn_5(x: Tensor):
return my_fn(x, 5) and use |
Personally, I think that it's awkward for scripted functions to behave substantially different here to what built-in functions would. (And if you have a few of the work-arounds in your model, it adds up quickly.) |
Fair enough. That convinces me. |
I faced this issue in a similar context, trying to trace some operations involving script+trace @torch.jit.script
def dist(x, y, keepdim: bool=False):
return torch.norm(x-y, p=2, dim=-1, keepdim=keepdim) This seems to be trivial, but keepdim argument is important in my application, I would expect this torch script would work while tracing, but it does not and gives errors like above, complaining that not arguments are tensors |
What workaround should I use there?
|
|
Any suggestions from pytorch devs maybe? I'm really stuck what to do with this code (example) |
If you need it fast, you might consider wrapping it in C++ and using a custom op. Or you - if it's not user facing - you could wrap the bools into tensors. It is something I'd latently want to fix and would expect not to be too hard, but I don't really know when I would get to it. |
* add base * add mobius add|sub * fix * missing formulas * remove unused import * add scalar mul, test props * unnessesary cons in project * no cover script functions * add distance * fix typo in comment * add geodesics * add expmap * add functions * add singlt apply * black * fix typos in docs * fix typos in docs * add parallel transport * add dist to a plane and parallel transport. Parallel transport is numerically unstable * fix math bugs * add cool plots * fix small things * add egrad2rgrad * add reference * docs * fix typos * finish Poincare ball implementation * fix small typo * add to inifinite and beyond test * add signed distance * infinity and beyond test * black * docfix * fix docs * fix doc * fix docs typos * add import * add dist0 * optim fails * fix numerics, do not repare broken test * black * some refactoring * fix typo * p.data -> p in optim * update docs a bit * split pr * remove torch script (it gave minor improvemets), delay to pytorch/pytorch#14455 resolution * fix coadd impl * coma typo in docs * nan police float32 * nan police! arcsinh * typo * nan police scripted!\nwratpping artanh in a script function results in umstable behavior * tests * fix typo * another test for parallel transport 0 * random doc fix to make typechecker happy * manifold->module migration fix * black * fix test for poincare (autocast double) * add float32 tests * fix typo * rename project->clip tangent * docs * fix side effect in tests * infinity anb beyond test was failing in torch==1.0.1 but not in torch_nightly, acceptable tolerance differs * add dim argument for poincare math * batched matvec * typo in dist formula * fix tracing issues and grad numerics for Arsinh,Artanh * _max_norm, specify device + dtype * clamp before save to backward in artanh * inplace ops in function impl * black * fix typo * fix spelling * some fixes to docs * euclidean -> Euclidean * black * math font for number * random travis fail? * pytorch future reminder
* add base * add mobius add|sub * fix * missing formulas * remove unused import * add scalar mul, test props * unnessesary cons in project * no cover script functions * add distance * fix typo in comment * add geodesics * add expmap * add functions * add singlt apply * black * fix typos in docs * fix typos in docs * add parallel transport * add dist to a plane and parallel transport. Parallel transport is numerically unstable * fix math bugs * add cool plots * fix small things * add egrad2rgrad * add reference * docs * fix typos * finish Poincare ball implementation * fix small typo * add to inifinite and beyond test * add signed distance * infinity and beyond test * black * docfix * fix docs * fix doc * fix docs typos * add import * add dist0 * optim fails * fix numerics, do not repare broken test * black * some refactoring * fix typo * p.data -> p in optim * update docs a bit * split pr * remove torch script (it gave minor improvemets), delay to pytorch/pytorch#14455 resolution * fix coadd impl * coma typo in docs * nan police float32 * nan police! arcsinh * typo * nan police scripted!\nwratpping artanh in a script function results in umstable behavior * tests * fix typo * another test for parallel transport 0 * random doc fix to make typechecker happy * manifold->module migration fix * black * fix test for poincare (autocast double) * add float32 tests * fix typo * rename project->clip tangent * docs * fix side effect in tests * infinity anb beyond test was failing in torch==1.0.1 but not in torch_nightly, acceptable tolerance differs * add dim argument for poincare math * batched matvec * typo in dist formula * fix tracing issues and grad numerics for Arsinh,Artanh * _max_norm, specify device + dtype * clamp before save to backward in artanh * inplace ops in function impl * black * fix typo * fix spelling * some fixes to docs * euclidean -> Euclidean * black * math font for number * random travis fail? * pytorch future reminder
🐛 Bug
Suppose I write a script function/module where one argument is an int.
Then tracing a larger model that uses this script function will fail.
To Reproduce
The tracing fails with
ValueError: Auto nesting doesn't know how to process an input object of type int. Accepted types: Tensors, or lists/tuples of them
Environment
PyTorch version: 1.0.0a0+60e7d04
Is debug build: No
CUDA used to build PyTorch: 9.1.85
OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.11.1
Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti
GPU 3: GeForce GTX 1080 Ti
Nvidia driver version: 390.77
cuDNN version: Probably one of the following:
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn.so.7.1.3
/usr/local/cuda-9.1/targets/x86_64-linux/lib/libcudnn_static.a
Versions of relevant libraries:
[pip] Could not collect
[conda] cuda91 1.0 h4c16780_0 pytorch
[conda] magma-cuda91 2.3.0 1 pytorch
[conda] torch 1.0.0a0+60e7d04
[conda] torchvision 0.2.1 py36_1 pytorch
The text was updated successfully, but these errors were encountered: