# Python multiprocessing with shared memory

I've fought the past 22 hours with multiprocessing to get a nice context manager to work with a couple of workers on WINDOWS 11.

Here are the lessons learned.

- The usecases for `Pool.map` is limited to repeated calls on the same data source. I still don't understand how to use it with shared_memory.
- [Shared memory arrays](https://docs.python.org/3/library/multiprocessing.shared_memory.html#module-multiprocessing.shared_memory) are very effective if you can control the work done by the workers. F.ex. give task ranges like `Pool.map(f, starmap(tasks))` in this [example](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.pool.Pool.starmap)
- There is a really good example [here](https://docs.python.org/3/library/multiprocessing.shared_memory.html#multiprocessing.shared_memory.SharedMemory.size).
- Shared arrays can outlive the workers, so they can create them and return `shm.name` to the main process for later usage. Just don't close the array.
- Multiprocessing on windows isn't hopeless, though there are some parts of the documentation that is omitted on python.org. I don't blame them.
- Getting tracebacks from the workers to main was easy thanks to the [`traceback` module](https://docs.python.org/3/library/traceback.html#traceback.print_exc).
- Using [`exec`](https://docs.python.org/3/library/functions.html#exec) in the worker allows me to practically do anything.
- Tracking and visualising progress using a `task queue` and a `result queue` was helped by [`tqdm`](https://pypi.org/project/tqdm/)
- Creating a context manager only required `__enter__` and `__exit__`.
- The workers could be spawned at any time using [multiprocessing.Process](https://docs.python.org/3/library/multiprocessing.html#the-process-class).
- Moving tasks and results was easy using the [multiprocessing.Queue](https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue), just be aware that the exception handler needs [`queue.Empty`](https://docs.python.org/3/library/queue.html#queue.Empty) from the `queue` module - not from multiprocessing.


Here is the whole solution:

In [None]:
import io
import traceback
import queue
import time
import tqdm
import multiprocessing
from multiprocessing import shared_memory
import numpy as np


class TaskManager(object):
    def __init__(self) -> None:
        self.tq = multiprocessing.Queue()  # task queue for workers.
        self.rq = multiprocessing.Queue()  # result queue for workers.
        self.pool = []
        self.tasks = {}  # task register for progress tracking
        self.results = {}  # result register for progress tracking
    
    def add(self, task):
        if not isinstance(task, dict):
            raise TypeError
        if not 'id' in task:
            raise KeyError("expect task to have id, to preserve order")
        task_id = task['id']
        if task_id in self.tasks:
            raise KeyError(f"task {task_id} already in use.")
        self.tasks[task_id] = task
        self.tq.put(task)
    
    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()
        self.tasks.clear()
        self.results.clear()

    def start(self):
        self.pool = [Worker(name=str(i), tq=self.tq, rq=self.rq) for i in range(2)]
        for p in self.pool:
            p.start()
        while not all(p.is_alive() for p in self.pool):
            time.sleep(0.01)

    def execute(self):
        t = tqdm.tqdm(total=len(self.tasks))
        while len(self.tasks) != len(self.results):
            try:
                result = self.rq.get_nowait()
                self.results[result['id']] = result
            except queue.Empty:
                time.sleep(0.01)
            t.update(len(self.results))
        t.close()
        
    def stop(self):
        self.tq.put("stop")
        while all(p.is_alive() for p in self.pool):
            time.sleep(0.01)
        print("all workers stopped")
        self.pool.clear()
  

class Worker(multiprocessing.Process):
    def __init__(self, name, tq, rq):
        super().__init__(group=None, target=self.update, name=name, daemon=False)
        self.exit = multiprocessing.Event()
        self.tq = tq  # workers task queue
        self.rq = rq  # workers result queue
        self._quit = False
        print(f"Worker-{self.name}: ready")
                
    def update(self):
        while True:
            try:
                task = self.tq.get_nowait()
            except queue.Empty:
                time.sleep(0.01)
                continue
            
            if task == "stop":
                print(f"Worker-{self.name}: stop signal received.")
                self.tq.put_nowait(task)  # this assures that everyone gets it.
                self.exit.set()
                break
            error = ""
            try:
                exec(task['script'])
            except Exception as e:
                f = io.StringIO()
                traceback.print_exc(limit=3, file=f)
                f.seek(0)
                error = f.read()
                f.close()

            self.rq.put({'id': task['id'], 'handled by': self.name, 'error': error})            


if __name__ == "__main__":  # REQUIRED ON WINDOWS.

    # Create shared_memory array for workers to access.
    a = np.array([1, 1, 2, 3, 5, 8])
    shm = shared_memory.SharedMemory(create=True, size=a.nbytes)
    b = np.ndarray(a.shape, dtype=a.dtype, buffer=shm.buf)
    b[:] = a[:]

    task = {
        'id':1,
        'address': shm.name, 'type': 'shm', 
        'dtype': a.dtype, 'shape': a.shape, 
        'script': f"""# from multiprocssing import shared_memory - is already imported.
existing_shm = shared_memory.SharedMemory(name='{shm.name}')
c = np.ndarray((6,), dtype=np.{a.dtype}, buffer=existing_shm.buf)
c[-1] = 888
existing_shm.close()
"""}

    tasks = [task]
    for i in range(4):
        task2 = task.copy()
        task2['id'] = 2+i
        task2['script'] = f"""existing_shm = shared_memory.SharedMemory(name='{shm.name}')
c = np.ndarray((6,), dtype=np.{a.dtype}, buffer=existing_shm.buf)
c[{i}] = 111+{i}  # DIFFERENT!
existing_shm.close()
time.sleep(0.1)  # Added delay to distribute the few tasks amongst the workers.
"""
        tasks.append(task2)
    
    with TaskManager() as tm:
        for task in tasks:
            tm.add(task)
        tm.execute()

        for v in tm.results.items():
            print(v)

    # Alternative "low level usage":
    # tm = TaskManager()
    # tm.add(task)
    # tm.start()
    # tm.execute()
    # tm.stop()
    print(b, f"assertion that b[-1] == 888 is {b[-1] == 888}")  
    print(b, f"assertion that b[-1] == 888 is {b[1] == 111}")  
    
    shm.close()
    shm.unlink()


Output:
```
Worker-0: ready
Worker-1: ready

(1, {'id': 1, 'handled by': '0', 'error': ''})
(2, {'id': 2, 'handled by': '0', 'error': ''})
(3, {'id': 3, 'handled by': '1', 'error': ''})
(4, {'id': 4, 'handled by': '0', 'error': ''})
(5, {'id': 5, 'handled by': '1', 'error': ''})
Worker-0: stop signal received.
Worker-1: stop signal received.
all workers stopped
[111 112 113 114   5 888] assertion that b[-1] == 888 is True 
[111 112 113 114   5 888] assertion that b[-1] == 888 is False
```

The code above will probably not run in a notebook, but you can copy-paste it to a script and execute it. Just take notice of the imports (I wrote this on python 3.9.6)

