Skip to content

Commit

Permalink
[pt2] add meta function for linalg.cross
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
nkaretnikov committed Apr 16, 2023
1 parent 7d12ea9 commit 8f3e8b3
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
2 changes: 0 additions & 2 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,7 +2510,6 @@ def forward(self, x):
xfail('cholesky_inverse', ''), # could not find kernel
xfail('cholesky_solve', ''), # could not find kernel
xfail('combinations', ''), # aten.masked_select.default
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('cumprod', ''), # aten.cumprod.default - couldn't find symbolic meta function/decomposition
xfail('cumulative_trapezoid', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
Expand All @@ -2526,7 +2525,6 @@ def forward(self, x):
xfail('kthvalue', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('linalg.cholesky_ex', ''), # could not find kernel for aten.linalg_solve_triangular.default
xfail('linalg.cond', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('linalg.det', ''), # aten._linalg_det.default - couldn't find symbolic meta function/decomposition
xfail('linalg.det', 'singular'), # aten._linalg_det.default - couldn't find symbolic meta function/deco...
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
Expand Down
2 changes: 0 additions & 2 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,7 +1369,6 @@ def f(a, b, c, d, e):
xfail('linalg.eigvals'),
xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
xfail('combinations', ''),
xfail('cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('cumulative_trapezoid', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition
xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
Expand All @@ -1386,7 +1385,6 @@ def f(a, b, c, d, e):
xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
xfail('linalg.cond', ''), # Tensors of type TensorImpl do not have numel
xfail('linalg.cross', ''), # aten.linalg_cross.default - couldn't find symbolic meta function/decomposition
xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
xfail('linalg.eigvalsh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
xfail('linalg.householder_product', ''), # aten.linalg_householder_product.default - couldn't find symbolic meta funct...
Expand Down
20 changes: 20 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ def meta_take(self, index, *, out=None):
return result


@register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
@out_wrapper()
def linalg_cross(self, other, *, dim=-1):
x_d = self.ndim
y_d = other.ndim
check(
x_d == y_d,
lambda: "linalg.cross: inputs must have the same number of dimensions.",
)
check(
self.size(dim) == 3 and other.size(dim) == 3,
lambda: (
f"linalg.cross: inputs dimension {dim} must have length 3. "
f"Got {self.size(dim)} and {other.size(dim)}"
),
)
out_shape = _broadcast_shapes(self.shape, other.shape)
return self.new_empty(out_shape)


@register_meta(
[aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
)
Expand Down

0 comments on commit 8f3e8b3

Please sign in to comment.