Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function overloading #1027

Merged
merged 43 commits into from Jan 20, 2022
Merged

Function overloading #1027

merged 43 commits into from Jan 20, 2022

Conversation

WardBrian
Copy link
Member

@WardBrian WardBrian commented Nov 8, 2021

This PR enables function overloading of user defined functions.

Current status:

  • Overloading of UDFs
  • Overloading a stan_math function name: Needs the changes in Fix issue with function shadowing #1011
  • Passing overloaded functions to map_rect, ODEs, etc: Should work but needs testing
  • Properly sort to fewest promotions when there are multiple options

Submission Checklist

Release notes

User defined functions can now be overloaded. Multiple definitions of the same function name are allowed if the arguments are different in each definition.

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to
license the submitted work under the BSD 3-clause license (https://opensource.org/licenses/BSD-3-Clause)

@WardBrian WardBrian marked this pull request as draft November 8, 2021 16:51
@WardBrian WardBrian added the feature New feature or request label Nov 9, 2021
@spinkney
Copy link
Contributor

Hey Brian, when do you expect this pr to be merged?

@WardBrian
Copy link
Member Author

After the holidays we should be able to merge #1011 and then I’ll take this back up

@spinkney
Copy link
Contributor

👍 thanks! I think I want to wait until announcing the "grand opening" of the helpful repo until this is in. It will be nice to have this multiple dispatch for vectorized versions of functions

@WardBrian
Copy link
Member Author

Code generation for this PR should now be in a good place (thanks @rok-cesnovar for all the C++ help).

There is a very tricky issue remaining on the typechecker level, which has to do with the higher order functions like reduce_sum. At the moment, a call

reduce_sum(foo, ....)

just naively looks up the identifier foo in the function table and then typechecks it. This is a problem when there are multiple values for foo, some of which are valid and others are may not be, and it is tricky to resolve. Currently it uses the most recently defined signature, but obviously this isn't desired behavior. The typechecker gives a pretty uninformative error too, reporting that the wrong number of arguments has been provided.

It's obviously best if I can make it so this 'just works', but if not I need to gracefully disallow overloaded functions as higher-order arguments

@WardBrian
Copy link
Member Author

I've added a test which shows the problem at the moment. If you disable the typechecker and compile the model, you get this .hpp:

https://gist.github.com/WardBrian/9c947f405a641df664e4bb676b2a2114

But current behavior is that the typechecker blocks it.

@rok-cesnovar
Copy link
Member

Example of a reduce_sum model that would run and work with shadowing allowed:

functions {
  real fun(array[] real y_slice, int start, int end, real m) {
    return sum(y_slice) * m;
  }

  real fun(array[] real y_slice, int start, int end) {
    return sum(y_slice);
  }
}
transformed data {
  int N = 100;
  array[N] real data_y = ones_array(N);

  real sum_1 = reduce_sum(fun, data_y, 1);
  print(sum_1);
  real sum_2 = reduce_sum(fun, data_y, 1, 5);
  print(sum_2);
}
parameters {
   real y;
}
transformed parameters {
   array[N] real param_y = ones_array(N);

   real p_sum_1 = reduce_sum(fun, param_y, 1);
   print(y, " - ", p_sum_1);
   real p_sum_2 = reduce_sum(fun, param_y, 1, y);
   print(y, " -- ", p_sum_2);
}
model {
   y ~ std_normal();
}

@WardBrian
Copy link
Member Author

WardBrian commented Jan 11, 2022

Here's what remains:

  1. Write a test for variadic odes with overloaded functions
  2. Write a test using map_rect with overloading functions
  3. Support map_rect having an overloaded function in typechecking
  4. I'd love to clean up how the typechecking is done once it all works, especially if we ever plan on having user-defined higher order functions, we'd need a more general approach for that than we currently have. That might be left for later PRs though.

I realized there is only one signature which map_rect ever supports, so it doesn't necessarily need extra support for overloading.

@WardBrian WardBrian mentioned this pull request Jan 11, 2022
3 tasks
@WardBrian WardBrian changed the title [WIP] Function overloading Function overloading Jan 12, 2022
@WardBrian
Copy link
Member Author

This needs a few more tests, especially for higher-order functions, but otherwise I'm happy with where it is right now.

Leaving a note for posterity:

Variadic functions in the typechecker will search for the appropriate function. This logic could be extended to 'regular' functions, but currently the only other higher order function is map_rect, which only has one (and exactly one) signature it can accept, so the old logic works fine.

This will need to change if/when we allow user-defined higher order variadic functions. We're a long way off from such a day, so I didn't go through the effort here, but the code here would be able to extend to that case when the time comes (I hope).

Copy link
Collaborator

@nhuurre nhuurre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something's wrong with variadic argument checking. This model crashes:

functions {
  real foo(real[] x, int s, int end, real z, int k) {
    return 0.0;
  }
  real foo(real[] x, int s, int end, int z, real k) {
    return 0.0;
  }
}
transformed data {
  real x = reduce_sum(foo, {1,2,3,4,5,6}, 1, 2, 3);
}

src/frontend/SignatureMismatch.ml Outdated Show resolved Hide resolved
src/frontend/Typechecker.ml Outdated Show resolved Hide resolved
src/frontend/Semantic_error.ml Show resolved Hide resolved
@WardBrian
Copy link
Member Author

Something's wrong with variadic argument checking. This model crashes: ...

This was a pretty simple mistake on my part. Whenever there was an ambiguous match, I was relying on the existing error cases of the variadic functions to handle it. My test case just so happened to have at least one signature which didn't match, so it output that one. But, if the only signatures were ambiguous ones, then there was no error to report, hence the crash.

This is fixed and they now output errors for ambiguous matches specifically, like with 'normal' functions.

@WardBrian
Copy link
Member Author

I think I've sucessfully merged the changes from #1091, so this should be mergeable as soon as it's approved without breaking other PRs

@nhuurre
Copy link
Collaborator

nhuurre commented Jan 19, 2022

Almost forgot, one more question: this PR doesn't change anything in Analysis_and_optimization.Optimize so I assume this completely breaks function inlining?

@WardBrian
Copy link
Member Author

Almost forgot, one more question: this PR doesn't change anything in Analysis_and_optimization.Optimize so I assume this completely breaks function inlining?

A good point I hadn't considered. Based on the comments in Optimize.ml, I'm pretty sure function inlining is broken anyway - it says

  (* We only add the first definition for each function to the inline map.
     This will make sure we do not inline recursive functions.
     We also don't want to add any function declaration (as opposed to
     definitions), because that would replace the function call with a Skip.
  *)

But recursive functions are possible without forward declarations, which I think will sneak by this and break it if the comment is to be believed.

At any rate, I can make it no worse than present by changing this so it stops trying to inline a certain function if it encounters a second definition of the same function.

@WardBrian
Copy link
Member Author

@nhuurre I made that small change just to not cause any further issues, and opened #1096 after confirming the existing code doesn't work

Copy link
Collaborator

@nhuurre nhuurre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worse than present

The standard we all should aspire to.

Copy link
Contributor

@SteveBronder SteveBronder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly a few comments about style but overall I think the code is good!

src/frontend/SignatureMismatch.ml Outdated Show resolved Hide resolved
src/frontend/SignatureMismatch.ml Outdated Show resolved Hide resolved
src/frontend/SignatureMismatch.ml Outdated Show resolved Hide resolved
src/stan_math_backend/Expression_gen.ml Show resolved Hide resolved
src/stan_math_backend/Expression_gen.ml Outdated Show resolved Hide resolved
test/integration/good/code-gen/cpp.expected Outdated Show resolved Hide resolved
test/integration/good/code-gen/cpp.expected Outdated Show resolved Hide resolved
@WardBrian
Copy link
Member Author

Since the latest tests passed I just wanted to ping @SteveBronder and @nhuurre for a final check before merging. I'll do the merge later today if there are no other comments.

Thank you both for all the review, and @rok-cesnovar for some help getting started on this PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants