You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Allow for callables (notably other scripted functions or ScriptModules) to be passed to scripted functions (Currently fails with cannot call a value).
Motivation
Attempting to tease out fusions from larger codebases (in my case maskrcnn_benchmark) where some blocks of code may not be scriptable, but subsets of those blocks may be. In a concrete example, I have a module which contains loop over sets of inputs & modules, where z = module(x) + y is seen. module is itself a ScriptModule, and I'd like to be able to add the + to a fused group that's already being generated within that ScriptModule. However, there are several things that prevent the outermost module from being scripted. Pulling out the relevant code to a separate function like:
@torch.jit.script
def impl(m, x. y):
return m(x) + y
Should allow the desired fusion, but is currently not supported.
Pitch
Allow callables (at least ScriptModules) to be passed as arguments to scripted functions and called.
Alternatives
My specific use-case disappears as the scripting supports a certain level of python, but I would expect there to be other use-cases.
Additional context
This is a proxy code I wrote to make sure that the error I was seeing wasn't from the larger application. This fails with cannot call a value.
import torch
class Bias(torch.jit.ScriptModule):
def __init__(self, num_channels):
super(Bias, self).__init__()
self.bias = torch.nn.Parameter(torch.zeros(num_channels))
@torch.jit.script_method
def forward(self, x):
return x + self.bias.reshape(1, -1, 1, 1)
@torch.jit.script
def fwd_impl(b, x, y):
return b(x) + y
# Note: Not-scriptable in real app
class OuterModule(torch.nn.Module):
def __init__(self, n):
super(OuterModule, self).__init__()
self.bias = Bias(n)
def forward(self, x, y):
return fwd_impl(self.bias, x, y)
n, c, h, w = 32, 4, 16, 16
x = torch.randn(n, c, h, w).cuda()
y = torch.randn(n, c, h, w).cuda()
m = OuterModule(c).cuda()
with torch.no_grad():
z = m(x, y)
We are currently building support for user-defined types in TorchScript. So soon you will be able to implement a method on your class and pass it around in script functions.
馃殌 Feature
Allow for callables (notably other scripted functions or ScriptModules) to be passed to scripted functions (Currently fails with
cannot call a value
).Motivation
Attempting to tease out fusions from larger codebases (in my case maskrcnn_benchmark) where some blocks of code may not be scriptable, but subsets of those blocks may be. In a concrete example, I have a module which contains loop over sets of inputs & modules, where
z = module(x) + y
is seen.module
is itself a ScriptModule, and I'd like to be able to add the+
to a fused group that's already being generated within that ScriptModule. However, there are several things that prevent the outermost module from being scripted. Pulling out the relevant code to a separate function like:Should allow the desired fusion, but is currently not supported.
Pitch
Allow callables (at least ScriptModules) to be passed as arguments to scripted functions and called.
Alternatives
My specific use-case disappears as the scripting supports a certain level of python, but I would expect there to be other use-cases.
Additional context
This is a proxy code I wrote to make sure that the error I was seeing wasn't from the larger application. This fails with
cannot call a value
.cc @suo
The text was updated successfully, but these errors were encountered: