Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FX] intermediate types of empty lists/dicts not preserved during torch.fx tracing #49935

Open
esqu1 opened this issue Dec 29, 2020 · 5 comments
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue TSRootCause:DefaultTypes TSUsability
Projects

Comments

@esqu1
Copy link

esqu1 commented Dec 29, 2020

馃悰 Bug

On occasions we may want to pass, for example, an empty list to a leaf function. In TorchScript, these empty lists are assumed to have type List[torch.Tensor], as most types in TorchScript (Tensor type defaulting behavior). In order to get around this defaulting behavior, we can either:

  1. Annotate the type of the variable: var1: List[str] = []
  2. Use torch.jit.annotate to notify the TorchScript compiler: var1 = torch.jit.annotate(List[str], [])

However, during torch.fx, neither of these methods will annotate the variables in the resulting GraphModule. Thus, TorchScript will always assume that these are List[Tensor] and may fail to compile the resulting module.

To Reproduce

Suppose that my_identity is a custom op with the following signature:

c10::List<std::string> identity(c10::List<std::string> x, at::Tensor& t) {
  return x;
}

TORCH_LIBRARY(my_ops, m) {
  m.def("identity", &identity);
}

Then use this in a module and trace it:

import torch
from typing import List
from torch.fx import symbolic_trace

class TestModule(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = torch.ops.my_ops.identity(
            torch.jit.annotate(List[str], []),
            x
        )
        return out

graph_module = symbolic_trace(TestModule())
print(graph_module)

yielding:

import torch
def forward(self, x : torch.Tensor):
    identity = torch.ops.my_ops.identity([], position)
    return identity

Finally, run the graph module through TorchScript via torch.jit.script(graph_module).

Expected behavior

TorchScript should ideally compile fine.

Actual behavior

The TorchScript compiler complains about the empty list:

my_ops::identity(str[] x, Tensor t) -> str[]:
Expected a value of type 'List[str]' for argument 'x' but instead found type 'List[Tensor]'.
Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)
:
import torch
def forward(self, x : torch.Tensor):
    identity = torch.ops.my_ops.identity([], position)
               ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return identity


cc @gmagogsfm
@ansley ansley self-assigned this Jan 5, 2021
@ansley ansley added oncall: jit Add this issue/PR to JIT oncall triage queue and removed fx labels Jan 12, 2021
@github-actions github-actions bot added this to Need triage in JIT Triage Jan 12, 2021
@ansley
Copy link

ansley commented Jan 12, 2021

After doing a fair bit of research, it looks like this isn't something we can change from the FX side. However, there is a fix on frontend. We need to make it so that the frontend doesn't automatically type empty lists as List[Tensor] and empty dicts as Dict[str, Tensor]. Let's discuss if this can (or should) happen now, as it would help unblock some FX users. I believe @penguinwu mentioned this as a project as well.

@gmagogsfm
Copy link
Contributor

Could you fill in some details on why this isn't viable by changing FX side? I think James mentioned something about preserving annotations, is that infeasible?

@ansley
Copy link

ansley commented Jan 19, 2021

@gmagogsfm It's lot harder than you'd think to preserve existing annotations in this case. In Python, functions, methods, modules, and class objects store some of their annotations, which can be retrieved by __annotations__ or typing.get_type_hints. However, functions/methods only store the annotations for their args and return. Storing variable annotations that are also in the function scope was rejected since it鈥檚 such an expensive operation (link).

The first thing I thought of doing was walking back up the stack, getting the calling frame, and examining the context of the code responsible for that frame. I ran into a lot of problems with this, though. It required me to make some uncomfortable assumptions about the code.

I eventually came up with a solution that involved AST rewrites and a custom tracer. (I can explain my design more if you're interested.) Unfortunately, it had an awful time complexity. James and Zach discussed the issue, and we eventually came to the conclusion that this is not a feature that we should pursue from the FX side.

@SplitInfinity
Copy link

A potential solution for this is to improve JIT type inference to make smarter decisions about types of lists and dicts.

@gmagogsfm
Copy link
Contributor

If I remember correctly, @ansley is working on this. Should it be moved out of "in discussion"?

@ansley ansley moved this from In discussion to In progress in JIT Triage Feb 24, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue TSRootCause:DefaultTypes TSUsability
Projects
JIT Triage
  
In progress
Development

No branches or pull requests

6 participants