Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax printer doesn't properly convert sympy Min and Max during lambdify #26139

Closed
dardeshna opened this issue Jan 29, 2024 · 14 comments · Fixed by #26149
Closed

jax printer doesn't properly convert sympy Min and Max during lambdify #26139

dardeshna opened this issue Jan 29, 2024 · 14 comments · Fixed by #26149

Comments

@dardeshna
Copy link

jax can't accept tuples as input to amin and amax the same way that numpy does. this results in the following error when lambdifying a sympy expression with min/max:

TypeError: max requires ndarray or scalar arguments, got <class 'tuple'> at position 0.

def _print_Min(self, expr):
return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amin'), ','.join(self._print(i) for i in expr.args))
def _print_Max(self, expr):
return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amax'), ','.join(self._print(i) for i in expr.args))

i think the fix is simply to wrap the tuple in a jax.numpy.asarray() call (similarly to how and/or are carved out for jax)

@virajvekaria
Copy link
Contributor

    def _print_Min(self, expr):
        return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amin'), ','.join('jax.numpy.asarray({})'.format(self._print(i)) for i in expr.args))

    def _print_Max(self, expr):
        return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amax'), ','.join('jax.numpy.asarray({})'.format(self._print(i)) for i in expr.args))

So should this work??

@virajvekaria
Copy link
Contributor

Hey I have made one more change, and I wished to test it. Please can someone help me with that.

@dardeshna
Copy link
Author

yes although i might suggest something like this to remove the redundant parentheses

thanks for looking into this!

    def _print_Min(self, expr):
        return '{}({}.asarray([{}]), axis=0)'.format(self._module_format(self._module + '.amin'), self._module_format(self._module), ','.join(self._print(i) for i in expr.args))

    def _print_Max(self, expr):
        return '{}({}.asarray([{}]), axis=0)'.format(self._module_format(self._module + '.amax'), self._module_format(self._module), ','.join(self._print(i) for i in expr.args))

@virajvekaria
Copy link
Contributor

@dardeshna I'm pretty new to open source, what do I do next? Should I issue a Pull request?

@dardeshna
Copy link
Author

Honestly me as well — probably open a PR and look through the contributing guidelines for the project to see the workflow used by sympy?

@virajvekaria
Copy link
Contributor

okay i looked into it, I am currently running all the tests, and then probably open a PR, btw I really need to thank you, sir, this is literally my first issue solved in open source

@virajvekaria
Copy link
Contributor

ouch it fails 15 errors

@virajvekaria
Copy link
Contributor

one of the tests is being detected as trojan on my windows computer, please can someone help

@virajvekaria
Copy link
Contributor

Okay, so it's now three errors only, and one is such that Windows Defender detects it as a trojan, should I allow this file to run by putting it on Defender's whitelist?

@oscarbenjamin
Copy link
Collaborator

It is hard to answer about the trojan without more information. It sounds like an overzealous virus checker but hard to be sure.

@virajvekaria
Copy link
Contributor

Also, please can tell someone tell what is _np here??

def get_results_with_scipy(objective, constraints, variables):
    if scipy is not None and np is not None:
        from sympy.solvers.inequalities import _np
        nonpos, rep, xx = _np(constraints, [])
        assert not rep  # only testing nonneg variables
        C, _D = linear_eq_to_matrix(objective, *variables)
        A, B = linear_eq_to_matrix(nonpos, *variables)
        assert _D[0] == 0  # scipy only deals with D = 0

This code is from sympy\solvers\tests\test_simplex

@virajvekaria
Copy link
Contributor

It is hard to answer about the trojan without more information. It sounds like an overzealous virus checker but hard to be sure.

So how do I proceed with this, I shouldn't commit if there's an error right? There is an option to allow this code to run by putting it on whitelist but the question is can this code be trusted

For reference this test is sympy\utilities_compilation\tests\test_compilation.py

@asmeurer
Copy link
Member

It looks like that _np function was moved in 43d23ed. I guess that code was never executed in the tests CI. It would be better to not use a private function there if possible.

@asmeurer
Copy link
Member

While it's usually best to try to get everything working before committing, if you can't figure something out, it's better to just commit what you have an make a pull request with it, so that others can help you figure out how to fix it. It's much easier for people to help you if your code is pushed up so that they can see it and checkout and try it themselves.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants