New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
parameter server strategy #8
base: master
Are you sure you want to change the base?
Conversation
huangrunhui
commented
May 10, 2021
•
edited
edited
- ps_strategy
- add ps_strategy to jax example
- add Typing in ps_strategy.py, allreduce_strategy.py and base_strategy.py
- strategy and jax operator save/load states
distml/strategy/ps_strategy.py
Outdated
|
||
|
||
class ParameterServerStrategy(BaseStrategy): | ||
"""Strategy that trains a model via collective AllReduce. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change this docstring summary ?
distml/strategy/ps_strategy.py
Outdated
training_operator_cls, | ||
operator_config=None, | ||
initialization_hook=None, | ||
num_workers=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_worker
distml/strategy/ps_strategy.py
Outdated
|
||
assert num_ps | ||
self.num_ps = num_ps | ||
self.num_workers = num_workers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here... don't use the plural form
assert num_ps | ||
self.num_ps = num_ps | ||
self.num_workers = num_workers | ||
self.num_cpus_per_server = num_cpus_per_server |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here and following
distml/strategy/ps_strategy.py
Outdated
ray.get([server.set_params.remote(this_shard_ref)]) | ||
|
||
def _start_workers(self): | ||
"""Create worker(actor), maybe need worker group to manager these workers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rewrite this docstring
""" | ||
# TODO (Hao): infer the per-replica batch size here... | ||
|
||
# so here we get two set of params that will be passed around: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can clean this comment as it is redundent with those I left in AllReduceStrategy
distml/strategy/ps_strategy.py
Outdated
} | ||
|
||
# Should we make two groups for worker and server? | ||
self.worker_group = DataParallelGroup(**workergroup_init_args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is strange. Is this the same DataParallelGroup
with the one in AllReduceStrategy?
If yes -- then fine
If not -- is there any way we can share the same class? If it is hard then we should at least use a different class name?
distml/strategy/ps_strategy.py
Outdated
self.server_group.start_actors( | ||
self.num_ps) # server at the last num_ps processes. | ||
|
||
worker_rets = self.worker_group.test_connection() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are testing connection necessary? if not, probably move it to DEBUG mode.
|
||
def setup_operator(self): | ||
# figure out the signature of training_operator_cls later. | ||
self.training_operator = self.training_operator_cls( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure whether we should setup the whole operator on the server side? One drawback is that this will take a lot of GPU memory?