Skip to content

Commit

Permalink
Add variable batch size support to TBE training (#1752)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1752

This diff adds the variable batch size (or variable length) support in split TBE training on GPU.

**Usage:**

```
# Initialize TBE as same as previously
emb_op = SplitTableBatchedEmbeddingBagsCodegen(
    embedding_specs=[...],
    ... # other params
)

# batch sizes (one for each FEATURE and each RANK).
# Example: num_features = 2, num_ranks = 4
batch_size_per_feature_per_rank = [
    [1,  2, 8, 3] # batch sizes for [Rank 0, Rank 1, Rank 2, Rank 3] in Feature 0
    [6, 10, 3, 5] # batch sizes for [Rank 0, Rank 1, Rank 2, Rank 3] in Feature 1
]

# Pass a list of batch_size_per_feature_per_rank to forward.
# !! Make sure to pass batch_size_per_feature_per_rank as a keyword arg because there can be other keyword args in forward. !!
output = emb_op(indices, offsets, batch_size_per_feature_per_rank=batch_size_per_feature_per_rank)
```

**Output format**

{F967393126}

**Limitation:**

`T` and `max_B` have to fit in 32 bits.
- We use lower `info_B_num_bits` bits to store `b` (bag ID; `b` < `max_B`).  Supported `max_B` = `2^info_B_num_bits`
- We use upper `32 - info_B_num_bits` bits to store `t` (table ID; `t` < `T`).  Supported `T` = `2^(32 - info_B_num_bits)`

Note that we adjust `info_B_num_bits` automatically at runtime based on `max_B` and `T`.  If they cannot fit into 32 bits, it will abort.

Differential Revision: D42663369

fbshipit-source-id: 9918a51ac0be5da077e37bb9315716380c12b7e0
  • Loading branch information
sryap authored and facebook-github-bot committed May 5, 2023
1 parent 3138ecd commit dcd47df
Show file tree
Hide file tree
Showing 6 changed files with 534 additions and 80 deletions.
10 changes: 10 additions & 0 deletions fbgemm_gpu/codegen/lookup_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@
import torch


class VBEMetadata(NamedTuple):
B_offsets: Optional[torch.Tensor]
output_offsets_feature_rank: Optional[torch.Tensor]
B_offsets_rank_per_feature: Optional[torch.Tensor]
max_B_feature_rank: int = -1
max_B: int = -1
output_size: int = -1


class CommonArgs(NamedTuple):
placeholder_autograd_tensor: torch.Tensor
dev_weights: torch.Tensor
Expand All @@ -30,6 +39,7 @@ class CommonArgs(NamedTuple):
feature_requires_grad: Optional[torch.Tensor]
lxu_cache_locations: torch.Tensor
output_dtype: int
vbe_metadata: VBEMetadata


class OptimizerArgs(NamedTuple):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def invoke(
{% endif %}
)
else:
vbe_metadata = common_args.vbe_metadata
return torch.ops.fbgemm.split_embedding_codegen_lookup_{{ optimizer }}_function(
# common_args
{% if not dense %}
Expand All @@ -169,6 +170,13 @@ def invoke(
indice_weights=common_args.indice_weights,
feature_requires_grad=common_args.feature_requires_grad,
lxu_cache_locations=common_args.lxu_cache_locations,
# VBE metadata
B_offsets=vbe_metadata.B_offsets,
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
vbe_B_offsets_rank_per_feature=vbe_metadata.B_offsets_rank_per_feature,
max_B=vbe_metadata.max_B,
max_B_feature_rank=vbe_metadata.max_B_feature_rank,
vbe_output_size=vbe_metadata.output_size,
# optimizer_args
gradient_clipping = optimizer_args.gradient_clipping,
max_gradient=optimizer_args.max_gradient,
Expand Down
14 changes: 9 additions & 5 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,19 @@ def to_device(t: Deviceable, use_cpu: bool) -> Deviceable:
# Merged indices with shape (T, B, L) -> (flattened indices with shape
# (T * B * L), offsets with shape (T * B + 1))
def get_table_batched_offsets_from_dense(
merged_indices: torch.Tensor, use_cpu: bool = False
merged_indices: torch.Tensor,
L: Optional[int] = None,
total_B: Optional[int] = None,
use_cpu: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
(T, B, L) = merged_indices.size()
lengths = np.ones((T, B)) * L
flat_lengths = lengths.flatten()
if L is None and total_B is None:
(T, B, L) = merged_indices.size()
total_B = T * B
lengths = np.ones(total_B) * L
return (
to_device(merged_indices.contiguous().view(-1), use_cpu),
to_device(
torch.tensor(([0] + np.cumsum(flat_lengths).tolist())).long(),
torch.tensor(([0] + np.cumsum(lengths).tolist())).long(),
use_cpu,
),
)
Expand Down
97 changes: 93 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,10 @@ def __init__( # noqa C901
], "Fused pooled embedding quantization only supported for cuda."

if device is None:
# pyre-fixme[8]: Attribute has type `device`; used as `Union[int, device]`.
self.current_device: torch.device = (
torch.device("cpu") if self.use_cpu else torch.cuda.current_device()
torch.device("cpu")
if self.use_cpu
else torch.device(torch.cuda.current_device())
)
elif isinstance(device, torch.device):
self.current_device = device
Expand Down Expand Up @@ -360,8 +361,8 @@ def __init__( # noqa C901
table_has_feature[t] = True
assert all(table_has_feature), "Each table must have at least one feature!"

D_offsets = [dims[t] for t in self.feature_table_map]
D_offsets = [0] + list(accumulate(D_offsets))
feature_dims = [dims[t] for t in self.feature_table_map]
D_offsets = [0] + list(accumulate(feature_dims))
self.total_D: int = D_offsets[-1]
self.max_D: int = max(dims)
cached_dims = [
Expand Down Expand Up @@ -405,6 +406,11 @@ def __init__( # noqa C901
"bounds_check_warning",
torch.tensor([0], device=self.current_device, dtype=torch.int64),
)
# Required for VBE
self.register_buffer(
"feature_dims",
torch.tensor(feature_dims, device="cpu", dtype=torch.int64),
)

weight_split = construct_split_state(
embedding_specs,
Expand Down Expand Up @@ -728,7 +734,87 @@ def forward(
offsets: Tensor,
per_sample_weights: Optional[Tensor] = None,
feature_requires_grad: Optional[Tensor] = None,
# 2D tensor of batch size for each rank and feature.
# Shape (number of features, number of ranks)
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
) -> Tensor:
if batch_size_per_feature_per_rank is not None:
# TODO: Add input check
zero_tensor = torch.zeros(1, device="cpu", dtype=torch.int32)

# Create B offsets
total_batch_size_per_feature = torch.tensor(
[sum(batch_sizes) for batch_sizes in batch_size_per_feature_per_rank],
device="cpu",
dtype=torch.int32,
)
max_B = int(total_batch_size_per_feature.max().item())
Bs = torch.concat([zero_tensor, total_batch_size_per_feature])
B_offsets = Bs.cumsum(dim=0).to(torch.int)

# Create output offsets
B_feature_rank = torch.tensor(
batch_size_per_feature_per_rank,
device="cpu",
dtype=torch.int64,
)
max_B_feature_rank = int(B_feature_rank.max().item())
# D->H only once
if self.feature_dims.is_cuda:
self.feature_dims = self.feature_dims.cpu()
output_sizes_feature_rank = B_feature_rank.transpose(
0, 1
) * self.feature_dims.view(1, -1)
output_offsets_feature_rank = torch.concat(
[
zero_tensor.to(torch.int64),
output_sizes_feature_rank.flatten().cumsum(dim=0),
]
)
output_size = int(output_offsets_feature_rank[-1].item())

# TODO: Support INT8 output
# B_offsets_rank_per_feature is for rank and (b, t) mapping
B_offsets_rank_per_feature = (
torch.tensor(
[
[0] + batch_size_per_feature
for batch_size_per_feature in batch_size_per_feature_per_rank
],
device="cpu",
dtype=torch.int32,
)
.cumsum(dim=1)
.to(torch.int)
)

B_offsets = B_offsets.to(self.current_device, non_blocking=True)
output_offsets_feature_rank = output_offsets_feature_rank.to(
self.current_device, non_blocking=True
)
B_offsets_rank_per_feature = B_offsets_rank_per_feature.to(
self.current_device, non_blocking=True
)

# TODO: Use int32 for B_offsets and int64 for output_offsets_feature_rank
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=B_offsets,
output_offsets_feature_rank=output_offsets_feature_rank,
B_offsets_rank_per_feature=B_offsets_rank_per_feature,
max_B=max_B,
max_B_feature_rank=max_B_feature_rank,
output_size=output_size,
)
else:
vbe_metadata = invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
B_offsets_rank_per_feature=None,
max_B=-1,
max_B_feature_rank=-1,
output_size=-1,
)

(indices, offsets) = indices.long(), offsets.long()
if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
torch.ops.fbgemm.bounds_check_indices(
Expand All @@ -738,6 +824,8 @@ def forward(
self.bounds_check_mode_int,
self.bounds_check_warning,
per_sample_weights,
B_offsets=vbe_metadata.B_offsets,
max_B=vbe_metadata.max_B,
)
self.step += 1
if len(self.timesteps_prefetched) == 0:
Expand Down Expand Up @@ -781,6 +869,7 @@ def forward(
feature_requires_grad=feature_requires_grad,
lxu_cache_locations=lxu_cache_locations,
output_dtype=self.output_dtype,
vbe_metadata=vbe_metadata,
)

if self.optimizer == OptimType.EXACT_SGD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,14 @@ def forward(
indice_weights=per_sample_weights,
feature_requires_grad=feature_requires_grad,
lxu_cache_locations=lxu_cache_locations,
vbe_metadata=invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
B_offsets_rank_per_feature=None,
max_B=-1,
max_B_feature_rank=-1,
output_size=-1,
),
)

momentum1 = invokers.lookup_args.Momentum(
Expand Down

0 comments on commit dcd47df

Please sign in to comment.