forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 66
/
Copy pathgen_schema_utils.py
97 lines (84 loc) · 3.24 KB
/
gen_schema_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from typing import Any, Optional, Union
from torchgen.model import (
Annotation,
Argument,
Arguments,
BaseOperatorName,
BaseTy,
BaseType,
CustomClassType,
FunctionSchema,
ListType,
OperatorName,
Return,
)
# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
# their schemas can vary for different instances of the same HOP.
class TypeGen:
convert_to_base_ty = {
int: BaseTy.int,
float: BaseTy.float,
str: BaseTy.str,
bool: BaseTy.bool,
}
@staticmethod
def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
import torch
if isinstance(obj, torch.fx.GraphModule):
return BaseType(BaseTy.GraphModule)
elif isinstance(obj, torch.Tensor):
return BaseType(BaseTy.Tensor)
elif isinstance(obj, torch.SymInt):
return BaseType(BaseTy.SymInt)
elif isinstance(obj, torch.SymBool):
return BaseType(BaseTy.SymBool)
elif isinstance(obj, torch.ScriptObject):
return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
elif isinstance(obj, (list, tuple)):
assert len(obj) > 0
all_base_tys = [TypeGen.from_example(x) for x in obj]
if len(set(all_base_tys)) > 1:
raise RuntimeError(
f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
"Consider unpacking the argument and give proper names to them if possible "
"instead of using *args."
)
return ListType(all_base_tys[0], len(obj))
tp = type(obj)
if tp not in TypeGen.convert_to_base_ty:
raise RuntimeError(f"unsupported type {tp}")
return BaseType(TypeGen.convert_to_base_ty[tp])
class ReturnGen:
@staticmethod
def from_example(
name: Optional[str], obj: Any, annotation: Optional[Annotation]
) -> Return:
return Return(name, TypeGen.from_example(obj), annotation)
class ArgumentGen:
@staticmethod
def from_example(
name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
) -> Argument:
return Argument(
name, TypeGen.from_example(obj), default=default, annotation=annotation
)
class FunctionSchemaGen:
@staticmethod
def from_example(
op_name: str,
example_inputs: tuple[tuple[str, Any], ...],
example_outputs: tuple[Any, ...],
) -> FunctionSchema:
args = []
for name, inp in example_inputs:
args.append(ArgumentGen.from_example(name, inp, None, None))
# ignore the annotations and other attributes for now, we could add more when needed.
arguments = Arguments(
tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
)
returns = tuple(
ReturnGen.from_example(None, out, None) for out in example_outputs
)
op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
return FunctionSchema(op_name, arguments, returns)