Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parameter server strategy #8

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open

parameter server strategy #8

wants to merge 14 commits into from

Conversation

huangrunhui
Copy link
Collaborator

@huangrunhui huangrunhui commented May 10, 2021

  1. ps_strategy
  2. add ps_strategy to jax example
  3. add Typing in ps_strategy.py, allreduce_strategy.py and base_strategy.py
  4. strategy and jax operator save/load states



class ParameterServerStrategy(BaseStrategy):
"""Strategy that trains a model via collective AllReduce.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this docstring summary ?

training_operator_cls,
operator_config=None,
initialization_hook=None,
num_workers=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_worker


assert num_ps
self.num_ps = num_ps
self.num_workers = num_workers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here... don't use the plural form

assert num_ps
self.num_ps = num_ps
self.num_workers = num_workers
self.num_cpus_per_server = num_cpus_per_server
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here and following

ray.get([server.set_params.remote(this_shard_ref)])

def _start_workers(self):
"""Create worker(actor), maybe need worker group to manager these workers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite this docstring

"""
# TODO (Hao): infer the per-replica batch size here...

# so here we get two set of params that will be passed around:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can clean this comment as it is redundent with those I left in AllReduceStrategy

}

# Should we make two groups for worker and server?
self.worker_group = DataParallelGroup(**workergroup_init_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is strange. Is this the same DataParallelGroup with the one in AllReduceStrategy?
If yes -- then fine
If not -- is there any way we can share the same class? If it is hard then we should at least use a different class name?

self.server_group.start_actors(
self.num_ps) # server at the last num_ps processes.

worker_rets = self.worker_group.test_connection()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are testing connection necessary? if not, probably move it to DEBUG mode.


def setup_operator(self):
# figure out the signature of training_operator_cls later.
self.training_operator = self.training_operator_cls(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether we should setup the whole operator on the server side? One drawback is that this will take a lot of GPU memory?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants