Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor multiprocessing and multiple inputs #4087

Open
martinjrobins opened this issue May 14, 2024 · 6 comments
Open

refactor multiprocessing and multiple inputs #4087

martinjrobins opened this issue May 14, 2024 · 6 comments

Comments

@martinjrobins
Copy link
Contributor

Description

Refactor the current multiprocessing implementation to push this responsibility onto the solver.

Related issues #3987, #3713, #2645, #2644, #3910

Motivation

At the moment users can request multiple parallel solves by passing in a list of inputs

solver.solve(t_eval, inputs = [ {'a': 1}, {'a': 2} ])

For all solvers apart from the Jax solver, this uses the python multiprocessing library to run these solves in parallel. However, this causes issues on Windows #3987, and this could probably be much more efficient if done at the solver level rather than up in Python.

Possible Implementation

I would propose that if a list of inputs of size n is detected, the model equations are duplicated n times before being passed to the solver. All the solvers we use have in-built parallel functionality that can be utilised to run this in parallel, and once we have a decent gpu solver this could be done on the gpu. After the solve is finished then the larger state vector will be split apart and passed to multiple Solution objects.

Additional context

At the moment this snippet captures how multiprocessing is implemented:

            ninputs = len(model_inputs_list)
            if ninputs == 1:
                new_solution = self._integrate(
                    model,
                    t_eval[start_index:end_index],
                    model_inputs_list[0],
                )
                new_solutions = [new_solution]
            else:
                if model.convert_to_format == "jax":
                    # Jax can parallelize over the inputs efficiently
                    new_solutions = self._integrate(
                        model,
                        t_eval[start_index:end_index],
                        model_inputs_list,
                    )
                else:
                    with mp.get_context(self._mp_context).Pool(processes=nproc) as p:
                        new_solutions = p.starmap(
                            self._integrate,
                            zip(
                                [model] * ninputs,
                                [t_eval[start_index:end_index]] * ninputs,
                                model_inputs_list,
                            ),
                        )
                        p.close()
                        p.join()
@martinjrobins
Copy link
Contributor Author

Main drawback of this approach is that it would be very inefficient if you wanted to use dense matrices for the jacobian, not sure if that would be an issue as we use sparse by default......

@martinjrobins
Copy link
Contributor Author

another downside would be if a user wanted to rapidly change the number of inputs, in which case you would have to setup new equations every time which would be quite inefficient

@martinjrobins
Copy link
Contributor Author

any comments welcome on this approach, the above is my current thinking but happy to consider alternatives before I dive into making changes: @BradyPlanden @rtimms @jsbrittain @valentinsulzer.

@BradyPlanden
Copy link
Member

Sounds like a good plan. Thanks @martinjrobins.

One area that might require some thought is ensuring that threads are not over-provisioned. I.e. if users have a higher level python multiprocess job interacting with the solver parallelisation. A kwarg that allows the user to set the number of threads used by the solver might do the trick.

@agriyakhetarpal
Copy link
Member

I would also recommend the same thing. If we were to switch to add support for Emscripten/Pyodide sometime soon, we would need to disable threading or put up guards around the threading or multiprocessing imports so that import pybamm and the rest of the functionality works well under a WASM runtime (wherein a browser cannot start a new thread or run subprocesses).

@martinjrobins
Copy link
Contributor Author

The solvers would use openmp for running in parallel so this would be controlled by setting an environment variable OMP_NUM_THREADS, I'll make sure to document this functionality for users also using multiprocessing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants