Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_scalar_constant_value now raises for non-scalar inputs #248

Merged
merged 7 commits into from Mar 22, 2023

Conversation

shreyas3156
Copy link
Contributor

Motivation for these changes

Closes #226

Implementation details

  • Added a get_scalar_constant() method that raises NotScalarConstantError when non-scalar outputs are passed.
  • Renamed the current get_scalar_constant_value() to get_underlying_scalar_constant()

Checklist

Major / Breaking Changes

  • None

Bugfixes

  • Yes.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I think we just need to tweak the docstrings

pytensor/tensor/basic.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented Mar 18, 2023

Codecov Report

Merging #248 (1981bb5) into main (4317d0d) will decrease coverage by 0.01%.
The diff coverage is 90.41%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #248      +/-   ##
==========================================
- Coverage   80.44%   80.44%   -0.01%     
==========================================
  Files         170      170              
  Lines       45328    45333       +5     
  Branches    11069    11071       +2     
==========================================
+ Hits        36463    36467       +4     
  Misses       6642     6642              
- Partials     2223     2224       +1     
Impacted Files Coverage Δ
pytensor/tensor/elemwise.py 87.97% <ø> (ø)
pytensor/tensor/exceptions.py 100.00% <ø> (ø)
pytensor/tensor/random/op.py 97.46% <0.00%> (ø)
pytensor/tensor/var.py 87.68% <50.00%> (ø)
pytensor/tensor/shape.py 91.74% <80.00%> (ø)
pytensor/tensor/rewriting/math.py 86.03% <84.61%> (ø)
pytensor/tensor/basic.py 90.48% <86.66%> (-0.04%) ⬇️
pytensor/__init__.py 91.22% <100.00%> (ø)
pytensor/gradient.py 76.79% <100.00%> (ø)
pytensor/link/jax/dispatch/tensor_basic.py 98.07% <100.00%> (ø)
... and 10 more

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some change requests, should be the last mile!

Comment on lines 268 to 272
if isinstance(v, np.ndarray):
data = v.data
if data.ndim != 0:
raise NotScalarConstantError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work with PyTensor variables. This might work

Suggested change
if isinstance(v, np.ndarray):
data = v.data
if data.ndim != 0:
raise NotScalarConstantError()
if isinstance(v, (Variable, np.ndarray)):
if v.ndim != 0:
raise NotScalarConstantError()

Please add a check testing this functionality with a PyTensor variable input. That's actually the most common case, not NumPy arrays.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, will add it right away!

@@ -255,7 +255,26 @@ def _obj_is_wrappable_as_tensor(x):
)


def get_scalar_constant_value(
def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't change the name, even if it's a bit verbose.

Suggested change
def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10):
def get_scalar_constant_value(v, elemwise=True, only_process_constants=False, max_recur=10):

Needs to be updated everywhere

)


def get_underlying_scalar_constant(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To match the non-underlying

Suggested change
def get_underlying_scalar_constant(
def get_underlying_scalar_constant_value(

Needs to be updated everywhere

@ricardoV94 ricardoV94 changed the title Raise NotScalarConstantError for homogenous non-scalar constants get_scalar_constant_value now raises for non-scalar typed inputs. get_underlying_scalar_constant_value can be used for the old behavior Mar 20, 2023
@ricardoV94 ricardoV94 changed the title get_scalar_constant_value now raises for non-scalar typed inputs. get_underlying_scalar_constant_value can be used for the old behavior get_scalar_constant_value now raises for non-scalar inputs. get_underlying_scalar_constant_value can be used for the old behavior Mar 20, 2023
@ricardoV94 ricardoV94 changed the title get_scalar_constant_value now raises for non-scalar inputs. get_underlying_scalar_constant_value can be used for the old behavior get_scalar_constant_value now raises for non-scalar inputs Mar 20, 2023
@ricardoV94 ricardoV94 added enhancement New feature or request bug Something isn't working and removed enhancement New feature or request labels Mar 20, 2023
@shreyas3156
Copy link
Contributor Author

@ricardoV94 It seems like the benchmark test is failing for a couple of tests in tests/link/jax/test_elemwise.py. Can you please help me understand it and if it is linked to any of the changes I made?

@ricardoV94
Copy link
Member

That's probably just a fluke, nothing to worry about

@ricardoV94 ricardoV94 merged commit fda240f into pymc-devs:main Mar 22, 2023
51 of 52 checks passed
@ricardoV94
Copy link
Member

Thanks @shreyas3156

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

get_scalar_constant does not raise for homogenous non-scalar constants
3 participants