Skip to content

Documentation misleading for torch.jit.trace: not all torch.jit.trace will return a TracedModule object #21857

@haoransh

Description

@haoransh

📚 Documentation

The first cope snippet of https://pytorch.org/docs/stable/jit.html#torch.jit.ScriptModule shows an example to trace

def foo(x, y):
    return 2 * x + y

And immediately after this snippet, it says Tracing a function will construct a ScriptModule. However, this is not true. For a plain python function like the above example, torch.jit.trace will only return a torch._C.Function object, instead of a ScriptModule.

These can be proved easily by the following snippet.

import torch
import torchvision
def foo(x, y):
    return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
print(traced_foo)
print(isinstance(traced_foo, torch.jit.ScriptModule))
print(isinstance(traced_foo, torch.jit.TracedModule))

traced_net = torch.jit.trace(torchvision.models.resnet18(),
                             torch.rand(1, 3, 224, 224))
print(type(traced_net))
print(isinstance(traced_net, torch.jit.TracedModule))
print(isinstance(traced_net, torch.jit.ScriptModule))

The output is

<class 'torch._C.Function'>
False
False
<class 'torch.jit.TopLevelTracedModule'>
True
True

It's not intuitive that the output of torch.jit.trace will return objects of different classes. I think it's better if it could return torch.jit.TracedModule too for the plain python function, but obviously it's not the truth for current pytorch.

BTW, My pytorch version is 1.1.0.

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions