Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/spmd/test_xla_virtual_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,35 @@ def test_outbound_data_metrics(self):

def test_non_tensor_scalar(self):
sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1))
# TODO(JackCaoG)currently, execution will only happen if there is at least one
# tensor on non-spmd:0 device.
t1 = torch.randn(3, 3, device=xm.xla_device())
# tensor will have device as `SPMD:0` in c++
xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)],
xm.xla_device(),
input_sharding=sharding_spec)[0]
# we will transfer 0.5 as a device_data to the 'SPMD:0' device, need to make sure
# that virtual device can handle this case.
xt2 = xt1 / 0.5
xm.mark_step(wait=True)
torch.allclose(xt2.cpu(), xt1.cpu() / 0.5)

def test_mark_step_on_virtual_device(self):
xm.mark_step()
sharding_spec = xs.ShardingSpec(self._get_mesh((1, self.n_devices)), (0, 1))
# TODO(JackCaoG)currently, execution will only happen if there is at least one
# tensor on non-spmd:0 device.
t1 = torch.randn(3, 3, device=xm.xla_device())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this t1 about?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so.. her eis another weird bug.. if both xt1 and xt2 on spmd:0, the mark_step will see that there is no tensor on TPU:0 device and it will skip the execution.. I think it is easier for me to rewrite whole virtual device instead of fixing these bugs to suit both spmd:0 and TPU:0 exist at the same time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good luck... I wonder how was I able to do GPT-2 experiments all the time...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am surprised that code runs at all lol.

# tensor will have device as `SPMD:0` in c++
xt1 = xm.send_cpu_data_to_device([torch.randn(3, 3)],
xm.xla_device(),
input_sharding=sharding_spec)[0]
xt2 = xt1 / 0.5
xm.mark_step(wait=True)
# after mark_step, xt2 should be materalized
self.assertNotIn('aten::div',
torch_xla._XLAC._get_xla_tensor_debug_info(xt2))


if __name__ == '__main__':
test = unittest.main()
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ auto XLAGraphExecutor::DeviceContextArena::Get() -> DeviceContextArena* {
std::vector<XLATensorPtr> XLAGraphExecutor::DeviceContextArena::GetLiveTensors(
const torch::lazy::BackendDevice* device) {
std::vector<XLATensorPtr> tensors;
torch::lazy::BackendDevice virtual_device = GetVirtualDevice();
auto fn = [&](DeviceContext* devctx) {
std::lock_guard<std::mutex> lock(devctx->lock);
for (auto& uid_wptr : devctx->tensors_data) {
Expand All @@ -92,6 +93,8 @@ std::vector<XLATensorPtr> XLAGraphExecutor::DeviceContextArena::GetLiveTensors(
}
};
ForAllDeviceContexts(fn, device);
// TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode.
ForAllDeviceContexts(fn, &virtual_device);
return tensors;
}

Expand Down Expand Up @@ -502,7 +505,10 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
tsl::profiler::TraceMeLevel::kInfo);
runtime::util::Unique<torch::lazy::BackendDevice> unique_device;
for (size_t i = 0; i < tensors.size(); ++i) {
unique_device.set(tensors[i]->GetDevice());
// TODO(JackCaoG): all tensors should be on spmd:0 in SPMD mode.
if (tensors[i]->GetDevice().toString() != "SPMD:0") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So currently we allow SPMD:0 and TPU:0 to coexist? I wonder what's the deal with the code later on that deals with unique_device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will move all tensors on SPMD:0 under spmd context. Planning on work on that later this week. Currently the code that go through

tensor.to(xla_device)

will be on TPU:0 and tensors go through

# data loader
`xm.send_cpu_data_to_device`

will be on SPMD:0, this is really confusing and make logic nasty.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but we have code in this method later on use unique_device? And it looks like it's important: coll.device = *unique_device; How does it work today?

And then if SPMD:0 and TPU:0 really coexist in this method, shouldn't unique_device.set crash?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, so here is where things become messy. What happened is that we will pass TPU:0 as the device to Compile and ScheduleSyncTensorgraph on xla_graph_executor. ScheduleSyncTensorsGraph will ignore the device and check if graph is sharded and execute on all device if that's the case

https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L999-L1000

It is a very broken design now. Different function is checking different thing. I want to unify all those later.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err... whatever, I will approve that and then wait for your redesign.

unique_device.set(tensors[i]->GetDevice());
}
}
SyncTensorCollection coll;
if (!unique_device) {
Expand Down