Skip to content

Commit

Permalink
Add shape function for stack op
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jan 14, 2023
1 parent 353e9f8 commit b33c68b
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
149 changes: 149 additions & 0 deletions torch/csrc/jit/runtime/serialized_shape_function_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,154 @@ def conv_forwards(input: List[int],
_11 = result_size
return _11
)=====")
+ std::string(R"=====(def stack(tensors: List[List[int]],
dim: int) -> List[int]:
_0 = "AssertionError: Tensors must have same number of dimensions"
_1 = "AssertionError: Sizes of tensors must match except in dimension"
unsqueezed_tensors = annotate(List[List[int]], [])
for _2 in range(torch.len(tensors)):
tensor = tensors[_2]
_3 = torch.add(torch.len(tensor), 1)
if torch.le(_3, 0):
dim_post_expr = 1
else:
dim_post_expr = _3
min = torch.neg(dim_post_expr)
max = torch.sub(dim_post_expr, 1)
if torch.lt(dim, min):
_4 = True
else:
_4 = torch.gt(dim, max)
if torch.__not__(_4):
pass
else:
ops.prim.RaiseException("AssertionError: ")
if torch.lt(dim, 0):
dim0 = torch.add(dim, dim_post_expr)
else:
dim0 = dim
unsqueezed = annotate(List[int], [])
for _5 in range(torch.len(tensor)):
elem = tensor[_5]
_6 = torch.append(unsqueezed, elem)
torch.insert(unsqueezed, dim0, 1)
_7 = torch.append(unsqueezed_tensors, unsqueezed)
for _8 in range(torch.len(unsqueezed_tensors)):
tensor0 = unsqueezed_tensors[_8]
if torch.gt(torch.len(tensor0), 0):
pass
else:
ops.prim.RaiseException("AssertionError: ")
out_dim: Optional[int] = None
for _9 in range(torch.len(unsqueezed_tensors)):
size = unsqueezed_tensors[_9]
if torch.eq(torch.len(size), 1):
_10 = torch.eq(size[0], 0)
else:
_10 = False
if torch.__not__(_10):
if torch.__is__(out_dim, None):
_11 = torch.len(size)
if torch.le(_11, 0):
dim_post_expr0 = 1
else:
dim_post_expr0 = _11
min0 = torch.neg(dim_post_expr0)
max0 = torch.sub(dim_post_expr0, 1)
if torch.lt(dim, min0):
_12 = True
else:
_12 = torch.gt(dim, max0)
if torch.__not__(_12):
pass
else:
ops.prim.RaiseException("AssertionError: ")
if torch.lt(dim, 0):
dim1 = torch.add(dim, dim_post_expr0)
out_dim2 = dim1
else:
out_dim2 = dim
out_dim1 = out_dim2
else:
out_dim1 = unchecked_cast(int, out_dim)
out_dim0 : Optional[int] = out_dim1
else:
out_dim0 = out_dim
out_dim = out_dim0
if torch.__is__(out_dim, None):
dim2 = dim
else:
dim2 = unchecked_cast(int, out_dim)
_13 = torch.gt(torch.len(unsqueezed_tensors), 0)
if _13:
pass
else:
ops.prim.RaiseException("AssertionError: ")
not_skipped_tensor: Optional[List[int]] = None
for _14 in range(torch.len(unsqueezed_tensors)):
tensor1 = unsqueezed_tensors[_14]
numel = 1
for _15 in range(torch.len(tensor1)):
elem0 = tensor1[_15]
numel = torch.mul(numel, elem0)
if torch.eq(numel, 0):
_16 = torch.eq(torch.len(tensor1), 1)
else:
_16 = False
if torch.__not__(_16):
not_skipped_tensor0 : Optional[List[int]] = tensor1
else:
not_skipped_tensor0 = not_skipped_tensor
not_skipped_tensor = not_skipped_tensor0
_17 = torch.__is__(not_skipped_tensor, None)
if _17:
_18 = [0]
else:
not_skipped_tensor1 = unchecked_cast(List[int], not_skipped_tensor)
cat_dim_size = 0
for i in range(torch.len(unsqueezed_tensors)):
tensor2 = unsqueezed_tensors[i]
numel0 = 1
for _19 in range(torch.len(tensor2)):
elem1 = tensor2[_19]
numel0 = torch.mul(numel0, elem1)
if torch.eq(numel0, 0):
_20 = torch.eq(torch.len(tensor2), 1)
else:
_20 = False
if torch.__not__(_20):
first_dims = torch.len(not_skipped_tensor1)
second_dims = torch.len(tensor2)
_21 = torch.eq(first_dims, second_dims)
if _21:
pass
else:
ops.prim.RaiseException(_0)
_22 = torch.__range_length(0, first_dims, 1)
for _23 in range(_22):
dim3 = torch.__derive_index(_23, 0, 1)
if torch.ne(dim3, dim2):
_24 = torch.eq(not_skipped_tensor1[dim3], tensor2[dim3])
if _24:
pass
else:
ops.prim.RaiseException(_1)
else:
pass
cat_dim_size1 = torch.add(cat_dim_size, tensor2[dim2])
cat_dim_size0 = cat_dim_size1
else:
cat_dim_size0 = cat_dim_size
cat_dim_size = cat_dim_size0
result_size = annotate(List[int], [])
for _25 in range(torch.len(not_skipped_tensor1)):
elem2 = not_skipped_tensor1[_25]
_26 = torch.append(result_size, elem2)
_27 = torch._set_item(result_size, dim2, cat_dim_size)
_18 = result_size
return _18
)=====")
+ std::string(R"=====(def permute(input: List[int],
dims: List[int]) -> List[int]:
Expand Down Expand Up @@ -2955,6 +3103,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
{"aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", "conv_transpose2d_input"},
{"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
{"aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "cat"},
{"aten::stack(Tensor[] tensors, int dim=0) -> Tensor", "stack"},
{"aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "permute"},
{"aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "movedim"},
{"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"},
Expand Down
9 changes: 9 additions & 0 deletions torch/jit/_shape_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,14 @@ def cat(tensors: List[List[int]], dim: int):
return result_size


def stack(tensors: List[List[int]], dim: int):
unsqueezed_tensors: List[List[int]] = []
for tensor in tensors:
unsqueezed = unsqueeze(tensor, dim)
unsqueezed_tensors.append(unsqueezed)
return cat(unsqueezed_tensors, dim)


def select(self: List[int], dim: int, index: int):
ndim = len(self)
assert ndim != 0
Expand Down Expand Up @@ -1100,6 +1108,7 @@ def add_bounded_compute_mapping(operator_schema: str, lower_bound_func: Callable
add_shape_compute_mapping("aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] output_padding=0, int groups=1, int[2] dilation=1) -> Tensor", conv_transpose2d_input)
add_shape_compute_mapping("aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", flatten)
add_shape_compute_mapping("aten::cat(Tensor[] tensors, int dim=0) -> Tensor", cat)
add_shape_compute_mapping("aten::stack(Tensor[] tensors, int dim=0) -> Tensor", stack)
add_shape_compute_mapping("aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", permute)
add_shape_compute_mapping("aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", movedim)
add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)", view)
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15429,6 +15429,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
assert_autodiffed=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_jit_shape_analysis=True,
skips=(
# https://github.com/pytorch/pytorch/issues/77046
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_conj_view'),
Expand Down

0 comments on commit b33c68b

Please sign in to comment.