Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Aligning functional/pil types with functional/tensor #4323

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

oke-aditya
Copy link
Contributor

Tries to solve #4282

I think it won't be possible till Union is fully supported by torchscript.
I will update as I try.

@oke-aditya
Copy link
Contributor Author

I'm trying to use the Union typing to align. But JIT doesn't seem to be happy still.

@oke-aditya
Copy link
Contributor Author

Hi @datumbox @pmeier how do we go ahead with this?

@oke-aditya oke-aditya marked this pull request as ready for review September 22, 2021 19:34
@pmeier
Copy link
Collaborator

pmeier commented Sep 23, 2021

It seems torchscript cannot deal with Tuple[float, ...]. To equalize the annotations we would thus need to remove it. Thoughts? cc @datumbox

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Sep 23, 2021

Torchscript can deal with Tuple[T0, T1, ..., TN] | A tuple containing subtypes T0, T1, etc. (e.g. Tuple[Tensor, Tensor])
I think the Tuple should be finite. Should I refactor accordingly?

https://pytorch.org/docs/master/jit_language_reference.html#supported-type

@pmeier
Copy link
Collaborator

pmeier commented Sep 23, 2021

The problem is that the length of fill needs to match the number of channels of the input. Since this is not fixed, we cannot type it with a fixed number of items.

@datumbox
Copy link
Contributor

It seems torchscript cannot deal with Tuple[float, ...]. To equalize the annotations we would thus need to remove it. Thoughts? cc @datumbox

Ah, the JIT limitations... Using Tuple[T, ...] is useful when we have default arguments because they are immutable. Replacing them with Lists requires extra code/care. But if they not supported by JIT then we have no option but to avoid using them. :(

@pmeier
Copy link
Collaborator

pmeier commented Sep 27, 2021

Hm, it seems we are missing something about the Union support:

from typing import Union, List

import torch
from torch import jit


def try_script(fn):
    name = fn.__name__
    try:
        jit.script(fn)
    except RuntimeError as error:
        print(f"Scripting of {name} failed with\n{error}")
    else:
        print(f"Scripting of {name} was successful!")
    print("#" * 80)


def foo(data: float) -> torch.Tensor:
    return torch.tensor(data)


def bar(data: List[float]) -> torch.Tensor:
    return torch.tensor(data)


def baz(data: Union[float, List[float]]) -> torch.Tensor:
    return torch.tensor(data)


try_script(foo)
try_script(bar)
try_script(baz)
Scripting of foo was successful!
################################################################################
Scripting of bar was successful!
################################################################################
Scripting of baz failed with

Arguments for call are not valid.
The following variants are available:
  
  aten::tensor.float(float t, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor):
  Expected a value of type 'float' for argument 't' but instead found type 'Union[List[float], float]'.
  
  aten::tensor.int(int t, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor):
  Expected a value of type 'int' for argument 't' but instead found type 'Union[List[float], float]'.
  
  aten::tensor.bool(bool t, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor):
  Expected a value of type 'bool' for argument 't' but instead found type 'Union[List[float], float]'.
  
  aten::tensor.complex(complex t, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor):
  Expected a value of type 'complex' for argument 't' but instead found type 'Union[List[float], float]'.
  
  aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor):
  Could not match type Union[List[float], float] to List[t] in argument 'data': Cannot match List[t] to Union[List[float], float].

The original call is:
  File "/home/philip/git/pytorch/torchvision/main.py", line 25
def baz(data: Union[float, List[float]]) -> torch.Tensor:
    return torch.tensor(data)
           ~~~~~~~~~~~~ <--- HERE

################################################################################

Thoughts? cc @ansley

@ansley
Copy link

ansley commented Sep 27, 2021

This is a limitation of our statically-typed language: T1 <: Union[T1, T2] and T2 <: Union[T1, T2], but we can't refine a value typed to be Union[T1, T2] into, say, T1 without either a) additional type inference information, or b) runtime values. Because our compiler frontend uses schema matching to uniquely identify a function overload (what you see above under "The following variants are available") during IR emission, we can't necessarily analyze the whole program to determine which variant we should use.

Type inference can be triggered with a conditional assert statement like assert isinstance(x, T1) or if isinstance(x, T1).

...Which all sounds reasonable enough, but, stupidly, you end up writing code like this to be type safe:

import torch
from typing import Union, List

@torch.jit.script
def fn(data: Union[float, List[float]]) -> torch.Tensor:
    if isinstance(data, float):
        # Matches to the `torch.tensor` overload that takes a float arg
        return torch.tensor(data)
    else:
        # Type inferred to be the `torch.tensor` overload that takes a
        # `List[float]` since it's the one remaining Union type we
        # haven't used in all in-scope branches. This is the best
        # type inference I could safely do
        return torch.tensor(data)

I would like to do some sort of follow-up where we take values typed as Union or Any and insert automatic type refinement branches, but it would be a huge project. I've been idly thinking on how to resolve the perf issues associated with adding these extra branches that the user might not need (probably the hardest problem outside of the type inference itself), but it's not p0 since there is this (very, very stupid-looking) workaround.

@pmeier
Copy link
Collaborator

pmeier commented Sep 28, 2021

@ansley Thanks a lot for the additional information. Just so I got it right: although a function like torch.tensor can take a Union of inputs and always return the same type, we cannot invoke it with a parameter of said Union, but have to invoke it for every partial type individually?

@ansley
Copy link

ansley commented Sep 28, 2021

@pmeier Yeah, in the worst case (i.e. when there isn't additional type information), we rely on the user to annotate the code so that we can uniquely identify a function signature at compile time. We can't delay figuring out which function to call until runtime.

Maybe I'm just being pedantic, but I genuinely hope it will help the problem be more clear for me to say: torch.tensor can't take a Union of inputs. Instead, there are multiple overloads of torch.tensor, each of which take a unique type. The set of all those types is indeed a Union of types, but that doesn't mean that we have a torch.tensor function that takes that Union.

You can see examples of the times we can automatically refine types in the test suite for Union here.

@oke-aditya oke-aditya mentioned this pull request Oct 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants