Skip to content

Commit

Permalink
Remove TopkOp
Browse files Browse the repository at this point in the history
  • Loading branch information
mory91 authored and ricardoV94 committed Apr 2, 2024
1 parent ef22377 commit f7b0a7a
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 590 deletions.
2 changes: 1 addition & 1 deletion pytensor/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:

# We import as `_shared` instead of `shared` to avoid confusion between
# `pytensor.shared` and `tensor._shared`.
from pytensor.tensor.sort import argsort, argtopk, sort, topk, topk_and_argtopk
from pytensor.tensor.sort import argsort, sort
from pytensor.tensor.subtensor import *
from pytensor.tensor.type import *
from pytensor.tensor.type_other import *
Expand Down
32 changes: 0 additions & 32 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add, eq
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import NoDuplicateOptWarningFilter
Expand Down Expand Up @@ -1224,35 +1223,4 @@ def local_merge_alloc(fgraph, node):
return [alloc(inputs_inner[0], *dims_outer)]


@register_useless("fast_compile")
@node_rewriter([TopKOp])
def local_useless_topk(fgraph, node):
"""Remove unused `TopKOp` outputs."""
op = node.op
if not isinstance(op, TopKOp):
return
if not (op.return_values and op.return_indices):
return False

x, k = node.inputs
ret_val = bool(fgraph.clients[node.outputs[0]])
ret_idx = bool(fgraph.clients[node.outputs[1]])

if not (ret_val ^ ret_idx):
# both true -> nothing to remove
# both false -> let pruner handle
return False

old_output = node.outputs[ret_idx]
new_output = TopKOp(
axis=op.axis,
sorted=op.sorted,
idx_dtype=op.idx_dtype,
return_values=ret_val,
return_indices=ret_idx,
)(x, k)
copy_stack_trace(node.outputs[0], new_output)
return {old_output: new_output}


register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
273 changes: 2 additions & 271 deletions pytensor/tensor/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.op import Op
from pytensor.misc.safe_asarray import _asarray
from pytensor.tensor.basic import arange, as_tensor_variable, flatten, switch
from pytensor.tensor.basic import arange, as_tensor_variable, switch
from pytensor.tensor.math import eq, ge, mul
from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import set_subtensor
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type import TensorType


def _variable_is_none(var):
Expand Down Expand Up @@ -304,270 +302,3 @@ def _topk_py_impl(op, x, k, axis, idx_dtype):
else:
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
return zi.astype(idx_dtype)


class TopKOp(Op):
"""Operations related to finding k-largest elements.
Parameters
----------
axis: integer
Defaults to ``-1``.
The axis to perform the operation. Must be in range ``[-ndim, ndim)``, where
``ndim`` is the dimensionality of input tensor.
idx_dtype: string
Specify output dtype for indices, defaults to ``int64``, must be integer type.
sorted: bool
NOTE: NOT IMPLEMENTED YET
Defaults to ``True``
If True, the result array would be sorted in descending order.
Notes
-----
- The output order is not guaranteed. On the CPU, we use
``np.partition`` and ``np.argpartition`` that only make sure the
k-th element is the correct one and that the other
elements are on the correct side.
- By default, this Op gives two outputs: values and indices. However
optimizers may remove a certain output if not needed.
- Computing the gradient requests the computation of the indices in
forward pass.
- If the top-k-th value is not unique, we cannot guarantee the
output indices being deterministically chosen.
See Also
--------
topk
argtopk
argtopk_and_topk
"""

# TODO more params
"""
only_top_kth: bool
Defaults to ``False``
If ``True``, will only find one exact top k-th element on given axis.
"""

# TODO c_code
# TODO add opt, if k==1, use max/min reduce
# also if k is axis size, just copy input tensor
# TODO add opt, to merge argtopk / topk
__props__ = ("axis", "sorted", "return_values", "return_indices", "idx_dtype")

def __init__(
self,
axis=-1,
sorted=True,
idx_dtype="int64",
return_values=True,
return_indices=True,
):
# numpy always uses int64 as output dtype for arg*() routines
# however, we add "idx_dtype" param as memory is more precious on gpu
if not isinstance(axis, int):
raise TypeError(f'"axis" parameter must be integer, got "{type(axis)}"')
if sorted:
raise NotImplementedError(
"The sorted parameter is not yet implemented. Use sorted=False for now."
)
if idx_dtype not in integer_dtypes:
raise TypeError(
f'"idx_dtype" parameter must be an integer dtype, got "{idx_dtype}"'
)

if not (return_indices or return_values):
raise ValueError(
"Neither return_values nor return_indices is True, this isn't allowed"
)

self.axis = axis
self.sorted = sorted
self.return_values = return_values
self.return_indices = return_indices
self.idx_dtype = idx_dtype

def __str__(self):
return "%(op)s{axis=%(axis)d, sorted=%(sorted)s}" % dict(
op=self.__class__.__name__, axis=self.axis, sorted=self.sorted
)

def make_node(self, inp, kth):
inp = as_tensor_variable(inp)
ndim = inp.ndim
if ndim == 0:
raise ValueError("Cannot take scalar as input")
if not -ndim <= self.axis < ndim:
raise IndexError(
'"axis" parameter out of range,'
f" expected integer within [{int(-ndim)}, {int(ndim - 1)}]"
)

kth = as_tensor_variable(kth)
_check_tensor_is_scalar(kth)
outs = []
if self.return_values:
outs.append(
TensorType(dtype=inp.type.dtype, shape=(None,) * inp.type.ndim)()
)
if self.return_indices:
outs.append(
TensorType(dtype=self.idx_dtype, shape=(None,) * inp.type.ndim)()
)
return Apply(self, [inp, kth], outs)

def perform(self, node, inputs, output_storage):
x, k = inputs
axis = self.axis
if not self.return_indices:
pzv = output_storage[0]
pzv[0] = _topk_py_impl(self, x, k, axis, None)
elif self.return_values:
pzv = output_storage[0]
pzi = output_storage[1]
pzv[0], pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[1].dtype)
else:
pzi = output_storage[0]
pzi[0] = _topk_py_impl(self, x, k, axis, node.outputs[0].dtype)

def infer_shape(self, fgraph, node, inp_shapes):
shp = list(inp_shapes[0])
shp[self.axis] = np.abs(node.inputs[1])
shp = tuple(shp)
return [shp for i in [self.return_values, self.return_indices] if i]

def L_op(self, inputs, outputs, out_grads):
x, k = inputs
k_grad = grad_undefined(self, 1, k, "topk: k is not differentiable")

if not (self.return_indices or self.return_values):
x_grad = grad_undefined(
self,
0,
x,
"topk: cannot get gradient without both indices and values",
)
else:
x_shp = shape(x)
z_grad = out_grads[0]
ndim = x.ndim
axis = self.axis % ndim
grad_indices = [
arange(x_shp[i]).dimshuffle([0] + ["x"] * (ndim - i - 1))
if i != axis
else outputs[-1]
for i in range(ndim)
]
x_grad = x.zeros_like(dtype=z_grad.dtype)
x_grad = set_subtensor(x_grad[tuple(grad_indices)], z_grad)

return [x_grad, k_grad]


def topk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
"""
Returns the k-largest elements along an axis.
Parameters
----------
x: tensor instance
kth: integer constant/variable
Must not be 0. If negative, gives k-smallest elements instead.
axis: integer or ``None``
Upon which axis shall the operation be performed on.
If ``None``, works on flattened array.
sorted: bool
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True``
If True, the result array would be sorted in descending order.
idx_dtype: string
Specify output dtype used in indices, defaults to ``int64``, must be integer type.
This option is here because indices are needed for gradient.
Returns
-------
Tensor variable with same dtype as `x`.
Notes
-----
- ``sorted=True`` is not supported yet.
"""
if axis is None:
x = flatten(x)
axis = 0
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[0]


def argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
"""
Returns the indices of k-largest elements along an axis.
Parameters
----------
x: tensor instance
kth: integer constant/variable
Must not be 0. If negative, gives k-smallest elements instead.
sorted: bool
NOTE: NOT IMPLEMENTED YET, USE ``False`` FOR NOW.
Defaults to ``True``
If True, the result array of corresponding indices would be sorted in descending order.
axis: integer, tuple/list of integers, or ``None``
Upon which axis shall the operation be performed on.
If ``None``, works on flattened array.
idx_dtype: string
Specify output dtype, defaults to ``int64``, must be integer type.
Returns
-------
Tensor variable with dtype specified in `idx_dtype`.
Notes
-----
- ``sorted=True`` is not supported yet.
- If the top-k-th value is not unique, we cannot guarantee the output
indices are deterministically chosen.
"""
if axis is None:
x = flatten(x)
axis = 0
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)[1]


def topk_and_argtopk(x, kth, axis=-1, sorted=True, idx_dtype="int64"):
"""
Returns the results of both topk() and argtopk() in one Op.
See the respective documentation for details.
Returns
-------
tuple: (values, indices)
"""
if axis is None:
x = flatten(x)
axis = 0
return TopKOp(axis=axis, sorted=sorted, idx_dtype=idx_dtype)(x, kth)
Loading

0 comments on commit f7b0a7a

Please sign in to comment.