Skip to content

Conversation

@yeounoh
Copy link
Contributor

@yeounoh yeounoh commented Nov 22, 2022

This fixes a couple of bugs in AssignIrValue and ExecuteReplicated for sharding, to enable mark_step() with SPMD. Note that this doesn't address sharding propagation through views, which will be handled later.

@yeounoh yeounoh added the distributed SPMD and other distributed things. label Nov 22, 2022
@yeounoh yeounoh self-assigned this Nov 22, 2022
if (!ir_value) {
ir_value = CreateTensorNode(CurrentXlaData(), /*read_only=*/false);
}
XLA_CHECK(ir_value.node != nullptr) << "Tyring to access a null cursor";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, so this line is to force the sharding on the old IR before it is being replaced? I am confuse because we did not clear the sharding after the new ir_value being assigned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intent is to force the same sharding on the new IR (input), per a common request like this, AssignIrValue(torch::lazy::Value()). Did I get it reversed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok, I read it backward, you re trying to populate the sharding spec to the new IR being assigned.

@yeounoh yeounoh force-pushed the spmd_mark_step branch 2 times, most recently from b083799 to 63773a9 Compare November 23, 2022 22:07
self.assertEqual(sharding_spec, torch_xla._XLAC._get_xla_sharding_spec(xt))

xt.add_(1) # inplace update
xm.mark_step() # resets IR value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the mark_step here? After the xt.add_(1), if you print the IR and HLO by

print(torch_xla._XLAC._get_xla_tensors_text([xt]))
print(torch_xla._XLAC._get_xla_tensors_hlo([xt]))

you should see the sharding spec on the output?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works either way, before or after -- and checks for the sharding annotation.

Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM, approve it to unblock the merge once test is green. You might need to rebase.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

distributed SPMD and other distributed things.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants