Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions torchrec/distributed/dist_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,33 @@ def forward(self, tensors: List[torch.Tensor], cat_dim: int) -> torch.Tensor:
Here we assume input tensors are:
[TBE_output_0, ..., TBE_output_(n-1)]
"""
B = tensors[0].size(1 - cat_dim)
# Handle empty shards case (can happen in column-wise sharding)
if not tensors or len(tensors) == 0:
# Return empty tensor if no tensors provided
return torch.empty(0, 0, dtype=torch.float, device=self.current_device)

# Check if we are in TorchScript mode first to avoid global variable access issues
if torch.jit.is_scripting() or torch.jit.is_tracing():
# In TorchScript or JIT tracing mode, use all tensors and let FBGEMM handle empties
tensors_to_use = tensors
else:
if torch.fx._symbolic_trace.is_fx_tracing():
# During FX tracing, include all tensors to avoid control flow issues
tensors_to_use = tensors
else:
# Normal execution: filter out empty tensors
non_empty_tensors = []

for t in tensors:
if t.numel() > 0 and t.size(cat_dim) > 0:
non_empty_tensors.append(t)

tensors_to_use = non_empty_tensors if non_empty_tensors else tensors

# Use the first tensor to determine batch size
B = tensors_to_use[0].size(1 - cat_dim)
return torch.ops.fbgemm.merge_pooled_embeddings(
tensors,
tensors_to_use,
B,
self.current_device,
cat_dim,
Expand Down
32 changes: 32 additions & 0 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,26 @@ def _emb_module_forward(
lengths_or_offsets: torch.Tensor,
weights: Optional[torch.Tensor],
) -> torch.Tensor:
# Check if total embedding dimension is 0 (can happen in column-wise sharding)
total_D = sum(table.local_cols for table in self._config.embedding_tables)

if total_D == 0:
# For empty shards, return tensor with correct batch size but 0 embedding dimension
# Use tensor operations that are FX symbolic tracing compatible
if self.lengths_to_tbe:
# For lengths format, batch size equals lengths tensor size
# Create [B, 0] tensor using zeros_like and slicing
dummy = torch.zeros_like(lengths_or_offsets, dtype=torch.float)
return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor
else:
# For offsets format, batch size is one less than offset size
# Use tensor slicing to create batch dimension
batch_tensor = lengths_or_offsets[
:-1
] # Remove last element to get batch size
dummy = torch.zeros_like(batch_tensor, dtype=torch.float)
return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor

kwargs = {"indices": indices}

if self.lengths_to_tbe:
Expand Down Expand Up @@ -600,6 +620,18 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
else:
values, offsets, _ = _unwrap_kjt(features)

# Check if total embedding dimension is 0
total_D = sum(table.local_cols for table in self._config.embedding_tables)

if total_D == 0:
# For empty shards, return tensor with correct batch size but 0 embedding dimension
# Use tensor operations that are FX symbolic tracing compatible
# For offsets format, batch size is one less than offset size
# Use tensor slicing to create batch dimension
batch_tensor = offsets[:-1] # Remove last element to get batch size
dummy = torch.zeros_like(batch_tensor, dtype=torch.float)
return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor

if self._emb_module_registered:
return self.emb_module(
indices=values,
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,9 +571,9 @@ def test_cw(
def test_uneven_cw(self, weight_dtype: torch.dtype, device_type: str) -> None:
num_embeddings = 64
emb_dim = 512
dim_1 = 63
dim_1 = 0
dim_2 = 128
dim_3 = 65
dim_3 = 128
dim_4 = 256
local_size = 4
world_size = 4
Expand Down
Loading