Skip to content

Conversation

@jonb377
Copy link
Collaborator

@jonb377 jonb377 commented May 18, 2023

As a follow-up from #5016, we should place the shards on CPU by default. This makes it clear to the user that the shards will not remain up-to-date with the underlying XLAShardedTensor, and it eases interoperability of the shards' data with other tensors.

Expanding on the second point, the shards are backed by PjRtData when returned from _get_local_shards. In SPMD mode, we expect all computation inputs to have PjRtShardedData handles, which makes the on-device shards incompatible with SPMD execution. In order to use the shards in an on-device computation, they will now need to be transferred back to the device, which implies implicit replication.

@jonb377 jonb377 added the distributed SPMD and other distributed things. label May 18, 2023
@jonb377 jonb377 force-pushed the jonbolin-cpu-shard branch from c777c1a to 6811b8e Compare May 19, 2023 20:01
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

# Shards on the devices are materialized/available after the lazy
# execution of the SPMDPartitioned HLO graph. Each XLAShard points
# to torch.Tensor (xla::device_data). The shards represent a snapshot
# to torch.Tensor (xla::device_data). The shards represent a snapshot on CPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. Could we update the comment here, "execution of the partitioned HLO graph. Each XLAShard points to torch.Tensor. The shards represent a snapshot on CPU, detached from the global tensor." ?

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a minor comment, LGTM.

@jonb377 jonb377 force-pushed the jonbolin-cpu-shard branch 2 times, most recently from 1b605d7 to 9ceadd3 Compare May 20, 2023 00:54
@jonb377 jonb377 force-pushed the jonbolin-cpu-shard branch from 9ceadd3 to d967760 Compare May 20, 2023 01:47
@jonb377 jonb377 merged commit 0464095 into master May 22, 2023
@jonb377 jonb377 deleted the jonbolin-cpu-shard branch May 22, 2023 17:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

distributed SPMD and other distributed things.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants