Skip to content

Conversation

@jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Jun 7, 2023

This change adds a few utility functions to support distributed checkpointing. The following changes are included:

  • Add sharding_type to XLAShardedTensor to get the ShardingType
  • Add wrap_if_sharded to convert torch.Tensor into XLAShardedTensor if the underlying data is sharded.
  • Remove devices parameter from _get_local_shard_indices and instead always return the shard indices in the order of the shards
  • Clamp the start bound of the index slices to the tensor's size.
  • Add a setter to the unpadded_data property of XLAShard

@jonb377 jonb377 added the distributed SPMD and other distributed things. label Jun 7, 2023
@jonb377 jonb377 requested a review from yeounoh June 7, 2023 01:19
@jonb377 jonb377 force-pushed the jonbolin-checkpoint-restore branch from 5872316 to d27834a Compare June 7, 2023 01:21
@jonb377
Copy link
Collaborator Author

jonb377 commented Jun 7, 2023

cc @yashs97

@jonb377 jonb377 force-pushed the jonbolin-checkpoint-restore branch from d27834a to 318074b Compare June 7, 2023 02:10
// Clamp the end of the slice to the tensor shape to accurately reflect
// Clamp the slice bounds to the tensor shape to accurately reflect
// the shard size without padding.
int start = std::min(n_j * shard_shape[j], tensor_shape[j]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you catch a bug or this is a more like a safeguard?

If n_j * shard_shape[j] is going to be > tensor_shape[j], it means (n_j + 1) * shard_shape[j] is certainly going to be larger than tensor_shape[j]. Therefore, for that scenario, start will be equal to end and equals to tensor_shape[j]. And that slice seems meaningless.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call this a latent bug, but it wasn't breaking anything because torch indexing handles negative-length indices as though they were empty. It just breaks the expectation that stop - start reflects the size of the unpadded shard, which we rely on in distributed checkpointing.

You're right - these index slices will end up empty, but this is the desired outcome when the shard consists entirely of padding.

@jonb377 jonb377 force-pushed the jonbolin-checkpoint-restore branch from 318074b to cfcd622 Compare June 8, 2023 00:24
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

return ShardingType::TUPLE;
case xla::OpSharding::OTHER:
// OTHER sharding can indicate either PARTIAL or TILED sharding.
return sharding.replicate_on_last_tile_dim() ? ShardingType::PARTIAL
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems pretty hacky. But I guess we don't have other ways round?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that we distinguish partial replication as a different sharding type whereas XLA treats partial and tiled as the same type OTHER. @yeounoh could you confirm?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually the correct way, the compiler treats TILED and PARTIAL the same as the OTHER type. The differences between the two would be how the tile shards are assigned to difference devices.

@JackCaoG
Copy link
Collaborator

JackCaoG commented Jun 9, 2023

@jonb377 @alanwaketan is this pr ready to merge?

@jonb377
Copy link
Collaborator Author

jonb377 commented Jun 9, 2023

Yes, I'll merge after TPU CI finishes

@jonb377 jonb377 merged commit c7fe0f9 into master Jun 9, 2023
@jonb377 jonb377 deleted the jonbolin-checkpoint-restore branch June 9, 2023 21:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

distributed SPMD and other distributed things.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants