Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op #731

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Dhruvanshu-Joshi
Copy link
Contributor

Description

MaxandArgmax Op calculates both maximum and argmax together. With this PR, we aim to have seperate ops for the two operations.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

TensorMax.debug = 0

def test_basic(self):
# dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0). MaxandArgmax used to work fine.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what does numpy do with scalar arrays? We should do the same as them

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy seems to work fine

Copy link
Member

@ricardoV94 ricardoV94 May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The/a problem seems to be this line, which is forcing Axis to be a tuple, when it should be allowed to be None for the scalar case:

self.axis = tuple(axis)

Should be self.axis = axis

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this and it does not seem to work.
This is the line in particular that causes a problem:

def __call__(self):
failure = self.run_cthunk(self.cthunk)
if failure:
task, taskname, id = self.find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
try:
exc_type, _exc_value, exc_trace = self.error_storage
if task in self.nodes:
self.position_of_error = self.nodes.index(task)
# this can be used to retrieve the location the Op was declared
exc_value = exc_type(_exc_value)
exc_value.__thunk_trace__ = trace
except Exception:
print(
(
"ERROR retrieving error_storage."
"Was the error set in the c code?"
),
end=" ",
file=sys.stderr,
)
print(self.error_storage, file=sys.stderr)
raise
raise exc_value.with_traceback(exc_trace)
def __str__(self):
return f"{type(self).__name__}({self.module})"

The if failure block never executes in case of non scalar inputs.

try:
from cutils_ext.cutils_ext import * # noqa

run_cthunk is imported using this but I cannot find cutils_ext.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may be hitting a windows/installation issue then. Does the test fail in the CI here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the CI error is the same as what I face locally. The difference between a scalar like as_tensor_variable(5.0) and non-scalar like as_tensor_variable([5.0]) is that in the former case, the if failure: block executes and in case of non-scalars, it never does.

pytensor/tensor/math.py Outdated Show resolved Hide resolved
pytensor/tensor/math.py Show resolved Hide resolved
pytensor/tensor/rewriting/uncanonicalize.py Outdated Show resolved Hide resolved
tests/tensor/rewriting/test_uncanonicalize.py Outdated Show resolved Hide resolved
tests/tensor/test_max_argmax.py Outdated Show resolved Hide resolved

def test_basic(self):
# dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scalars are still a problem. Am working on this and looking for how numpy handles them as suggested.

Copy link
Member

@ricardoV94 ricardoV94 May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's raising from this C-code check:

"Argmax, bad axis argument");

It seems not to handle the empty axes case that's passed when it's a scalar. Instead of empty axes we can convert to None or zero which is equivalent for the scalar. I prefer None because that's the default anyway.

But actually these lines seem to be creating the problem in the first place?

axis = check_and_normalize_axes(a, axis)
if len(axis) == 0:
axis = list(range(a.type.ndim))

I don't think it should be needed or this convoluted (referring to that check_and_normalize_axes, which I think is only used here?). We handle axes in other places with way less code. We should use the numpy helper like we do for other cases, or let axes = None alone, which both argmax and max support anyway.

Copy link
Member

@ricardoV94 ricardoV94 May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not seeing problems before because I was creating Argmax(axis=None)(pt.as_tensor(5.0)) directly. Which shows the problem is how the helper is creating the Argmax, basically Argmax(axis=()). Actually I am not sure what Argmax(axis=()) should do, I think it should return zeros_like(x) since it corresponds to a no-reduction. np.max(x, axis=()) just returns x as well. We should check our Max does the same btw, which I think it's not doing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the

It's raising from this C-code check:

"Argmax, bad axis argument");

It seems not to handle the empty axes case that's passed when it's a scalar. Instead of empty axes we can convert to None or zero which is equivalent for the scalar. I prefer None because that's the default anyway.

But actually these lines seem to be creating the problem in the first place?

axis = check_and_normalize_axes(a, axis)
if len(axis) == 0:
axis = list(range(a.type.ndim))

I don't think it should be needed or this convoluted (referring to that check_and_normalize_axes, which I think is only used here?). We handle axes in other places with way less code. We should use the numpy helper like we do for other cases, or let axes = None alone, which both argmax and max support anyway.

I tried this locally and it works with a silly modification:

def max_and_argmax(a, axis=None, keepdims=False):
    """
    Returns maximum elements and their indices obtained by iterating over
    given axis.

    When axis is None (the default value), the max is performed
    over the flattened tensor.

    Parameters
    ----------
    keepdims : bool
        If this is set to True, the axes which are reduced are left in
        the result as dimensions with size one. With this option, the result
        will broadcast correctly against the original tensor.

    """
    # Check axis and convert it to a Python list of integers.
    # Axis will be used as an op param of MaxAndArgmax.
    a = as_tensor_variable(a)
    axis = check_and_normalize_axes(a, axis)
    if len(axis) == 0:
        axis = None
    out = Max(axis)(a)
    argout = Argmax(axis)(a)

    if keepdims:
        out = makeKeepDims(a, out, axis)
        argout = makeKeepDims(a, argout, axis)
    return [out, argout]

Scalars work this way. But now in the grad function for max, the line axis = as_tensor_variable(self.axis) misbehaves as self.axis is None and as_tensor_variable(None) is wrong. Maybe doing if self.axis is None: self.axis= tuple(range(x.ndim)) help but will it be correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not seeing problems before because I was creating Argmax(axis=None)(pt.as_tensor(5.0)) directly. Which shows the problem is how the helper is creating the Argmax, basically Argmax(axis=()). Actually I am not sure what Argmax(axis=()) should do, I think it should return zeros_like(x) since it corresponds to a no-reduction. np.max(x, axis=()) just returns x as well. We should check our Max does the same btw, which I think it's not doing.

Do the current changes reflect on this? The case of pt.as_tensor(5.0) is handled effectively now ig.
And the assert

assert v == 5.0

does not give any error so I assume it is doing what we expect it to do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even

n = as_tensor_variable(5.0)
        v, i = eval_outputs(max_and_argmax(n, axis=()))
        assert v == 5.0
        assert i == 0
        assert i.dtype == "int64"
        v = eval_outputs(max_and_argmax(n)[0].shape)
        assert len(v) == 0
        v = eval_outputs(max_and_argmax(n)[1].shape)
        assert len(v) == 0

works fine so I assume axis=() works as expected?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why grad was converting axis to a tensor_variable, there is no point, since axis have to be constant. You can avoid that conversion.

@@ -1386,6 +1383,12 @@ def test_uint(self):
n = as_tensor_variable(data)
assert min(n).dtype == dtype
i = eval_outputs(min(n))
# pytensor.dprint(n)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dtype uint64 fails strangely for some reason. The rest work fine. The error message is that itype.min = 0 but i comes out to be equal to the 18446744073709551610 which is the second maximum in the list of [0 , 3, 18446744073709551610, 18446744073709551615].

The error is:

>           assert i == itype.min
E           assert array(18446744073709551610, dtype=uint64) == 0
E            +  where 0 = iinfo(min=0, max=18446744073709551615, dtype=uint64).min

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might have to do with the dtype used for internal accumulation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate a little on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what the problem is yet without looking further. But CAReduce has 2 dtypes, the output one and the one used for internal accumulation. I was wondering if the problem was coming from the internal accumulation dtype. Also uint are tricky because they don't represent negative numbers, but I think our implementation of min is something like -max(-x). You may need to investigate a bit the behavior to understand what's going on.

def maxandargmax(x):
return x, 0
def max(x):
return x

else:
axes = tuple(int(ax) for ax in axis)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes problem when axis is None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to convert into axis = tuple(range(x_ndim)). I assume this is only a problem for Argmax? I think the conversion is already done for Max by default (as is for all CAReduce)?

We can do the same conversion for Argmax. Are we always creating Argmax for the user in pt.argmax? If so we can do the conversion there already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh this code was redundant. numba_funcify_CAReduce already handles it effectively. Argmax behaves as expected already.

@Dhruvanshu-Joshi Dhruvanshu-Joshi force-pushed the remove_MaxArgmax branch 2 times, most recently from 1d9a484 to a278272 Compare May 13, 2024 17:59
@Dhruvanshu-Joshi
Copy link
Contributor Author

The tests failing are because of uint64 data type which is highlighted in #770 . So for this to be ready, should I just remove the test for uint64 for now and open another issue to add support back for this test once #770 is solved?

@ricardoV94
Copy link
Member

The tests failing are because of uint64 data type which is highlighted in #770 . So for this to be ready, should I just remove the test for uint64 for now and open another issue to add support back for this test once #770 is solved?

You can mark the test with pytest.mark.xfail. There are a couple of examples in the codebase

Copy link

codecov bot commented May 22, 2024

Codecov Report

Attention: Patch coverage is 79.22078% with 16 lines in your changes are missing coverage. Please review.

Project coverage is 80.85%. Comparing base (15b90be) to head (6b07a6e).
Report is 11 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #731      +/-   ##
==========================================
+ Coverage   80.83%   80.85%   +0.01%     
==========================================
  Files         162      162              
  Lines       46862    46969     +107     
  Branches    11465    11492      +27     
==========================================
+ Hits        37881    37975      +94     
- Misses       6733     6748      +15     
+ Partials     2248     2246       -2     
Files Coverage Δ
pytensor/compile/function/types.py 79.62% <100.00%> (+0.02%) ⬆️
pytensor/graph/op.py 87.89% <ø> (ø)
pytensor/ifelse.py 51.70% <ø> (ø)
pytensor/link/numba/dispatch/elemwise.py 88.64% <100.00%> (-0.08%) ⬇️
pytensor/tensor/rewriting/uncanonicalize.py 96.63% <100.00%> (+0.42%) ⬆️
pytensor/link/jax/dispatch/nlinalg.py 83.33% <69.23%> (-6.42%) ⬇️
pytensor/tensor/math.py 90.42% <78.18%> (+0.76%) ⬆️

... and 15 files with indirect coverage changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove MaxAndArgmax Op
2 participants