Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force JIT to do type inference even when mypy annotated #39670

Open
ezyang opened this issue Jun 8, 2020 · 3 comments
Open

Force JIT to do type inference even when mypy annotated #39670

ezyang opened this issue Jun 8, 2020 · 3 comments
Assignees
Labels
module: typing Related to mypy type annotations oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects

Comments

@ezyang
Copy link
Contributor

ezyang commented Jun 8, 2020

In #38211 I discovered a very interesting aspect of TorchScript's type system, which is that it is staged. Suppose you have a constructor like this:

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

In a normal type system, you would have to declare weight/bias as Optional[Tensor], as it is not known at compile time what type they are (you must know the type of affine). However, in TorchScript, the type of the module is not ascribed until after the module has been constructed, at which point we do know what the type of affine is. So if you don't annotate weight/bias, TorchScript will properly infer None or Tensor as their types.

Here's the problem: if you annotate weight/bias for mypy, TorchScript will directly use this annotation, and no longer carry out its own type inference. And this in turn may cause code that previously typechecked in TorchScript to stop typechecking. This means that adding most general mypy signatures for modules can cause TorchScript to stop working.

There should be some way to force TorchScript to do type inference, even if there is an annotation. The relevant code is pretty easy to change, so the big question is what the syntax for triggering this type inference should be.

cc @gmagogsfm @ezyang @malfet @rgommers @xuzhao9 @gramster @suo

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jun 8, 2020
@suo suo added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 9, 2020
@walterddr
Copy link
Contributor

I was wondering what's the solution to this. an not sure if this is related - I tried to annotate the 3 conditionally registered parameter/buffer in #44535 but somehow it also complains something similar in test_jit.py

@ezyang ezyang added the module: typing Related to mypy type annotations label Sep 14, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Sep 14, 2020
@wanchaol wanchaol moved this from Need triage to In discussion in JIT Triage Sep 21, 2020
@rgommers
Copy link
Collaborator

the big question is what the syntax for triggering this type inference should be.

Let me make a suggestion to start the discussion. Two options come to mind:

(1) An object that can be applied to an individual annotation. Let's call it NoJit for now. Usage:

def forward(self, input: Tensor) -> NoJit[Tensor]:
    ....

Implementation along the lines of:

class _NoJit(_FinalTypingBase, _root=True):
    def __new__(cls, parameters=None, origin=None, *args, _root=False):
        # see an example in typing module for what to do here
        # will also need some other methods, such as `__eq__`, `__hash__`, `__repr__`

    def __getitem(self, parameters):
        if typing.TYPE_CHECKING:
            return parameters
        else:
            # mark type as ignored

NoJit = _NoJit(_root=True)

(2) a type comment like #tscript: ignore. Usage:

def forward(self, input: Tensor
    ) -> Tensor:  #tscript: ignore
    ....

The advantages of (2) are that it's similar to #type: ignore to make Mypy ignore something, and that it's probably easier to implement (disclaimer: that's based on my very limited experience fixing one bug in how the JIT deals with #type: ignore). The downside of (2) is that if you want to ignore only part of all annotations on a line, you have to break that line so you can apply the comment to only the part you want - which can look ugly.

This construct shouldn't be needed very often, so that (2) is a little uglier doesn't bother me. Assuming it is indeed the easier option to implement, maybe go with that?

@albertz
Copy link
Contributor

albertz commented Dec 30, 2021

Why not just?

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""

    ...
    if typing.TYPE_CHECKING:
        weight: Optional[Tensor]
        bias: Optional[Tensor]

    ...

Or does TorchScript also set TYPE_CHECKING?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: typing Related to mypy type annotations oncall: jit Add this issue/PR to JIT oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
JIT Triage
  
In discussion
Development

No branches or pull requests

7 participants