Skip to content

Commit

Permalink
Update on "Add Opinfo entry for add_alias"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
rec committed May 24, 2024
2 parents 404efec + 1023563 commit 37023dc
Showing 1 changed file with 0 additions and 55 deletions.
55 changes: 0 additions & 55 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,61 +1478,6 @@ def reindexer(idx):
return TensorBox(ir.GenericView.create(input, sizes, reindexer))


@register_lowering(aten.diagonal, type_promotion_kind=None)
def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
original_shape = input.get_size()
num_dims = len(original_shape)
dim1 = canonicalize_dim(idx=dim1, rank=num_dims)
dim2 = canonicalize_dim(idx=dim2, rank=num_dims)

check(
dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
)

offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0))
if offset_negative:
diag_size = V.graph.sizevars.evaluate_max(
V.graph.sizevars.evaluate_min(
original_shape[dim1] + offset, original_shape[dim2]
),
0, # type: ignore[arg-type]
)
else:
diag_size = V.graph.sizevars.evaluate_max(
V.graph.sizevars.evaluate_min(
original_shape[dim1], original_shape[dim2] - offset
),
0, # type: ignore[arg-type]
)

base_idx = (0, 0)
if offset_negative:
base_idx = (-offset, 0)
else:
base_idx = (0, offset)

sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)]
sizes.append(diag_size)

def reindexer(idx):
diag_idx = idx[-1]
original_idx = [0] * len(original_shape)
cur_dim = 0
for d in range(num_dims):
if d == dim1:
original_idx[d] = diag_idx + base_idx[0]
elif d == dim2:
original_idx[d] = diag_idx + base_idx[1]
else:
original_idx[d] = idx[cur_dim]
cur_dim += 1

assert cur_dim == len(original_shape) - 2
return original_idx

return TensorBox(ir.GenericView.create(input, sizes, reindexer))


@register_lowering(aten.diagonal_copy, type_promotion_kind=None)
def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
return clone(diagonal(input, offset, dim1, dim2))
Expand Down

0 comments on commit 37023dc

Please sign in to comment.