Skip to content

added stubs for jit tree views #156504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ python_version = 3.11
# Extension modules without stubs.
#

[mypy-torch._C._jit_tree_views]
ignore_missing_imports = True

[mypy-torch.for_onnx.onnx]
ignore_missing_imports = True

Expand Down
202 changes: 202 additions & 0 deletions torch/_C/_jit_tree_views.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Any, Optional

# Defined in torch/csrc/jit/python/python_tree_views.cpp

class SourceRange:
def highlight(self) -> str: ...
@property
def start(self) -> int: ...
@property
def end(self) -> int: ...

class SourceRangeFactory:
def __init__(
self,
text: str,
filename: Any,
file_lineno: int,
leading_whitespace_chars: int,
) -> None: ...
def make_range(self, line: int, start_col: int, end_col: int) -> SourceRange: ...
def make_raw_range(self, start: int, end: int) -> SourceRange: ...
@property
def source(self) -> str: ...

class TreeView:
def range(self) -> SourceRange: ...
def dump(self) -> None: ...

class Ident(TreeView):
def __init__(self, *args: Any, **kwargs: Any) -> None: ...
@property
def name(self) -> str: ...

class Param(TreeView):
def __init__(self, type: Optional[Any], name: Ident, kwarg_only: bool) -> None: ...

class Attribute(TreeView):
def __init__(self, name: Ident, value: Any) -> None: ...

# Literals
def TrueLiteral(range: SourceRange) -> Any: ...
def FalseLiteral(range: SourceRange) -> Any: ...
def NoneLiteral(range: SourceRange) -> Any: ...

# Tree nodes
class Stmt(TreeView):
def __init__(self, thing: TreeView) -> None: ...

class Expr(TreeView): ...

class Def(TreeView):
def __init__(self, name: Ident, decl: Any, body: list[Stmt]) -> None: ...
def decl(self) -> Any: ...
def name(self) -> Ident: ...

class Property(TreeView):
def __init__(
self, r: SourceRange, name: Ident, getter: Def, setter: Optional[Def]
) -> None: ...
def name(self) -> Ident: ...
def getter_name(self) -> str: ...
def setter_name(self) -> Optional[Ident]: ...

class ClassDef(TreeView):
def __init__(
self, name: Ident, body: list[Stmt], props: list[Property], assigns: list[Any]
) -> None: ...

class Decl(TreeView):
def __init__(
self, r: SourceRange, params: list[Param], return_type: Optional[Expr]
) -> None: ...

class Delete(Stmt):
def __init__(self, range: SourceRange, targets: list[Expr]) -> None: ...

class WithItem(Expr):
def __init__(
self, range: SourceRange, target: Expr, var: Optional[Any]
) -> None: ...

class Assign(Stmt):
def __init__(
self, lhs: list[Expr], rhs: Expr, type: Optional[Expr] = None
) -> None: ...

class AugAssign(Stmt):
def __init__(self, lhs: Expr, kind_str: str, rhs: Expr) -> None: ...

class Return(Stmt):
def __init__(self, range: SourceRange, value: Optional[Expr]) -> None: ...

class Raise(Stmt):
def __init__(self, range: SourceRange, expr: Expr) -> None: ...

class Assert(Stmt):
def __init__(self, range: SourceRange, test: Expr, msg: Optional[Expr]) -> None: ...

class Pass(Stmt):
def __init__(self, range: SourceRange) -> None: ...

class Break(Stmt): ...
class Continue(Stmt): ...

class Dots(Expr, TreeView):
def __init__(self, range: SourceRange) -> None: ...

class If(Stmt):
def __init__(
self,
range: SourceRange,
cond: Expr,
true_branch: list[Stmt],
false_branch: list[Stmt],
) -> None: ...

class While(Stmt):
def __init__(self, range: SourceRange, cond: Expr, body: list[Stmt]) -> None: ...

class With(Stmt):
def __init__(
self, range: SourceRange, targets: list[WithItem], body: list[Stmt]
) -> None: ...

class For(Stmt):
def __init__(
self,
range: SourceRange,
targets: list[Expr],
itrs: list[Expr],
body: list[Stmt],
) -> None: ...

class ExprStmt(Stmt):
def __init__(self, expr: Expr) -> None: ...

class Var(Expr):
def __init__(self, name: Ident) -> None: ...
@property
def name(self) -> str: ...

class BinOp(Expr):
def __init__(self, kind: str, lhs: Expr, rhs: Expr) -> None: ...

class UnaryOp(Expr):
def __init__(self, range: SourceRange, kind: str, expr: Expr) -> None: ...

class Const(Expr):
def __init__(self, range: SourceRange, value: str) -> None: ...

class StringLiteral(Expr):
def __init__(self, range: SourceRange, value: str) -> None: ...

class Apply(Expr):
def __init__(
self, expr: Expr, args: list[Expr], kwargs: list[Attribute]
) -> None: ...

class Select(Expr):
def __init__(self, expr: Expr, field: Ident) -> None: ...

class TernaryIf(Expr):
def __init__(self, cond: Expr, true_expr: Expr, false_expr: Expr) -> None: ...

class ListComp(Expr):
def __init__(
self, range: SourceRange, elt: Expr, target: Expr, iter: Expr
) -> None: ...

class DictComp(Expr):
def __init__(
self, range: SourceRange, key: Expr, value: Expr, target: Expr, iter: Expr
) -> None: ...

class ListLiteral(Expr):
def __init__(self, range: SourceRange, args: list[Expr]) -> None: ...

class TupleLiteral(Expr):
def __init__(self, range: SourceRange, args: list[Expr]) -> None: ...

class DictLiteral(Expr):
def __init__(
self, range: SourceRange, keys: list[Expr], values: list[Expr]
) -> None: ...

class Subscript(Expr):
def __init__(self, base: Expr, subscript_exprs: list[Expr]) -> None: ...

class SliceExpr(Expr):
def __init__(
self,
range: SourceRange,
lower: Optional[Expr],
upper: Optional[Expr],
step: Optional[Expr],
) -> None: ...

class Starred(Expr):
def __init__(self, range: SourceRange, expr: Expr) -> None: ...

class EmptyTypeAnnotation(TreeView):
def __init__(self, range: SourceRange) -> None: ...
10 changes: 7 additions & 3 deletions torch/jit/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,11 @@ def build_def(ctx, py_def, type_line, def_name, self_name=None, pdt_arg_types=No
is_method = self_name is not None
if type_line is not None:
type_comment_decl = torch._C.parse_type_comment(type_line)
decl = torch._C.merge_type_from_type_comment(decl, type_comment_decl, is_method)
decl = torch._C.merge_type_from_type_comment(
decl, # type: ignore[arg-type]
type_comment_decl,
is_method, # type: ignore[assignment]
)

return Def(Ident(r, def_name), decl, build_stmts(ctx, body))

Expand Down Expand Up @@ -1055,12 +1059,12 @@ def build_Compare(ctx, expr):
in_expr = BinOp("in", lhs, rhs)
cmp_expr = UnaryOp(r, "not", in_expr)
else:
cmp_expr = BinOp(op_token, lhs, rhs)
cmp_expr = BinOp(op_token, lhs, rhs) # type: ignore[assignment]

if result is None:
result = cmp_expr
else:
result = BinOp("and", result, cmp_expr)
result = BinOp("and", result, cmp_expr) # type: ignore[assignment]
return result

@staticmethod
Expand Down
Loading