I see four methods to store internal state:

* Storing state:
    * centralized
    * decentralized
* Communication
    * asynchronous
    * synchronous

Here's what these two decisions mean:

1. Asynchronous, centralized (e.g., [Hogwild](https://arxiv.org/abs/1106.5730))
    * Will require holding one model vector on one parameter server
    * Will mean that workers pull/push this model at any time
2. Synchronous, centralized (e.g., mini-batch SGD with one-to-all)
    * Same as (1), but waits for model to be fully updated before workers can pull model
3. Synchronous, decentralized (e.g., mini-batch SGD w/ all-reduce)
    * Every worker decides how to communicate
    * Every worker holds onto a model
4. Asynchronous, decentralized (e.g., [Hogwild++](https://ieeexplore.ieee.org/abstract/document/7837887/))
    * Not every worker communicates with every other worker
    * e.g., a bunch of point communications

Good reference of speeds up Paleo for minibatch SGD: https://talwalkarlab.github.io/paleo/ (and choose "strong scaling" instead of "weak scaling"). This shows that for mini-batch SGD **all-reduce is >5x faster than all-to-one.**

* Centralized:
    * pros: large models
    * cons: slower, functions
* Decentralized
    * pros: faster, classes.
    * advantage: decentralized is a superset of centralized (communication strategy can be customized).
    * cons: large model
* sync:
    * pros: simple. no assumptions.
    * cons: maybe slower?
* async:
    * pros: maybe faster
    * cons: not simple. assumptions.
    
I'm inclined to go with decentralized + sync.

## Related work
* Ray is most similar to Dask. https://ray.readthedocs.io/en/latest/example-parameter-server.html
* Horovod, an MPI wrapper with nice API: https://github.com/uber/horovod
* Tensorflow parameter server

In [1]:
from distributed import Client
client = Client()
client

0,1
Client  Scheduler: tcp://127.0.0.1:52000  Dashboard: http://127.0.0.1:8787/status,Cluster  Workers: 8  Cores: 8  Memory: 17.18 GB


## Sketches

### Asynchronous

In [2]:
class PS:
    def __init__(self):
        self._store = {'model': 0}
        
    def get(self, key):
        return self._store[key]
    
    def set(self, key, value, i=0):
        self._store[key] += value
    
    def store(self):
        return self._store

def update(ps, i=0):
    for _ in range(4):
        model = ps.get('model').result()
        new_model = model + 1
        ps.set('model', new_model, i=i)

In [3]:
futures = client.submit(PS, actor=True)
ps = client.gather(futures)
ps

<Actor: PS, key=PS-e4f5ac0b-4001-49de-bef6-fa43bf2faf71>

In [4]:
futures = [client.submit(update, ps, i=i) for i in range(4)]
client.gather(futures)

[None, None, None, None]

In [5]:
ps.store().result()

{'model': 5759}

### Synchronous

In [2]:
from distributed import Client
client = Client()

In [3]:
import numpy as np

class Worker:
    def __init__(self, model, n_models, worker_id):
        self.model = model
        self.grads = []
        self.n_models = n_models
        self.worker_id = worker_id
        
    def _model(self):
        return self.model
    
    def compute(self):
        self.grad = self.worker_id
        self.grads += [self.grad]
        return True
    
    def send(self, worker):
        worker.recv(self.grad)
        
    def recv(self, grad):
        self.grads += [grad]
    
    def reduce(self):
        assert len(self.grads) == 4
        self.model += sum(self.grads)
        self.grads = []

In [4]:
model = 0
n_models = 4
futures = [client.submit(Worker, model, n_models, i, actor=True) 
           for i in range(n_models)]
workers = client.gather(futures)
workers

[<Actor: Worker, key=Worker-33a35517-bb78-4708-9eb0-5340eb6978b7>,
 <Actor: Worker, key=Worker-aaf61086-c506-45d4-b4ee-f8dfdc27c75c>,
 <Actor: Worker, key=Worker-4af885e4-0652-4dac-abb8-aad458beeaca>,
 <Actor: Worker, key=Worker-d881d528-2334-42ae-a12b-b2be311e3026>]

In [7]:
for k in range(4):
    # calculate
    futures = [worker.compute() for worker in workers]
    client.gather(futures)

    # communicate
    # (updating model hapepns internally; worker knows when fully received model)
    # (this could be an all-reduce implementation if desired)
    futures = []
    for i, w1 in enumerate(workers):
        for j, w2 in enumerate(workers):
            if i == j:
                continue
            else:
                futures += [w1.send(w2)]
    client.gather(futures)

    # update model
    futures = [worker.reduce() for worker in workers]
    client.gather(futures)

    # quick test; make sure all models are the same
    futures = [worker._model() for worker in workers]
    if k == 0:
        print("result of client.gather() are ActorFutures?", client.gather(futures))
        
    models = [m.result() for m in futures]
    print("models =", models)
    assert all(model == models[0] for model in models)

result of client.gather() are ActorFutures? [<ActorFuture>, <ActorFuture>, <ActorFuture>, <ActorFuture>]
models = [54, 54, 54, 54]
models = [60, 60, 60, 60]
models = [66, 66, 66, 66]
models = [72, 72, 72, 72]


In [10]:
grads = [worker.compute() for worker in workers]
grads

[<ActorFuture>, <ActorFuture>, <ActorFuture>, <ActorFuture>]

In [11]:
sum(grads)

TypeError: unsupported operand type(s) for +: 'int' and 'ActorFuture'