Skip to content

Commit

Permalink
[JIT] Fix function schema subtype checking (pytorch#47706)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#47706

**Summary**
This commit fixes `FunctionSchema::isSubtypeOf` so that the subtyping rule it
implements for `FunctionSchema` instances is contravariant in argument
types and covariant in return type. At present, the rule is covariant in
argument types and contravariant in return type, which is not correct.

A brief but not rigourous explanation follows. Suppose there are two
`FunctionSchema`s, `M = (x: T) -> R` and `N = (x: U) -> S`. For `M <= N`
to be true (i.e. that `M` is a subtype of `N`), it must be true that `U
<= T` and `R <= S`. This generalizes to functions with multiple
arguments.

**Test Plan**
This commit extends `TestModuleInterface.test_module_interface_subtype`
with two new tests cases that test the contravariance of argument types
and covariance of return types in determining whether a `Module`
implements an interface type.

**Fixes**
This commit closes pytorch#47631.

Test Plan: Imported from OSS

Reviewed By: nikithamalgifb

Differential Revision: D24934099

Pulled By: SplitInfinity

fbshipit-source-id: bd07e7b47d2a3a56d676f2f572de09fb18ececd8
  • Loading branch information
Meghan Lele authored and tugsbayasgalan committed Nov 16, 2020
1 parent 805d894 commit 3f09a64
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/core/function_schema_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,12 @@ inline bool FunctionSchema::isSubtypeOf(
bool as_method,
std::ostream* why_not) const {
size_t start = as_method ? 1 : 0;
// functions are covariant in arguments but contravariant in returns
// functions are contravariant in arguments but covariant in returns
return isSubtypeOfList(
ArrayRef<Argument>(arguments()).slice(start),
ArrayRef<Argument>(rhs.arguments()).slice(start),
ArrayRef<Argument>(arguments()).slice(start),
why_not) &&
isSubtypeOfList(rhs.returns(), returns(), why_not);
isSubtypeOfList(returns(), rhs.returns(), why_not);
}

} // namespace c10
40 changes: 39 additions & 1 deletion test/jit/test_module_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# flake8: noqa
# TODO: enable linting check for this file

from typing import List
from typing import List, Any
import torch
import torch.nn as nn
import os
Expand Down Expand Up @@ -202,6 +202,44 @@ def forward(self, x):
with self.assertRaisesRegex(RuntimeError, "is not compatible with interface"):
as_module_interface(scripted_wrong_mod)

# Check that interface implementations can be contravariant in argument types and covariant in return type.
global TensorToAny
@torch.jit.interface
class TensorToAny(nn.Module):
def forward(self, input: torch.Tensor) -> Any:
pass

@torch.jit.script
def as_tensor_to_any(x: TensorToAny) -> TensorToAny:
return x

global AnyToAny
@torch.jit.interface
class AnyToAny(nn.Module):
def forward(self, input: Any) -> Any:
pass

@torch.jit.script
def as_any_to_any(x: AnyToAny) -> AnyToAny:
return x

class TensorToAnyImplA(nn.Module):
def forward(self, input: Any) -> Any:
return input

class TensorToAnyImplB(nn.Module):
def forward(self, input: Any) -> torch.Tensor:
return torch.tensor([1])

class AnyToAnyImpl(nn.Module):
def forward(self, input: Any) -> torch.Tensor:
return torch.tensor([1])

as_tensor_to_any(torch.jit.script(TensorToAnyImplA()))
as_tensor_to_any(torch.jit.script(TensorToAnyImplB()))
as_any_to_any(torch.jit.script(AnyToAnyImpl()))


def test_module_interface_inheritance(self):
with self.assertRaisesRegex(RuntimeError, "does not support inheritance yet. Please directly"):
@torch.jit.interface
Expand Down

0 comments on commit 3f09a64

Please sign in to comment.