Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worker_state is lost between map calls if the input is too large #46

Closed
ghost opened this issue May 30, 2022 · 9 comments · Fixed by #52
Closed

worker_state is lost between map calls if the input is too large #46

ghost opened this issue May 30, 2022 · 9 comments · Fixed by #52
Assignees
Labels
enhancement New feature or request

Comments

@ghost
Copy link

ghost commented May 30, 2022

This is very related to #15.

Since your awesome release v2.3.0 (which fix #15), I've been using mpire a lot, I love it :)


But I'm having a problem, very similar to #15.

In the following script, each worker get to deal with several numbers i, which they keep in state. Then I retrieve these values in another call.

from mpire import WorkerPool


N = 12
W = 4


def set_state(w_state, i):
    w_state[i] = 2 * i + 1
    return None


def get_state(w_state, i):
    return w_state[i]


if __name__ == "__main__":
    pool = WorkerPool(n_jobs=W, use_worker_state=True, keep_alive=True)
    s = N // W

    pool.map(set_state, list(range(N)), iterable_len=N, n_splits=s)
    r = pool.map(get_state, list(range(N)), iterable_len=N, n_splits=s)

    print(r)

If I run this script, everything works perfectly, I get my expected output :

[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23]


Now, if I change N (the number of tasks) to a higher number (like 128) and I run the script again, I get the following error :

KeyError: '\n\nException occurred in Worker-2 with the following arguments:\nArg 0: 32\nTraceback (most recent call last):\n File "/root/miniconda3/envs/housing_sb3/lib/python3.8/site-packages/mpire/worker.py", line 322, in _run_safely\n results = func()\n File "/root/miniconda3/envs/housing_sb3/lib/python3.8/site-packages/mpire/worker.py", line 276, in _func\n results = func(args)\n File "/root/miniconda3/envs/housing_sb3/lib/python3.8/site-packages/mpire/worker.py", line 415, in _helper_func_with_idx\n return args[0], self._call_func(func, args[1])\n File "/root/miniconda3/envs/housing_sb3/lib/python3.8/site-packages/mpire/worker.py", line 442, in _call_func\n return func(args)\n File "housing_drl/sb3/swag.py", line 14, in get_state\n return w_state[i]\nKeyError: 32\n'

It's the exact same error as in #15, so it seems the worker state is somehow erased ?


@sybrenjansen Do you have any idea what's the problem ? Did I do something wrong in my script ?

@ghost
Copy link
Author

ghost commented May 30, 2022

I just noticed that if I set the number of worker (W) to a higher number (say, 40), it produces the right output.

I don't know what's going on behind the scenes, but this seems more like a dirty-workaround.


I'd like to be able to process my 128 tasks with a specific number of workers (like 4) to not engorge my CPUs.

@sybrenjansen
Copy link
Owner

Let's break this down.

You first set the state using set_state for the different workers. You start with 4 workers and 12 tasks, so the range(12) is distributed over the different workers. This means that worker 0 could have numbers 0, 1, 2, worker 1 has 3, 4, 5, etc.

You're using n_splits=3 (because N // W == 3), which means the data gets split in 3 chunks. This also means that one of the workers won't have anything to do. E.g.:

In [36]: def inspect_state(w_state, i):
    ...:     return w_state
    ...: 

In [37]: pool.map(inspect_state, range(4))
Out[37]: 
[{0: 1, 1: 3, 2: 5, 3: 7},
 {4: 9, 5: 11, 6: 13, 7: 15},
 {8: 17, 9: 19, 10: 21, 11: 23},
 {}]

The important thing to notice here, however, is that each worker got at most 1 chunk and these chunks are in order. Worker 0 got chunk 0, worker 1 got chunk 1, worker 2 got chunk 2.

When you increase N to 128, N // W becomes 32. This means each worker will get multiple chunks to process. Importantly, chunks will be passed on to the workers based on who happens to be the fastest! If no worker has completed anything yet, chunks are passed on to the workers in order. This usually means the first W chunks are passed in order (but this is not guaranteed!). After that, perhaps worker 0 is done with the first chunk the fasest, so this worker will get the next chunk (e.g., chunk 4). Then, worker 3 could be done and it will get chunk 5. Then perhaps worker 1 is done and gets chunk 6. Perhaps worker 2 is a bit slow and worker 0 finishes its second chunk before it and will get chunk 7, and so on.

In other words, there's no deterministic order in workers and what chunks they receive. And in your example script you're assuming just this. The chunks are delivered to different workers in between your set_state and get_state calls.

I'm not sure what your actual use case is and what you need to store in the worker state, but here are a few things to do in order to remedy this:

  • Make sure the worker state is equal for each worker so each worker can access all keys (but might be more memory intensive)
  • If it's fine for you then you can set n_splits=W, such that each worker will only need to process 1 chunk. Then, the only thing you need to do is make sure a worker doesn't finish before another worker has started. Depending on your task this might already be true, but if you're not sure you could make use of a multiprocessing.Barrier object at the end of the functions you're calling with pool.map. A Barrier makes sure your function only continues when all processes have reached that point.
  • Ask me very nicely to add an additional flag to MPIRE which, when passed, makes sure tasks are always distributed in order to the workers, instead of based on completion time. This is actually not that hard to do

@ghost
Copy link
Author

ghost commented Jun 1, 2022

@sybrenjansen Hey thanks for the super detailed answer ! That's very helpful.

Solution 1 will not be possible for my use-case : I use mpire to parallelize RL, so each worker has its own environment with its own state. Replicating each worker's environment to every other worker would take more memory, but also more time as you need to step through each environment with the appropriate action.

Solution 2 seems nice ! I'll try and let you know.

Solution 3 is of course awesome, a cool new feature !
But it seems like I'm alone wanting this feature, so don't bother, I'll try by myself with Solution 2 first !

@ghost
Copy link
Author

ghost commented Jun 2, 2022

Updated example, working perfectly with multiprocessing.Barrier :

from mpire import WorkerPool
from multiprocessing import Barrier


N = 1024
W = 4
B = Barrier(W)


def set_state(w_state, i):
    w_state[i] = 2 * i + 1
    B.wait()
    return None


def get_state(w_state, i):
    B.wait()
    return w_state[i]


if __name__ == "__main__":
    pool = WorkerPool(n_jobs=W, use_worker_state=True, keep_alive=True)

    pool.map(set_state, list(range(N)), iterable_len=N, n_splits=W)
    r = pool.map(get_state, list(range(N)), iterable_len=N, n_splits=W)

    print(r)

Thanks again @sybrenjansen for the super clear explanation 🙏

@ghost ghost closed this as completed Jun 2, 2022
@ghost
Copy link
Author

ghost commented Jun 2, 2022

Actually, let me re-open this issue...

Barrier is working fine, but only if the number of tasks is "round".

If the number of tasks is not "round" (try the above script with N=14 for example), then each worker will have to deal with a different number of tasks. In that case the code hangs (processes are waiting for W workers to be at the barrier, but only a subset are there).

@ghost ghost reopened this Jun 2, 2022
@ghost
Copy link
Author

ghost commented Jun 2, 2022

Ok, so I created a working example, by using 2 different barriers : One for the general case, and one for the last indexes (which might not be round).

(Note, I tried to use Semaphore, but it's not an appropriate use case so it didn't work)

I wrapped everything in a new class so it's easier to use :

from mpire import WorkerPool
from mpire.utils import chunk_tasks
from multiprocessing import Barrier


N = 1025
W = 4


class AdaptativeBarrier:
    def __init__(self, n_workers, n_tasks):
        self.general_barrier = Barrier(n_workers)
        self.last_barrier = Barrier(n_tasks % n_workers)

        # Compute the indexes of the last pass
        chunks = chunk_tasks(range(n_tasks), n_splits=n_workers)
        n = n_tasks // n_workers
        self.last_idx = [c[-1] for c in chunks if len(c) > n]

    def wait(self, i):
        if i in self.last_idx:
            return self.last_barrier.wait()
        else:
            return self.general_barrier.wait()


B = AdaptativeBarrier(W, N)


def set_state(w_state, i):
    w_state[i] = 2 * i + 1
    B.wait(i)
    return None


def get_state(w_state, i):
    B.wait(i)
    return w_state[i]


if __name__ == "__main__":
    pool = WorkerPool(n_jobs=W, use_worker_state=True, keep_alive=True)

    pool.map(set_state, list(range(N)), iterable_len=N, n_splits=W)
    r = pool.map(get_state, list(range(N)), iterable_len=N, n_splits=W)

    print(r)

I'll keep this issue open, as I think having an additional flag in mpire could be easier from a user perspective.

But feel free to close @sybrenjansen !

@sybrenjansen
Copy link
Owner

Let's keep it open.

Like I said, it's not that much work and I do think it can be a useful feature for others as well. I think I can find some free time next week or so.

@sybrenjansen sybrenjansen added the enhancement New feature or request label Jun 3, 2022
@sybrenjansen sybrenjansen self-assigned this Jun 3, 2022
@sybrenjansen
Copy link
Owner

Had a bit less free time last few weeks than I thought. This is just a reminder that I didn't forgot about it, it's still on my to-do list

@sybrenjansen
Copy link
Owner

Now available in v2.5.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant