Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rust codegen is broken #25173

Open
smups opened this issue May 25, 2023 · 4 comments
Open

Rust codegen is broken #25173

smups opened this issue May 25, 2023 · 4 comments
Labels

Comments

@smups
Copy link

smups commented May 25, 2023

The current rust printer is broken, it does not output valid rust code. It sometimes places floats where only integers are allowed and vice versa.

In rust, 2*arg[1].powi(2) does not compile because 2 is not a float, 2.0 is. 2.0*arg[1].powi(2.0) also does not compile, because powi expects an integer, not a float.

Currently, sympy does not do this correctly. For example:
afbeelding
is printed as:

2*m.powi(2)*(φ - φ0)*(4*m.powi(4)*φ.powi(3)*(φ - φ0) + ψ0.powi(2)*(φ.powi(2) - 3)*(φ.powi(2) + 1))(φ*(4*m.powi(4)*φ.powi(2)*(φ - φ0).powi(2) + ψ0.powi(2)*(φ.powi(2) + 1).powi(2))*(2*V0 + m.powi(2)*(φ - φ0).powi(2) - 2*ψ*ψ0))

Python 3.11.3
Sympy 1.11.1
Fedora 38, virtual environment

@bjodah
Copy link
Member

bjodah commented May 27, 2023

Thank you for reporting this @smups . I think some of the issues can be fixed in the RustCodePrinter itself, others might require a preprocessing step (which rust_code could perform by default). I don't have any real experience of writing in rust (beyond working my way through their tutorial a few years back). Would you mind providing a few (minimal) examples highlighting the different failure modes, and what you would have expected instead? Those examples will be useful as test cases when someone tries to fix this.

@smups
Copy link
Author

smups commented May 27, 2023

@bjodah Yeah sure. The main issue that comes to mind is that rust does not support implicit conversion of numeric types. Sympy sometimes mixes floats and integer types in the output code. Things like: 2*f where f is supposed to be a float.

I believe it's currently ambiguous what type each parsed sympy symbol should have. I think it might be most straightforward to allow the user to specify the output type of the whole expression (with some sane default like f64) and work back from there, making sure that all the numeric constants are parsed as the appropriate type (either float or int) and that the correct functions are used.

For example, calling rust_code with rust_type="f64" (or something like that) converts all constants to float literals (so 2.0 instead of 2) etc.

One complication (perhaps for a later revision) is that rust has a pow function that mixes integers and floats and that not all math functions are implemented for all types.

Here's an overview of which mathematics functions are available in rust's std:

Float only:

(arguments and return type are all floats)

  • trig functions: (a)cos(h), (a)sin(h), (a)tan(h), sin_cos and atan2
  • roots: sqrt, cbrt
  • exponents: exp, exp2, exp_m1, powf
  • logarithms: ln, ln_1p, log, log10, log2
  • misc: signum, abs, round, recip, div_euclid, rem_euclid, hypot, floor

Integer only:

(arguments and return type are all integers)

  • exponents: pow
  • logarithms: ilog, ilog10, ilog2
  • misc: signum, abs, div_euclid, rem_euclid, abs_diff

Then there's pow(i)...

  • pow(self, exp: u32) -> Self calculates self^exp where Self can be any numerical type. It can be called like 2.0.pow(2). It is a lot faster than the powf function.
  • powi(self, exp: i32) -> Self calculates self^exp where Self is a float and exp is an i32. It is not implemented for integer types. It can be called like 2.0.powi(-2). This function is generally faster than powf, but slower than pow.

It might be necessary to do some type conversions because most math functions are only implemented for the floating-point types. In rust you can do this with type1 as type2. Some conversions might do surprising things, but all conversions are well-defined (NaN becomes 0, values too large/small for the output type will be cropped to the max/min value of the output type).

Some examples of expected behaviour:

a, b = sympy.symbols("a b")
expr1 = (2*a**2 - b).nsimplify()
expr2 = sympy.sqrt(sympy.sqrt(a) - b)
sympy.printing.rust_code(expr1)
sympy.printing.rust_code(expr1)
2.0*a.pow(2) - b //expr1
(a.sqrt() - b).sqrt() //expr2

with the (proposed) rust_type option:

sympy.printing.rust_code(expr1, rust_type="i64")
sympy.printing.rust_code(expr2, rust_type="i64")
2*a.pow(2) - b //expr1
((a as f64).sqrt() - b as f64).sqrt().round() as i64

@bjodah
Copy link
Member

bjodah commented May 28, 2023

Thank you for expanding on this. I think the proper way forward is to add Function subclasses (e.g. powi) in a new module (sympy.codegen.rust). That file would also be the proper destination for utility functions that transforms the expression tree to include these rust specific functions.

As for float vs. integer, I was hoping that using Symbol('k', integer=True) could be useful to indicate what variables are expected to be integers. Then the user would need to specify a choice of what floating point type to use for symbols where integer is not set to True (default would probably be f64) and what type to use for symbols with Integer set to True. The Fortran and C printer uses the printer setting type_aliases for this. Do you think that mechanism is appropriate for rust as well?

@smups
Copy link
Author

smups commented May 30, 2023

As for float vs. integer, I was hoping that using Symbol('k', integer=True) could be useful to indicate what variables are expected to be integers. Then the user would need to specify what floating point type to use for symbols where integer is not set to True (default would probably be f64) and what type to use for symbols with Integer set to True. The Fortran and C printer uses the printer setting type_aliases for this. Do you think that mechanism is appropriate for rust as well?

Yes, I think that would be a great solution. So to properly parse a sympy expression we would need:

  • the types of all the symbols (integer, float, complex - throw an error for matrix symbols?)
  • type_aliases to go from sympy symbol types -> rust types (integer -> i64, float -> f64 etc.)
  • a user-defined output type (perhaps derived from the types of the symbols? I.e. if all symbols are integers, then the output of the expression should be an integer as well)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants