Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] Add support for '.' characters in TORCH_EXTENSION_NAME #13

Closed
nkolot opened this issue Jun 21, 2018 · 8 comments
Closed

Comments

@nkolot
Copy link

nkolot commented Jun 21, 2018

I was trying to build a CUDAExtension, where the extension name was containing dots. The reason for that is that I wanted to install it as a submodule of the module I am building. For example, I have a top-level module foo and I want my CUDA Extension to be foo.bar.
There is a related discussion also here.

This is currently not possible using CUDAExtension and I can't think of a workaround right now. The obvious reason here is that the macros don't allow for '.' characters. I am attaching the output of the build process. A fast way to replicate this is to change the module name from lltm_cuda to foo.cuda in your CUDA lltm example.

In file included from /usr/local/lib/python3.5/dist-packages/torch/lib/include/pybind11/pytypes.h:12:0,
                 from /usr/local/lib/python3.5/dist-packages/torch/lib/include/pybind11/cast.h:13,
                 from /usr/local/lib/python3.5/dist-packages/torch/lib/include/pybind11/attr.h:13,
                 from /usr/local/lib/python3.5/dist-packages/torch/lib/include/pybind11/pybind11.h:43,
                 from /usr/local/lib/python3.5/dist-packages/torch/lib/include/torch/torch.h:6,
                 from neural_renderer/cuda/load_textures_cuda.cpp:1:
<command-line>:0:37: error: expected initializer before '.' token
/usr/local/lib/python3.5/dist-packages/torch/lib/include/pybind11/detail/common.h:212:47: note: in definition of macro 'PYBIND11_CONCAT'
 #define PYBIND11_CONCAT(first, second) first##second
                                               ^
neural_renderer/cuda/load_textures_cuda.cpp:33:1: note: in expansion of macro 'PYBIND11_MODULE'
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 ^
neural_renderer/cuda/load_textures_cuda.cpp:33:17: note: in expansion of macro 'TORCH_EXTENSION_NAME'
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
                 ^
<command-line>:0:37: error: expected initializer before '.' token
/usr/local/lib/python3.5/dist-packages/torch/lib/include/pybind11/detail/common.h:171:51: note: in definition of macro 'PYBIND11_PLUGIN_IMPL'
     extern "C" PYBIND11_EXPORT PyObject *PyInit_##name()
                                                   ^
neural_renderer/cuda/load_textures_cuda.cpp:33:1: note: in expansion of macro 'PYBIND11_MODULE'
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 ^
neural_renderer/cuda/load_textures_cuda.cpp:33:17: note: in expansion of macro 'TORCH_EXTENSION_NAME'
 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
                 ^
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1
  • OS: Ubuntu 16.04
  • PyTorch version: 0.4.0
  • How you installed PyTorch: source
  • Python version: 3.5.2
  • CUDA/cuDNN version: 9.0
@nkolot
Copy link
Author

nkolot commented Jun 21, 2018

I actually found a workaround for now. If you want to name your module "foo.bar" you can use the top-level module name foo as the extension name and put the submodule names in the module def(), as shown here:

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("bar.myfunction", &myfunction_cuda, "MY_FUNCTION");
}

@soumith
Copy link
Member

soumith commented Jun 21, 2018

cc: @goldsborough

@goldsborough
Copy link
Contributor

Hi @nkolot, this is a known issue with PYBIND11_MODULE. I thought this was fixed by @fmassa in pytorch/pytorch#6986, where we fixed/hacked it by only making the last part of the dotted extension the actual extension name. I believe setuptools creates the directory structure to still make the extension importable with the full path. Maybe change TORCH_EXTENSION_NAME to just bar? Could you try that? @fmassa do you know why this doesn't work after your fix?

@fmassa
Copy link
Member

fmassa commented Jun 21, 2018

@nkolot are you sure you installed PyTorch from source? The fix @goldsborough mentioned is not present in pytorch 0.4 (which is what we have for the binaries)
Also, can you check that your installation has the fix that @goldsborough mentioned?

@nkolot
Copy link
Author

nkolot commented Jun 21, 2018

Yes, PyTorch is compiled from source.
I can confirm that the fix @goldsborough is working.
So in my setup.py file I have

ext_modules = [CUDAExtension('foo.bar', ['bar_cuda.cpp', 'bar_cuda_kernel.cu']),]

and in bar_cuda.cpp

PYBIND11_MODULE(bar, m) {
    m.def("myfunc", &my_func, "MY_FUNC");
}

It even works with deeper nesting, i.e. foo.bar.foobar, as long as I put the last part in PYBIND.

@goldsborough
Copy link
Contributor

Great, thanks for confirming. Think we can close this then.

@fmassa
Copy link
Member

fmassa commented Jun 21, 2018

@nkolot I still didn't quite understand... is pytorch/pytorch#6986 present in your installation? That patch is supposed to be doing what @goldsborough mentioned.

@nkolot
Copy link
Author

nkolot commented Jun 21, 2018

It seems that I compiled PyTorch before the fix.

In [1]: import torch
In [2]: torch.__version__
Out[2]: '0.4.0a0+200fb22'
In [3]: from torch.utils.cpp_extension import BuildExtension
In [4]: import inspect
In [5]: print(inspect.getsource(BuildExtension._define_torch_extension_name))
    def _define_torch_extension_name(self, extension):
        define = '-DTORCH_EXTENSION_NAME={}'.format(extension.name)
        if isinstance(extension.extra_compile_args, dict):
            for args in extension.extra_compile_args.values():
                args.append(define)
        else:
            extension.extra_compile_args.append(define)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants