Skip to content

Commit

Permalink
Support uneven batch sizes and uneven number of batches across ranks (p…
Browse files Browse the repository at this point in the history
…ytorch#283)

Summary:
X-link: facebookresearch/dlrm#283

Remove the constraint that ranks must iterate through batches of the exact same size for the exact same number of iterations.  Now each rank's input batch can be a different size containing a different number of samples, and each rank can forward pass or train fewer or more batches than other ranks.

Differential Revision: D40676549

fbshipit-source-id: 47174289e88d7d13339a9b16325b4275bc0aa628
  • Loading branch information
samiwilf authored and facebook-github-bot committed Oct 25, 2022
1 parent 8edff8f commit 6aa7dfd
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 27 deletions.
61 changes: 36 additions & 25 deletions torchrec/datasets/criteo.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ def get_file_idx_to_row_range(
lengths: List[int],
rank: int,
world_size: int,
start_row: int = 0,
last_row: Optional[int] = None,
) -> Dict[int, Tuple[int, int]]:
"""
Given a rank, world_size, and the lengths (number of rows) for a list of files,
Expand All @@ -296,14 +298,26 @@ def get_file_idx_to_row_range(
# All ..._g variables are globals indices (meaning they range from 0 to
# total_length - 1). All ..._l variables are local indices (meaning they range
# from 0 to lengths[i] - 1 for the ith file).

total_length = sum(lengths)
if last_row is None:
total_length = sum(lengths) - start_row
else:
total_length = last_row - start_row + 1
rows_per_rank = total_length // world_size
remainder = total_length % world_size

# Global indices that rank is responsible for. All ranges (left, right) are
# inclusive.
rank_left_g = rank * rows_per_rank
rank_right_g = (rank + 1) * rows_per_rank - 1
if rank < remainder:
rank_left_g = rank * (rows_per_rank + 1)
rank_right_g = (rank + 1) * (rows_per_rank + 1) - 1
else:
rank_left_g = (
remainder * (rows_per_rank + 1) + (rank - remainder) * rows_per_rank
)
rank_right_g = rank_left_g + rows_per_rank - 1

rank_left_g += start_row
rank_right_g += start_row

output = {}

Expand Down Expand Up @@ -734,34 +748,31 @@ def __init__(
}

def _load_data_for_rank(self) -> None:
if self.stage == "train":
file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
lengths=[
BinaryCriteoUtils.get_shape_from_npy(
path, path_manager_key=self.path_manager_key
)[0]
for path in self.dense_paths
],
rank=self.rank,
world_size=self.world_size,
)
elif self.stage in ["val", "test"]:
start_row, last_row = 0, None
if self.stage in ["val", "test"]:
# Last day's dataset is split into 2 sets: 1st half for "val"; 2nd for "test"
samples_in_file = BinaryCriteoUtils.get_shape_from_npy(
self.dense_paths[0], path_manager_key=self.path_manager_key
)[0]

dataset_start = 0
start_row = 0
dataset_len = int(np.ceil(samples_in_file / 2.0))

if self.stage == "test":
dataset_start = dataset_len
start_row = dataset_len
dataset_len = samples_in_file - dataset_len
segment_len = dataset_len // self.world_size
rank_start_row = dataset_start + self.rank * segment_len

rank_last_row = rank_start_row + segment_len - 1
file_idx_to_row_range = {0: (rank_start_row, rank_last_row)}
last_row = dataset_len - 1

file_idx_to_row_range = BinaryCriteoUtils.get_file_idx_to_row_range(
lengths=[
BinaryCriteoUtils.get_shape_from_npy(
path, path_manager_key=self.path_manager_key
)[0]
for path in self.dense_paths
],
rank=self.rank,
world_size=self.world_size,
start_row=start_row,
last_row=last_row,
)

self.dense_arrs, self.sparse_arrs, self.labels_arrs = [], [], []
for arrs, paths in zip(
Expand Down
73 changes: 71 additions & 2 deletions torchrec/models/dlrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

from typing import Dict, List, Optional, Tuple

import torch.fx as fx

fx.wrap("len")
import torch
from torch import nn
from torch import distributed as dist, nn
from torchrec.datasets.utils import Batch
from torchrec.modules.crossnet import LowRankCrossNet
from torchrec.modules.embedding_modules import EmbeddingBagCollection
Expand Down Expand Up @@ -555,6 +558,26 @@ def __init__(
layer_sizes=over_arch_layer_sizes,
device=dense_device,
)
try:
if not torch.distributed.is_initialized():
backend = "gloo"
if dense_device is not None and dense_device.type == "cuda":
backend = "nccl"
dist.init_process_group(backend=backend)
world_size = dist.get_world_size()
except ValueError:
world_size = 1
self.batch_sizes_gathered_list: List[torch.Tensor] = [
torch.tensor([0], dtype=torch.int, device=dense_device)
for _ in range(world_size)
]
self.table_size_tensor: torch._tensor.Tensor = torch.tensor(
[0], dtype=torch.int32, device=dense_device
)
self.batch_size_tensor: torch._tensor.Tensor = torch.tensor(
[0], dtype=torch.int32, device=dense_device
)
self.dense_device = dense_device

def forward(
self,
Expand All @@ -569,11 +592,57 @@ def forward(
Returns:
torch.Tensor: logits.
"""
embedded_dense = self.dense_arch(dense_features)
# Embedding Lookup.
# Pad the sparse input if necessary to maintain the same shape among all ranks.
# This rank's local batch size
self.batch_size_tensor[0] = len(dense_features)
if torch.distributed.is_initialized():
dist.all_gather(self.batch_sizes_gathered_list, self.batch_size_tensor)
# Max local batch size among all the ranks
max_batch_size = torch.max(torch.cat(self.batch_sizes_gathered_list))
sparse_features._values = sparse_features._values.reshape(
-1, int(self.batch_size_tensor[0])
)
# Number of embedding tables
self.table_size_tensor[0] = sparse_features._values.shape[0]
if max_batch_size - self.batch_size_tensor[0] > 0:
sparse_features._values = torch.cat(
(
sparse_features._values,
torch.zeros(
(
int(self.table_size_tensor[0]),
int(max_batch_size - self.batch_size_tensor[0]),
),
dtype=torch.int32,
device=self.dense_device,
),
),
dim=-1,
)
sparse_features = KeyedJaggedTensor.from_offsets_sync(
keys=sparse_features._keys,
values=sparse_features._values.reshape(-1),
offsets=torch.arange(
0,
int(self.table_size_tensor[0] * max_batch_size + 1),
dtype=torch.int32,
device=self.dense_device,
),
)
embedded_sparse = self.sparse_arch(sparse_features)
# Remove any padding from the embedding bag output to return it back to its expected size.
if max_batch_size - self.batch_size_tensor[0] > 0:
embedded_sparse = embedded_sparse[: self.batch_size_tensor[0]]

# Bottom MLP
embedded_dense = self.dense_arch(dense_features)

# Interaction Layer
concatenated_dense = self.inter_arch(
dense_features=embedded_dense, sparse_features=embedded_sparse
)
# Top MLP
logits = self.over_arch(concatenated_dense)
return logits

Expand Down

0 comments on commit 6aa7dfd

Please sign in to comment.