New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
get_scalar_constant_value
now raises for non-scalar inputs
#248
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I think we just need to tweak the docstrings
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #248 +/- ##
==========================================
- Coverage 80.44% 80.44% -0.01%
==========================================
Files 170 170
Lines 45328 45333 +5
Branches 11069 11071 +2
==========================================
+ Hits 36463 36467 +4
Misses 6642 6642
- Partials 2223 2224 +1
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some change requests, should be the last mile!
pytensor/tensor/basic.py
Outdated
if isinstance(v, np.ndarray): | ||
data = v.data | ||
if data.ndim != 0: | ||
raise NotScalarConstantError() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should work with PyTensor variables. This might work
if isinstance(v, np.ndarray): | |
data = v.data | |
if data.ndim != 0: | |
raise NotScalarConstantError() | |
if isinstance(v, (Variable, np.ndarray)): | |
if v.ndim != 0: | |
raise NotScalarConstantError() |
Please add a check testing this functionality with a PyTensor variable input. That's actually the most common case, not NumPy arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, will add it right away!
pytensor/tensor/basic.py
Outdated
@@ -255,7 +255,26 @@ def _obj_is_wrappable_as_tensor(x): | |||
) | |||
|
|||
|
|||
def get_scalar_constant_value( | |||
def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't change the name, even if it's a bit verbose.
def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10): | |
def get_scalar_constant_value(v, elemwise=True, only_process_constants=False, max_recur=10): |
Needs to be updated everywhere
pytensor/tensor/basic.py
Outdated
) | ||
|
||
|
||
def get_underlying_scalar_constant( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To match the non-underlying
def get_underlying_scalar_constant( | |
def get_underlying_scalar_constant_value( |
Needs to be updated everywhere
get_scalar_constant_value
now raises for non-scalar typed inputs. get_underlying_scalar_constant_value
can be used for the old behavior
get_scalar_constant_value
now raises for non-scalar typed inputs. get_underlying_scalar_constant_value
can be used for the old behaviorget_scalar_constant_value
now raises for non-scalar inputs. get_underlying_scalar_constant_value
can be used for the old behavior
get_scalar_constant_value
now raises for non-scalar inputs. get_underlying_scalar_constant_value
can be used for the old behaviorget_scalar_constant_value
now raises for non-scalar inputs
303094f
to
1981bb5
Compare
@ricardoV94 It seems like the benchmark test is failing for a couple of tests in |
That's probably just a fluke, nothing to worry about |
Thanks @shreyas3156 |
Motivation for these changes
Closes #226
Implementation details
get_scalar_constant()
method that raises NotScalarConstantError when non-scalar outputs are passed.get_scalar_constant_value()
toget_underlying_scalar_constant()
Checklist
Major / Breaking Changes
Bugfixes