Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support placing CW shards for same table on different devices for SQEC #2804

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
@@ -127,9 +127,10 @@ def get_device_from_parameter_sharding(
if len(set(device_type_list)) == 1:
return device_type_list[0]
else:
assert (
ps.sharding_type == "row_wise"
), "Only row_wise sharding supports sharding across multiple device types for a table"
assert ps.sharding_type in [
ShardingType.ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
return device_type_list


2 changes: 1 addition & 1 deletion torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
@@ -1116,7 +1116,7 @@ def __init__(
self._embedding_lookups_per_rank.append(
MetaInferGroupedEmbeddingsLookup(
grouped_configs=grouped_configs_per_rank[rank],
device=rank_device(device_type, rank),
device=device,
fused_params=fused_params,
shard_index=shard_index,
)
21 changes: 17 additions & 4 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
@@ -164,9 +164,10 @@ def get_device_from_parameter_sharding(
if len(set(device_type_list)) == 1:
return device_type_list[0]
else:
assert (
ps.sharding_type == "row_wise"
), "Only row_wise sharding supports sharding across multiple device types for a table"
assert ps.sharding_type in [
ShardingType.ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
], "Only row_wise or column_wise sharding supports sharding across multiple device types for a table"
return device_type_list


@@ -209,7 +210,12 @@ def create_infer_embedding_sharding(
if sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return InferCwSequenceEmbeddingSharding(sharding_infos, env, device)
return InferCwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
)
elif sharding_type == ShardingType.ROW_WISE.value:
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
@@ -231,6 +237,13 @@ def create_infer_embedding_sharding(
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return InferCwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
)
elif sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
else:
29 changes: 4 additions & 25 deletions torchrec/distributed/sharding/cw_sequence_sharding.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@
from typing import Any, Dict, List, Optional

import torch
from torchrec.distributed.dist_data import SeqEmbeddingsAllToOne
from torchrec.distributed.embedding_lookup import (
GroupedEmbeddingsLookup,
InferGroupedEmbeddingsLookup,
@@ -26,6 +25,7 @@
)
from torchrec.distributed.sharding.cw_sharding import BaseCwEmbeddingSharding
from torchrec.distributed.sharding.sequence_sharding import (
InferSequenceEmbeddingDist,
InferSequenceShardingContext,
SequenceShardingContext,
)
@@ -111,6 +111,7 @@ def create_lookup(
world_size=self._world_size,
fused_params=fused_params,
device=device if device is not None else self._device,
device_type_from_sharding_infos=self._device_type_from_sharding_infos,
)

def create_output_dist(
@@ -121,31 +122,9 @@ def create_output_dist(
device = device if device is not None else self._device
assert device is not None

dist_out = InferCwSequenceEmbeddingDist(
dist_out = InferSequenceEmbeddingDist(
device,
self._world_size,
self._device_type_from_sharding_infos,
)
return dist_out


class InferCwSequenceEmbeddingDist(
BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]
):
def __init__(
self,
device: torch.device,
world_size: int,
) -> None:
super().__init__()
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(
device=device, world_size=world_size
)

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
return self._dist(local_embs)
7 changes: 6 additions & 1 deletion torchrec/distributed/sharding/cw_sharding.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union

import torch
import torch.distributed as dist # noqa
@@ -70,6 +70,7 @@ def __init__(
device: Optional[torch.device] = None,
permute_embeddings: bool = False,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__(
sharding_infos,
@@ -81,6 +82,10 @@ def __init__(
if self._permute_embeddings:
self._init_combined_embeddings()

self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
device_type_from_sharding_infos
)

def _init_combined_embeddings(self) -> None:
"""
Grabs the embedding names and dims from TwEmbeddingSharder.
88 changes: 2 additions & 86 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
@@ -35,19 +35,13 @@
RwSparseFeaturesDist,
)
from torchrec.distributed.sharding.sequence_sharding import (
InferSequenceEmbeddingDist,
InferSequenceShardingContext,
SequenceShardingContext,
)
from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs
from torchrec.modules.utils import (
_fx_trec_get_feature_length,
_get_batching_hinted_output,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("_fx_trec_get_feature_length")


class RwSequenceEmbeddingDist(
BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]
@@ -164,84 +158,6 @@ def create_output_dist(
)


class InferRwSequenceEmbeddingDist(
BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]
):
def __init__(
self,
device: torch.device,
world_size: int,
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__()
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
device_type_from_sharding_infos
)
num_cpu_ranks = 0
if self._device_type_from_sharding_infos and isinstance(
self._device_type_from_sharding_infos, tuple
):
for device_type in self._device_type_from_sharding_infos:
if device_type == "cpu":
num_cpu_ranks += 1
elif self._device_type_from_sharding_infos == "cpu":
num_cpu_ranks = world_size

self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(
device, world_size - num_cpu_ranks
)

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
assert (
self._device_type_from_sharding_infos is not None
), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist"
if isinstance(self._device_type_from_sharding_infos, tuple):
assert sharding_ctx is not None
assert sharding_ctx.embedding_names_per_rank is not None
assert len(self._device_type_from_sharding_infos) == len(
local_embs
), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types"
non_cpu_local_embs = []
# Here looping through local_embs is also compatible with tracing
# given the number of looks up / shards withing ShardedQuantEmbeddingCollection
# are fixed and local_embs is the output of those looks ups. However, still
# using _device_type_from_sharding_infos to iterate on local_embs list as
# that's a better practice.
for i, device_type in enumerate(self._device_type_from_sharding_infos):
if device_type != "cpu":
non_cpu_local_embs.append(
_get_batching_hinted_output(
_fx_trec_get_feature_length(
sharding_ctx.features[i],
# pyre-fixme [16]
sharding_ctx.embedding_names_per_rank[i],
),
local_embs[i],
)
)
non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs)
index = 0
result = []
for i, device_type in enumerate(self._device_type_from_sharding_infos):
if device_type == "cpu":
result.append(local_embs[i])
else:
result.append(non_cpu_local_embs_dist[index])
index += 1
return result
elif self._device_type_from_sharding_infos == "cpu":
# for cpu sharder, output dist should be a no-op
return local_embs
else:
return self._device_dist(local_embs)


class InferRwSequenceEmbeddingSharding(
BaseRwEmbeddingSharding[
InferSequenceShardingContext,
@@ -297,7 +213,7 @@ def create_output_dist(
) -> BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]:
return InferRwSequenceEmbeddingDist(
return InferSequenceEmbeddingDist(
device if device is not None else self._device,
self._world_size,
self._device_type_from_sharding_infos,
97 changes: 95 additions & 2 deletions torchrec/distributed/sharding/sequence_sharding.py
Original file line number Diff line number Diff line change
@@ -8,14 +8,27 @@
# pyre-strict

from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Tuple, Union

import torch
from torchrec.distributed.embedding_sharding import EmbeddingShardingContext

from torchrec.distributed.dist_data import SeqEmbeddingsAllToOne
from torchrec.distributed.embedding_sharding import (
BaseEmbeddingDist,
EmbeddingShardingContext,
)
from torchrec.distributed.embedding_types import KJTList

from torchrec.modules.utils import (
_fx_trec_get_feature_length,
_get_batching_hinted_output,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable

torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("_fx_trec_get_feature_length")


class SequenceShardingContext(EmbeddingShardingContext):
"""
@@ -111,3 +124,83 @@ def record_stream(self, stream: torch.Stream) -> None:
self.bucket_mapping_tensor.record_stream(stream)
if self.bucketized_length is not None:
self.bucketized_length.record_stream(stream)


class InferSequenceEmbeddingDist(
BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]
):
def __init__(
self,
device: torch.device,
world_size: int,
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__()
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
device_type_from_sharding_infos
)
non_cpu_ranks = 0
if self._device_type_from_sharding_infos and isinstance(
self._device_type_from_sharding_infos, tuple
):
for device_type in self._device_type_from_sharding_infos:
if device_type != "cpu":
non_cpu_ranks += 1
elif self._device_type_from_sharding_infos == "cpu":
non_cpu_ranks = 0
else:
non_cpu_ranks = world_size

self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(
device, non_cpu_ranks
)

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
assert (
self._device_type_from_sharding_infos is not None
), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist"
if isinstance(self._device_type_from_sharding_infos, tuple):
assert sharding_ctx is not None
assert sharding_ctx.embedding_names_per_rank is not None
assert len(self._device_type_from_sharding_infos) == len(
local_embs
), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types"
non_cpu_local_embs = []
# Here looping through local_embs is also compatible with tracing
# given the number of looks up / shards withing ShardedQuantEmbeddingCollection
# are fixed and local_embs is the output of those looks ups. However, still
# using _device_type_from_sharding_infos to iterate on local_embs list as
# that's a better practice.
for i, device_type in enumerate(self._device_type_from_sharding_infos):
if device_type != "cpu":
non_cpu_local_embs.append(
_get_batching_hinted_output(
_fx_trec_get_feature_length(
sharding_ctx.features[i],
# pyre-fixme [16]
sharding_ctx.embedding_names_per_rank[i],
),
local_embs[i],
)
)
non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs)
index = 0
result = []
for i, device_type in enumerate(self._device_type_from_sharding_infos):
if device_type == "cpu":
result.append(local_embs[i])
else:
result.append(non_cpu_local_embs_dist[index])
index += 1
return result
elif self._device_type_from_sharding_infos == "cpu":
# for cpu sharder, output dist should be a no-op
return local_embs
else:
return self._device_dist(local_embs)
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.