Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch to core for RFC: Sparse Domain Isolation for Supporting Large-scale Sparse Weights #41371

Closed
wants to merge 7 commits into from

Conversation

rhdong
Copy link
Member

@rhdong rhdong commented Jul 14, 2020

This is a patch to core for RFC: Sparse Domain Isolation for Supporting large-scale Sparse Weight
Please visit tensorflow/community#237

@rhdong rhdong requested a review from annarev as a code owner July 14, 2020 09:17
@google-ml-butler google-ml-butler bot added the size:XL CL Change Size:Extra Large label Jul 14, 2020
@gbaned gbaned self-assigned this Jul 14, 2020
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Jul 14, 2020
@rhdong
Copy link
Member Author

rhdong commented Jul 14, 2020

@yuefengz @tanzhenyu @byronyi @alextp Hi, this is the code patch for RFC of Sparse Domain Isolation for Supporting large-scale Sparse Weights Training , and hope you have time to help review, thank you!

Copy link

@wjwx wjwx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

follow

@FanGhost
Copy link

follow

@tanzhenyu tanzhenyu requested review from tanzhenyu and removed request for annarev July 15, 2020 15:01
@alextp
Copy link
Contributor

alextp commented Jul 17, 2020

AFAICT it's possible to have the dynamic embedding ops entirely in a third-party repo, and we just need the trainable interface part of this PR to be in core TF. Is that true?

@tanzhenyu
Copy link
Contributor

AFAICT it's possible to have the dynamic embedding ops entirely in a third-party repo, and we just need the trainable interface part of this PR to be in core TF. Is that true?

That is true -- and I have the same comments. Let's keep what needs to be changed in lookup table, and leave the rest of it in the new SIG/repo

@rhdong
Copy link
Member Author

rhdong commented Jul 18, 2020

@alextp @tanzhenyu Yes, we need the trainable interface part of this PR to be in core especially the part of optimizer.py and TrainableWrapper which are the key to compatibility with all native optimizers without requiring extend them one by one, and I believe it's inappropriate and very difficult to be spilt them out in design considerations.

cached_value=cached_value)

def update_op(self):
return self.params.upsert(self.ids, self.read_value(False))
Copy link

@yejw5 yejw5 Jul 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If multi workers update the same embeddings, only one update will take effect? Or maybe one worker read the embeddings for training, and return grad with a long time delay, the pre-trained embeddings may be covered by this update?

Copy link
Member Author

@rhdong rhdong Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If multi workers update the same embeddings, only one update will take effect? Or maybe one worker read the embeddings for training, and return grad with a long time delay, the pre-trained embeddings may be covered by this update?

Good question. That's a common problem in asynchronous training. To fix that we read again (from hash tables) when applying gradients to variables and slots: here

Copy link

@yejw5 yejw5 Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another question, does this support synchronous training?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another question, does this support synchronous training?

Yes, The RFC is compatible with all distributed strategy of TensorFlow, not only PS-Worker mode.

Copy link
Contributor

@liyinhgqw liyinhgqw Jul 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will the local variable on each replicas sync with each other and push the gradients to lookup hash table in terms of synchronous training?
Plus, as one more time of I/O has been introduced due to local variables, will it cause training speed degradation?

Copy link
Member Author

@rhdong rhdong Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will the local variable on each replicas sync with each other and push the gradients to lookup hash table in terms of synchronous training?
Plus, as one more time of I/O has been introduced due to local variables, will it cause training speed degradation?

AFAICT, the strategy you mentioned is hybrid-parallelism, this paper maybe helpful : https://arxiv.org/abs/1909.04823

@evanzhen
Copy link
Contributor

Is there any corresponding patch in TF-Serving ? How to deal with the problem of huge memory used by hash table on a single machine ?

@byronyi
Copy link
Contributor

byronyi commented Jul 21, 2020

@lilao might have better ideas on this for TF Serving.

Is there any corresponding patch in TF-Serving ? How to deal with the problem of huge memory used by hash table on a single machine ?

