Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support callables in scripted functions #17126

Open
slayton58 opened this issue Feb 14, 2019 · 3 comments
Open

Support callables in scripted functions #17126

slayton58 opened this issue Feb 14, 2019 · 3 comments
Assignees
Labels
jit-backlog oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@slayton58
Copy link
Contributor

slayton58 commented Feb 14, 2019

馃殌 Feature

Allow for callables (notably other scripted functions or ScriptModules) to be passed to scripted functions (Currently fails with cannot call a value).

Motivation

Attempting to tease out fusions from larger codebases (in my case maskrcnn_benchmark) where some blocks of code may not be scriptable, but subsets of those blocks may be. In a concrete example, I have a module which contains loop over sets of inputs & modules, where z = module(x) + y is seen. module is itself a ScriptModule, and I'd like to be able to add the + to a fused group that's already being generated within that ScriptModule. However, there are several things that prevent the outermost module from being scripted. Pulling out the relevant code to a separate function like:

@torch.jit.script
def impl(m, x. y):
    return m(x) + y

Should allow the desired fusion, but is currently not supported.

Pitch

Allow callables (at least ScriptModules) to be passed as arguments to scripted functions and called.

Alternatives

My specific use-case disappears as the scripting supports a certain level of python, but I would expect there to be other use-cases.

Additional context

This is a proxy code I wrote to make sure that the error I was seeing wasn't from the larger application. This fails with cannot call a value.

import torch

class Bias(torch.jit.ScriptModule):
    def __init__(self, num_channels):
        super(Bias, self).__init__()

        self.bias = torch.nn.Parameter(torch.zeros(num_channels))

    @torch.jit.script_method
    def forward(self, x):
        return x + self.bias.reshape(1, -1, 1, 1)

@torch.jit.script
def fwd_impl(b, x, y):
    return b(x) + y

# Note: Not-scriptable in real app
class OuterModule(torch.nn.Module):
    def __init__(self, n):
        super(OuterModule, self).__init__()
  
        self.bias = Bias(n)

    def forward(self, x, y):
        return fwd_impl(self.bias, x, y)

n, c, h, w = 32, 4, 16, 16
x = torch.randn(n, c, h, w).cuda()
y = torch.randn(n, c, h, w).cuda()
m = OuterModule(c).cuda()

with torch.no_grad():
    z = m(x, y)

cc @suo

@eellison
Copy link
Contributor

Thanks for the detailed feature request! This is on our radar but not sure it will happen anytime soon.

@eellison eellison added jit-triaged oncall: jit Add this issue/PR to JIT oncall triage queue labels Feb 14, 2019
@suo
Copy link
Member

suo commented Feb 15, 2019

We are currently building support for user-defined types in TorchScript. So soon you will be able to implement a method on your class and pass it around in script functions.

@suo suo self-assigned this Feb 19, 2019
@suo suo removed the jit-triaged label Mar 11, 2019
@suo suo added the jit-backlog label Oct 3, 2019
@gkorland
Copy link

gkorland commented Oct 5, 2020

@suo what is the status of the support for user-defined types?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
jit-backlog oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

4 participants