Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[rpc] special case tensor type check when getting RRef #33582

Closed
wants to merge 4 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion torch/csrc/distributed/rpc/rref_context.cpp
Expand Up @@ -175,8 +175,22 @@ c10::intrusive_ptr<RRef> RRefContext::getOrCreateRRef(
auto& rrefId = rrefForkData.rrefId_;
auto& forkId = rrefForkData.forkId_;
if (ownerId == getWorkerId()) {
// We have found the rref through the rrefId
auto ownerRRef = getOwnerRRef(rrefId);
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
// Now double check if the two types are matched
//
// Why we are special casing the check for tensor type here?
// this is because tensor types might get specialized on tensors when
// we pass inputs to the function, i.e. TensorType can filled with
// specific shape info, requires_grad info, etc. so the OwerRRef we
// found might already have those infos, but the `type` we passed in
// here is a plain TensorType, they are not equal relationship. In RPC
// we don't care the difference as we ser/de with just the plain TensorType.
if(type == TensorType::get()) {
TORCH_INTERNAL_ASSERT(ownerRRef->type()->isSubtypeOf(TensorType::get()));
} else {
TORCH_INTERNAL_ASSERT(ownerRRef->type() == type);
}
return ownerRRef;
} else {
return createUserRRef(ownerId, rrefId, forkId, type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about createUserRRef case when type is not exactly matched

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc createUserRRef will not try to find if there's an existing UserRRef in the RRefContext, so there would be no type match problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean whether there will be an issue if one user created user rref with subtypeOf(TensorType::get()), then the user shared this user rref to another user, and another user will create user rref here with a TensorType::get(). iiuc, the owner rref will have subtypeof(TensorType::get()).

So some user rref will have slightly different type based on above, I'm wondering whether this will be an issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plain TensorType is always a subtype of specialized SubTensorType, the only reason why this is failing is that we have a Type equal assertion here, I didn't see this in other places. For the case that you described, the forked UserRRef will holding the plain TensorType, which is subtype compatible with the SubTensorType, which should be safe. In fact, we can only get into your described case when we fist run the ScriptFunction locally, then call rpc.remote on this ScriptFunction again remotely. But when we run the ScriptFunction in remote, we shouldn't preserve the Specialized SubTensorType information (because that's the information we get from the local run). So I think the fix here should be enough.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for explaining! Sounds good, would you please also add this to comment?

Expand Down