@rhdong rhdong closed this Jul 21, 2020
PR Queue automation moved this from Assigned Reviewer to Closed/Rejected Jul 21, 2020
@rhdong rhdong reopened this Jul 21, 2020
PR Queue automation moved this from Closed/Rejected to Assigned Reviewer Jul 21, 2020
@rhdong rhdong changed the title Patch to core for RFC: Sparse Domain Isolation for Supporting large-scale Sparse Weight Patch to core for RFC: Sparse Domain Isolation for Supporting Large-scale Sparse Weights Jul 27, 2020
@gbaned gbaned requested review from tanzhenyu and removed request for tanzhenyu July 29, 2020 17:39
@gbaned gbaned added the awaiting review Pull request awaiting review label Jul 29, 2020
@gbaned gbaned requested review from tanzhenyu and removed request for tanzhenyu August 13, 2020 16:20
@fff-2013
Copy link

This feature will be helpful for our recommend system, hope to use as soon as possible, follow.

@google-cla
Copy link

google-cla bot commented Oct 22, 2020

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla
Copy link

google-cla bot commented Oct 22, 2020

All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter.

We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only @googlebot I consent. in this pull request.

Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the cla label to yes (if enabled on your project).

ℹ️ Googlers: Go here for more info.

@google-cla google-cla bot added cla: yes and removed cla: no labels Oct 28, 2020
@mihaimaruseac
Copy link
Collaborator

There are several conflicts. Can you solve them, please?

@gbaned gbaned added stat:awaiting response Status - Awaiting response from author and removed awaiting review Pull request awaiting review labels Oct 29, 2020
@gbaned
Copy link
Contributor

gbaned commented Nov 5, 2020

@rhdong Can you please resolve conflicts? Thanks!

"""

partition_index = self.partition_fn(keys, self.shard_num)
keys_partitions, _ = _partition(keys, partition_index, self.shard_num)
Copy link
Contributor

@evanzhen evanzhen Nov 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose we have keys in range [0, 10000] and shard_num = 8, then partition_0 will got keys [0, 8, 16, 24, ...., 10000], and MutableHashTableOfTensors will store these keys in a std::unordered_map container.
However, unordered_map uses collision chaining to resolve hash collisions and slot_index = hash(key) % bucket_count, suppose bucket_count was huge like 2^20, thus keys' slot_index will be [0, 8, 16, 24, ...., 10000], that means we only use 1/shard_num slots of bucket_count in unordered_map, this may leader to huge hash collision, and finally leader to poor performance.
Suggestion:

  1. provide a int/int64 hash function like murmur hash
  2. one-to-one mapping keys to another format like new_keys = keys / shard_num

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rangjiaheng ,thanks for your comments, that's really a problem and I will consider your suggestion, but before that maybe you can customize a partitioner for Variable to avoid this problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rangjiaheng Any update on this PR? Please. Thanks!

@kkimdev kkimdev removed their request for review November 12, 2020 02:29
@tensorflowbutler tensorflowbutler removed the stat:awaiting response Status - Awaiting response from author label Nov 22, 2020
@gbaned
Copy link
Contributor

gbaned commented Nov 24, 2020

@rhdong Can you please resolve conflicts? Thanks!

@gbaned gbaned added the stat:awaiting response Status - Awaiting response from author label Dec 4, 2020
@gbaned
Copy link
Contributor

gbaned commented Dec 14, 2020

@rhdong Any update on this PR? Please. Thanks!

@tensorflowbutler tensorflowbutler added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Dec 31, 2020
@tensorflowbutler
Copy link
Member

It has been 15 days with no activity and the awaiting response label was assigned. Is this PR still valid? Assigning the stalled label. Please comment to reassure me that this is still being worked on.

@gbaned
Copy link
Contributor

gbaned commented Jan 25, 2021

I'm going to go ahead and close this PR, because it seems to have stalled. If you're still interested in pursing this (and responding to my comments), please feel free to reopen!

@gbaned gbaned closed this Jan 25, 2021
PR Queue automation moved this from Assigned Reviewer to Closed/Rejected Jan 25, 2021
@xiaogaozi
Copy link

Looks like the author will continue the work on recommenders-addons project.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes size:XL CL Change Size:Extra Large stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response Status - Awaiting response from author
Projects
PR Queue
  
Closed/Rejected
Development

Successfully merging this pull request may close these issues.

None yet