-
-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify broadcast logic for control messages #2501
Conversation
@zhuohan123 looks much cleaner! But since most of the tensors share a dtype and are the same shape apart from size of one dimension, I was those could be concatenated and just pass the offsets. So that the number of broadcasts done can be minimized. This wouldn't be quite as nice and generic as what you've done but might not look too bad? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhuohan123 Thanks for cleaning this up! It looks much nicer to me. Please take a look at my minor comments.
class TensorMetadata: | ||
"""A simple class to hold tensor metadata.""" | ||
|
||
def __init__(self, tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, tensor): | |
def __init__(self, tensor: torch.Tensor): |
self.size = tensor.size() | ||
|
||
def __repr__(self): | ||
return (f"TensorMetadata(dtype={self.dtype}, size={self.size})") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return (f"TensorMetadata(dtype={self.dtype}, size={self.size})") | |
return f"TensorMetadata(dtype={self.dtype}, size={self.size})" |
def broadcast_tensor_dict(tensor_dict: Dict[Any, Union[torch.Tensor, | ||
Any]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tensor_dict
type is a bit weird: Shouldn't it be Optional[Dict]
? Also, how Union[torch.Tensor, Any]
is different from Any
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks I indeed missed an Optional
. I use Union[torch.Tensor, Any]
to explicitly emphasize that we treat torch.Tensor
and other types differently in this function.
@@ -104,3 +109,67 @@ def broadcast_object_list(obj_list, src=0): | |||
# Broadcast. | |||
torch.distributed.broadcast_object_list(obj_list, src=src) | |||
return obj_list | |||
|
|||
|
|||
class TensorMetadata: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can using a class instead of raw data type incur any additional overhead? If so, can we use named tuple instead?
Yeah some of the tensors are |
In this PR, we simplify and unify the previous nasty control message broadcast logic.