Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trace more than one function in a module #19070

Closed
zdevito opened this issue Apr 9, 2019 · 4 comments
Closed

Trace more than one function in a module #19070

zdevito opened this issue Apr 9, 2019 · 4 comments
Assignees
Labels
feature A request for a proper, new feature. oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@zdevito
Copy link
Contributor

zdevito commented Apr 9, 2019

馃殌 Feature

Motivation

Pitch

Users are running into issues where they are tracing a non-forward method on a module, and getting errors because the weights used in the trace are either being considered constants, or they are autograd recording tensors and the tracer refuses to handle them. The trace API already understands how to trace a forward method correctly capturing weights. It would be easy to extend the API so that in the general case you can trace multiple methods of a module to create a single ScriptModule with multiple methods.

import torch
import torch.nn as nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 3, 3)

    def forward(self, x):
        return self.conv(x)

    def weighted_kernel_sum(self, weight):  # I want to trace this thing
        return (weight * self.conv.weight).sum()

n = Net()
traced_forward = torch.jit.trace(n, example_forward_input)

traced_weight_kernel_sum = torch.jit.trace(n.weighted_kernel_sum, example_weight)
# current: error constants are requiring gradients, or the weights are captured as constants

# proposed generic API:
fully_traced = torch.jit.trace(n, { 'forward' : example_forward_input, 'weighted_kernel_sum': example_weight})
# fully_traced has both forward and weighted_kernel_sum present

# syntax sugar for the old behavior:
m = torch.jit.trace(n, example_input) # --> torch.jit.trace(n, {'foward': example_input})
#syntax sugar for tracing a single method:
m = torch.jit.trace(n.weighted_kernel_sum, example_weight # --> torch.jit.trace(n.weighted_kernel_sum.__self__, {'weighted_kernel_sum': n.weighted_kernel_sum.__name__}) 
@eellison eellison added oncall: jit Add this issue/PR to JIT oncall triage queue feature A request for a proper, new feature. triage review labels Apr 9, 2019
@Krovatkin Krovatkin self-assigned this Apr 17, 2019
@zdevito zdevito added this to the 1.1 milestone Apr 17, 2019
@zdevito
Copy link
Contributor Author

zdevito commented Apr 17, 2019

Adding 1.1 milestone. The issues causing this were introduced in 1.0, revealing a lot of cases where people were inadvertently (but successfully) capturing parameters as constants. Given the number of reports, we need to make sure this is fixed for the 1.1 release.

@soumith
Copy link
Member

soumith commented May 1, 2019

has this been fixed?

@Krovatkin
Copy link
Contributor

@soumith yes, fixed in #19905

@liyuanyaun
Copy link

but how use functions of the trace module in libtorch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

6 participants