Skip to content

Commit

Permalink
Alias _size_N_t to BroadcastingListN[int] (#48297)
Browse files Browse the repository at this point in the history
Summary:
Because they are one and the same

Fixes #47528

Pull Request resolved: #48297

Reviewed By: eellison

Differential Revision: D25116203

Pulled By: malfet

fbshipit-source-id: 7edc2c89daa3f3302822b1f9b83b41b04658c6b7
  • Loading branch information
malfet authored and facebook-github-bot committed Nov 26, 2020
1 parent e7ca62b commit 8b248af
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 1 deletion.
16 changes: 16 additions & 0 deletions test/test_jit_py3.py
Expand Up @@ -797,6 +797,22 @@ def use_module_interface(mod_list: List[OneTwoModule], x: torch.Tensor):
# self.assertTrue(set(['aten::add.Tensor', 'aten::mul.Scalar']).issubset(
# set(torch.jit.export_opnames(scripted_M_mod))))

def test_broadcasting_list(self):
"""
Test BroadcastingList and torch.nn._size_N_t alias
"""
from torch._jit_internal import BroadcastingList2
from torch.nn.common_types import _size_2_t

def sum_i(x: _size_2_t) -> int:
return x[0] + x[1]

def sum_f(x: BroadcastingList2[float]) -> float:
return x[0] + x[1]

self.assertTrue(torch.jit.script(sum_i)(4) == 8)
self.assertTrue(torch.jit.script(sum_f)(4.5) == 9.)


if __name__ == '__main__':
run_tests()
16 changes: 16 additions & 0 deletions torch/csrc/jit/frontend/script_type_parser.cpp
Expand Up @@ -86,6 +86,22 @@ TypePtr ScriptTypeParser::subscriptToType(

c10::optional<std::pair<TypePtr, int32_t>> ScriptTypeParser::parseBroadcastList(
const Expr& expr) const {
// Alias torch.nn._common_types._size_?_t to BroadcastingList?[int]
if (expr.kind() == TK_VAR) {
auto var = Var(expr);
auto& name = var.name().name();
constexpr auto _size_prefix = "_size_";
constexpr auto _size_suffix = "_t";
constexpr auto _size_n_len = 9; // strlen("_size_X_t")
constexpr auto _size_prefix_len = 6; // strlen("_size_");
if (name.find(_size_prefix) == 0 && name.length() == _size_n_len &&
name.find(_size_suffix) == _size_prefix_len + 1 &&
::isdigit(name[_size_prefix_len])) {
int n = name[_size_prefix_len] - '0';
return std::pair<TypePtr, int32_t>(ListType::create(IntType::get()), n);
}
}

if (expr.kind() != TK_SUBSCRIPT)
return c10::nullopt;
auto subscript = Subscript(expr);
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/frontend/source_range.h
Expand Up @@ -14,7 +14,7 @@ struct SourceRange;
// Source represents a code segment. It keeps track of:
// - text : the text of the code segment
// - filename (optional) : if present, represents the name of the file from
// which the code semgemnt originated.
// which the code segment originated.
// - starting_line_no : represents the line in the original file where the
// code segment started.
struct Source {
Expand Down

0 comments on commit 8b248af

Please sign in to comment.