Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba speedup for wiring + log potentials #133

Merged
merged 19 commits into from Apr 8, 2022

Conversation

antoine-dedieu
Copy link
Contributor

@antoine-dedieu antoine-dedieu commented Apr 5, 2022

This PR is the continuation of #129 and part of our efforts to speed up the adding of FactorGroups and the wiring compilation.

As #129 has moved most of the wiring computation to the FactorGroup level, we can now use numba for fast computation of these wirings

As a result:

  • adding factors for the RBM exp takes 3s, building run_bp takes 1s
  • adding factors for the convor exp takes 2s, building run_bp takes 1s

@antoine-dedieu antoine-dedieu changed the title WIP - Numba speedup Numba speedup for wiring + log potentials Apr 5, 2022
@codecov-commenter
Copy link

codecov-commenter commented Apr 5, 2022

Codecov Report

Merging #133 (04d62bb) into master (1c08295) will not change coverage.
The diff coverage is 100.00%.

@@            Coverage Diff            @@
##            master      #133   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files           13        13           
  Lines          917       950   +33     
=========================================
+ Hits           917       950   +33     
Impacted Files Coverage Δ
pgmax/groups/logical.py 100.00% <ø> (ø)
pgmax/factors/enumeration.py 100.00% <100.00%> (ø)
pgmax/factors/logical.py 100.00% <100.00%> (ø)
pgmax/fg/graph.py 100.00% <100.00%> (ø)
pgmax/fg/groups.py 100.00% <100.00%> (ø)
pgmax/fg/nodes.py 100.00% <100.00%> (ø)
pgmax/groups/enumeration.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1c08295...04d62bb. Read the comment docs.

@antoine-dedieu antoine-dedieu marked this pull request as ready for review April 5, 2022 00:49
Copy link
Contributor

@StannisZhou StannisZhou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation is valid, but one issue is to write a new factor group the developer needs to understand various bits of wiring compilation. This seems unnecessary. Instead, it would be best if we implement the wiring compilation once within the parent FactorGroup class and everything is automatically taken care of in new factor groups.

One possible way to do this:

  1. For each type of factor, implement a static method compile_wiring_numba which takes some inputs and do numba wiring compilation.
  2. For each type of factor, specify a property wiring_compilation_arguments, similar to the current inference_arguments in the wirings.
  3. In FactorGroup, add an optional use_numba=True flag to compile_wiring. If use_numba is True, we do something like:
wiring_compilation_arguments = {
	key: getattr(self, key) for key in self.factor_type.wiring_compilation_arguments
}
wiring = factor_type.compile_wiring_numba(
	vars_to_starts=vars_to_starts,
	**wiring_compilation_arguments
)

to get the wiring.
4. Get rid of all the customized wiring compilation implementations in the various factor groups.

In the process it would also be best if we can consolidate the number of arguments we need (for example if we make variables_for_factors as tuple of tuples we can get rid of factor_sizes and num_factors).

tests/factors/test_and.py Outdated Show resolved Hide resolved
pgmax/groups/enumeration.py Outdated Show resolved Hide resolved
pgmax/factors/enumeration.py Outdated Show resolved Hide resolved
pgmax/factors/enumeration.py Outdated Show resolved Hide resolved
pgmax/factors/enumeration.py Outdated Show resolved Hide resolved
pgmax/fg/groups.py Show resolved Hide resolved
pgmax/fg/groups.py Outdated Show resolved Hide resolved
@antoine-dedieu
Copy link
Contributor Author

@StannisZhou I am fine with the wiring_compilation_arguments but I do not think numba should be involved here.
A factor_type.compile_wiring may use an internal _compile_wiring_numba to speed things up, but that should not be specified at the FactorGroup level

@StannisZhou
Copy link
Contributor

Then we can change the existing compile_wiring function to be functional. Seems like with the wiring_compilation_arguments it wouldn't make things harder to use so I'm fine with that

@antoine-dedieu
Copy link
Contributor Author

antoine-dedieu commented Apr 6, 2022

@StannisZhou I have pushed a commit that should address your comments.
The two nice thing are that
(1) we only define the factor_group.compile_wiring once, in the parent class
(2) we do not need to match the Factor and FactorGroup arguments, because wiring is always called at the FactorGroup level (even when it is a SingleFactorGroup, which was my concern)

However there is something weird with having the compile_wiring_arguments at the Factors level (note: I had to make this @staticmethod because a @Property is not iterable) because now
(1) factor.compile_wiring_arguments() uses FactorGroup arguments
(2) we cannot call factor.compile_wiring()

Copy link
Contributor

@StannisZhou StannisZhou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main comment is to use inspect to get the necessary arguments. Other bits LGTM

pgmax/factors/enumeration.py Outdated Show resolved Hide resolved
pgmax/fg/groups.py Outdated Show resolved Hide resolved
pgmax/fg/groups.py Outdated Show resolved Hide resolved
pgmax/factors/logical.py Outdated Show resolved Hide resolved
pgmax/factors/logical.py Outdated Show resolved Hide resolved
pgmax/fg/nodes.py Outdated Show resolved Hide resolved
pgmax/fg/nodes.py Outdated Show resolved Hide resolved
pgmax/groups/enumeration.py Outdated Show resolved Hide resolved
Copy link
Contributor

@StannisZhou StannisZhou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more minor comment. LGTM otherwise! And remember to get coverage back to 100% after replacing the assert with raise

pgmax/factors/enumeration.py Outdated Show resolved Hide resolved
@StannisZhou StannisZhou linked an issue Apr 8, 2022 that may be closed by this pull request
@antoine-dedieu antoine-dedieu merged commit 8a31c9c into vicariousinc:master Apr 8, 2022
@antoine-dedieu antoine-dedieu deleted the numba_speedup branch April 8, 2022 17:36
)

@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
def _compile_var_states_numba(
Copy link
Contributor

@wlehrach wlehrach Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you make the caller allocate these arrays? In general t's cleaner and less likely to result in error to allocate return arrays inside numba rather that mutating a passed in array. You can get a very small optimization by re-using arrays between calls (so highly performance sensitive code it can be useful), but you're not doing that here. You can refer to dtype of incoming arrays as well and copy that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimize and parallelize structure compiling function
4 participants