Skip to content

Commit 20efdb3

Browse files
ge0405facebook-github-bot
authored andcommitted
Benchmark: compare EBCs' performances (#421)
Summary: Pull Request resolved: #421 Compare the efficiency of EBC based on the following classes in github/benchmarks: (1) nn.EmbeddingBag EBC (2) SplitTableBatched EBC (with optimizer fusion) from D36396172 (34f17b2) (Todo) in addition to the above two classes, put the following class in fb/benchmarks (3) DenseTableBatched EBC (ref: D36394471) Reviewed By: YLGH Differential Revision: D36694744 fbshipit-source-id: 9950127727b2df8363b4be6fcdf6a168b962be27
1 parent beca256 commit 20efdb3

File tree

3 files changed

+497
-0
lines changed

3 files changed

+497
-0
lines changed

benchmarks/README.md

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# TorchRec Benchmarks for `EmbeddingBag`
2+
3+
We evaluate the performance of two EmbeddingBagCollection modules:
4+
5+
1. `EmbeddingBagCollection` (EBC) ([code](https://github.com/pytorch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L67)): a simple module backed by [torch.nn.EmbeddingBag](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html).
6+
7+
2. `FusedEmbeddingBagCollection` (Fused EBC) ([code](https://github.com/pytorch/torchrec/blob/main/torchrec/modules/fused_embedding_bag_collection.py#L299)): a module backed by [FBGEMM](https://github.com/pytorch/FBGEMM) kernels which enables more efficient, high-performance operations on embedding tables. It is equipped with a fused optimizer, and UVM caching/management that makes much larger memory available for GPUs.
8+
9+
10+
## Module architecture and running setup
11+
12+
We chose the embedding tables (sparse arch) in ML Perf [DLRM](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm) as the model to compare the performance difference between EBC and Fused EBC. Below are the settings on the embedding tables:
13+
```
14+
num_embeddings_per_feature = [45833188, 36746, 17245, 7413, 20243, 3, 7114, 1441, 62, 29275261, 1572176, 345138, 10, 2209, 11267, 128, 4, 974, 14, 48937457, 11316796, 40094537, 452104, 12606, 104, 35]
15+
embedding_dim_size = 128
16+
```
17+
18+
Other setup includes:
19+
- Optimizer: Stochastic Gradient Descent (SGD)
20+
- Dataset: Random dataset ([code](https://github.com/pytorch/torchrec/blob/main/torchrec/datasets/random.py))
21+
- CUDA 11.7, NCCL 2.11.4.
22+
- AWS EC2 instance with 8 16GB NVIDIA Tesla V100
23+
24+
25+
## How to run
26+
27+
After the installation of Torchrec (see "Binary" in the "Installation" section, [link](https://github.com/pytorch/torchrec)), run the following command under the benchmark directory (/torchrec/torchrec/benchmarks):
28+
29+
```
30+
python ebc_benchmarks.py [--mode MODE] [--cpu_only]
31+
```
32+
33+
where `MODE` can be specified as `ebc_comparison_dlrm` (default) / `fused_ebc_uvm` / `ebc_comparison_scaling` to see different comparisons.
34+
35+
36+
## Results
37+
38+
### Methodology
39+
40+
To ease the reading, we use "DLRM EMB" to abbreviate "DLRM embedding tables" from below. Since 1 GPU can't accommondate the full sized tables in DLRM, we need to reduce the `embedding_dim` of the 5 largest tables to some degree (see the "Note" column in the following tables for the reduction degree). For the metrics, we use the average training time over 100 epochs to represent the performance of each module. `speedup` (defined as `training time using EBC` divided by `training time using Fused EBC`) is also computed to demonstrate the degree of improvement from EBC to Fused EBC.
41+
42+
### 1. Comparison between EBC and FusedEBC on DLRM EMB (`ebc_comparison_dlrm`)
43+
44+
We see that Fused EBC has much faster training efficiency compared to EBC. The speedup from EBC to Fused EBC is 13X, 18X and 23X when the DLRM EMB is reduced by 128 times, 64 times and 32 times, respectively.
45+
46+
| Module | Time to train one epoch | Note |
47+
| ------ | ---------------------- | ---- |
48+
| EBC | 0.267 (+/- 0.002) second | DLRM EMB with sizes of the 5 largest tables reduced by 128 times |
49+
| EBC | 0.332 (+/- 0.002) second | DLRM EMB with sizes of the 5 largest tables reduced by 64 times |
50+
| EBC | 0.462 (+/- 0.002) second | DLRM EMB with sizes of the 5 largest tables reduced by 32 times |
51+
| Fused EBC | 0.019 (+/- 0.001) second | DLRM EMB with sizes of the 5 largest tables reduced by 128 times |
52+
| Fused EBC | 0.019 (+/- 0.001) second | DLRM EMB with sizes of the 5 largest tables reduced by 64 times |
53+
| Fused EBC | 0.019 (+/- 0.009) second | DLRM EMB with sizes of the 5 largest tables reduced by 32 times |
54+
55+
### 2. Full sized DLRM EMB w/ UVM/UVM-caching w/ FusedEBC (`fused_ebc_uvm`)
56+
57+
Here, we demonstrate the advantage of UVM/UVM-caching with Fused EBC. With UVM caching enabled, we can put larger sized tables in DLRM EMB in Fused EBC without significant sacrifice on the efficiency. With UVM enabled, we can allocate full sized DLRM EMB in UVM, with expected slower training performance because of the extra sync points between host and GPU (see [this example](https://github.com/pytorch/torchrec/blob/main/examples/sharding/uvm.ipynb) for more UVM explanation/usage).
58+
59+
| Module | Time to train one epoch | Note |
60+
| ------ | ---------------------- | ---- |
61+
|Fused EBC with UVM caching | 0.06 (+/- 0.37) second | DLRM EMB with size of the 5 largest tables reduced by 2 |
62+
|Fused EBC with UVM | 0.62 (+/- 5.34) second | full sized DLRM EMB |
63+
64+
The above performance comparison is also put in a bar chart for better visualization.
65+
![EBC_benchmarks_dlrm_emb](EBC_benchmarks_dlrm_emb.png)
66+
67+
68+
### 3. Comparison between EBC and fused_EBC on different sized embedding tables (`ebc_comparison_scaling`)
69+
70+
Here, we study how the scaling on the embedding table affects the performance difference between EBC and Fused EBC. In doing so, we vary three parameters, `num_tables`, `embedding_dim` and `num_embeddings`, and present `speedup` from EBC to Fused EBC in the following tables. In each table, we observe that `embedding_dim` and `num_embeddings` do not have significant effect on speedup. However, as `num_tables` increases, the improvement from EBC to Fused EBC becomes higher (speedup increases), suggesting the benefit of Fused EBC when it is to deal with many embedding tables.
71+
72+
73+
- `num_tables` = 10
74+
75+
|||——————|——————| `embedding_dim` |——————|——————|—————>|
76+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
77+
||| **4** | **8** | **16** | **32** | **64** | **128** |
78+
||| **4** | 2.87 | 2.79 | 2.79 | 2.79 | 2.76 | 2.8 |
79+
||| **8** | 2.71 | 3.11 | 2.97 | 3.02 | 2.99 | 2.95 |
80+
||| **16** | 2.98 | 2.97 | 2.98 | 2.97 | 3.0 | 3.05 |
81+
||| **32** | 3.01 | 2.95 | 2.99 | 2.98 | 2.98 | 3.01 |
82+
||| **64** | 3.0 | 3.02 | 3.0 | 2.97 | 2.96 | 2.97 |
83+
|**`num_embeddings`**| **128** | 3.03 | 2.96 | 3.02 | 3.0 | 3.02 | 3.05 |
84+
||| **256** | 3.01 | 2.95 | 3.0 | 3.03 | 3.05 | 3.02 |
85+
||| **1024** | 3.0 | 3.05 | 3.05 | 3.08 | 5.89 | 3.07 |
86+
||| **2048** | 2.99 | 3.03 | 3.0 | 3.05 | 3.0 | 3.06 |
87+
||| **4096** | 3.0 | 3.03 | 3.05 | 3.02 | 3.07 | 3.05 |
88+
|V| **8192** | 3.0 | 3.08 | 3.04 | 3.02 | 3.09 | 3.1 |
89+
90+
91+
- `num_tables` = 100
92+
93+
|||——————|——————| `embedding_dim` |——————|——————|—————>|
94+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
95+
||| **4** | **8** | **16** | **32** | **64** | **128** |
96+
||| **4** | 10.33 | 10.36 | 10.26 | 10.24 | 10.28 | 10.24 |
97+
||| **8** | 10.34 | 10.47 | 10.29 | 10.25 | 10.23 | 10.19 |
98+
||| **16** | 10.18 | 10.36 | 10.2 | 10.28 | 10.25 | 10.26 |
99+
||| **32** | 10.41 | 10.2 | 10.19 | 10.2 | 10.04 | 9.89 |
100+
||| **64** | 9.93 | 9.9 | 9.73 | 9.89 | 10.17 | 10.16 |
101+
|**`num_embeddings`**| **128** | 10.32 | 10.11 | 10.12 | 10.08 | 10.01 | 10.05 |
102+
||| **256** | 10.57 | 8.39 | 10.36 | 10.21 | 10.14 | 10.43 |
103+
||| **1024** | 10.39 | 9.67 | 8.46 | 10.23 | 10.29 | 10.11 |
104+
||| **2048** | 10.0 | 9.74 | 10.0 | 9.67 | 10.08 | 11.87 |
105+
||| **4096** | 9.94 | 9.82 | 10.17 | 9.66 | 9.87 | 9.95 |
106+
|V| **8192** | 9.81 | 10.23 | 10.12 | 10.18 | 10.36 | 9.57 |
107+
108+
109+
- `num_tables` = 1000
110+
111+
|||——————|——————| `embedding_dim` |——————|——————|—————>|
112+
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
113+
||| **4** | **8** | **16** | **32** | **64** | **128** |
114+
||| **4** | 13.81 | 13.56 | 13.33 | 13.33 | 13.24 | 12.86 |
115+
||| **8** | 13.44 | 13.4 | 13.39 | 13.41 | 13.39 | 13.09 |
116+
||| **16** | 12.55 | 12.88 | 13.22 | 13.19 | 13.27 | 12.95 |
117+
||| **32** | 13.17 | 12.84 | 12.8 | 12.78 | 13.13 | 13.07 |
118+
||| **64** | 13.06 | 12.84 | 12.84 | 12.9 | 12.83 | 12.89 |
119+
|**`num_embeddings`**| **128**| 13.14 | 13.04 | 13.16 | 13.21 | 13.08 | 12.91 |
120+
||| **256** | 13.71 | 13.59 | 13.76 | 13.24 | 13.36 | 13.59 |
121+
||| **1024** | 13.24 | 13.29 | 13.56 | 13.64 | 13.68 | 13.79 |
122+
||| **2048** | 13.2 | 13.19 | 13.35 | 12.44 | 13.32 | 13.17 |
123+
||| **4096** | 12.96 | 13.24 | 12.51 | 12.99 | 12.47 | 12.34 |
124+
|V| **8192** | 12.84 | 13.32 | 13.27 | 13.06 | 12.35 | 12.58 |

benchmarks/ebc_benchmarks.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import argparse
9+
import sys
10+
from typing import List, Tuple
11+
12+
import torch
13+
from fbgemm_gpu.split_table_batched_embeddings_ops import EmbeddingLocation
14+
from torchrec.github.benchmarks import ebc_benchmarks_utils
15+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
16+
from torchrec.modules.embedding_modules import EmbeddingBagCollection
17+
from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection
18+
19+
# Reference: https://github.com/facebookresearch/dlrm/blob/main/torchrec_dlrm/README.MD
20+
DLRM_NUM_EMBEDDINGS_PER_FEATURE = [
21+
45833188,
22+
36746,
23+
17245,
24+
7413,
25+
20243,
26+
3,
27+
7114,
28+
1441,
29+
62,
30+
29275261,
31+
1572176,
32+
345138,
33+
10,
34+
2209,
35+
11267,
36+
128,
37+
4,
38+
974,
39+
14,
40+
48937457,
41+
11316796,
42+
40094537,
43+
452104,
44+
12606,
45+
104,
46+
35,
47+
]
48+
49+
50+
def get_shrunk_dlrm_num_embeddings(reduction_degree: int) -> List[int]:
51+
return [
52+
num_emb if num_emb < 10000000 else int(num_emb / reduction_degree)
53+
for num_emb in DLRM_NUM_EMBEDDINGS_PER_FEATURE
54+
]
55+
56+
57+
def main(argv: List[str]) -> None:
58+
args = parse_args(argv)
59+
60+
if not args.cpu_only and torch.cuda.is_available():
61+
device = torch.device("cuda")
62+
else:
63+
device = torch.device("cpu")
64+
65+
if args.mode == "ebc_comparison_dlrm":
66+
print("Running EBC vs. FusedEBC on DLRM EMB")
67+
68+
for reduction_degree in [128, 64, 32]:
69+
embedding_bag_configs: List[EmbeddingBagConfig] = [
70+
EmbeddingBagConfig(
71+
name=f"ebc_{idx}",
72+
embedding_dim=128,
73+
num_embeddings=num_embeddings,
74+
feature_names=[f"ebc_{idx}_feat_1"],
75+
)
76+
for idx, num_embeddings in enumerate(
77+
get_shrunk_dlrm_num_embeddings(reduction_degree)
78+
)
79+
]
80+
(
81+
ebc_time_avg,
82+
ebc_time_std,
83+
fused_ebc_time_avg,
84+
fused_ebc_time_std,
85+
speedup,
86+
) = get_ebc_comparison(embedding_bag_configs, device)
87+
88+
print(f"when DLRM EMB is reduced by {reduction_degree} times:")
89+
print(f"ebc_time = {ebc_time_avg} +/- {ebc_time_std} sec")
90+
print(f"fused_ebc_time = {fused_ebc_time_avg} +/- {fused_ebc_time_std} sec")
91+
print(f"speedup = {speedup}")
92+
93+
elif args.mode == "fused_ebc_uvm":
94+
print("Running DLRM EMB on FusedEBC with UVM/UVM-caching")
95+
embedding_bag_configs: List[EmbeddingBagConfig] = [
96+
EmbeddingBagConfig(
97+
name=f"ebc_{idx}",
98+
embedding_dim=128,
99+
num_embeddings=num_embeddings,
100+
feature_names=[f"ebc_{idx}_feat_1"],
101+
)
102+
for idx, num_embeddings in enumerate(get_shrunk_dlrm_num_embeddings(2))
103+
]
104+
fused_ebc_time_avg, fused_ebc_time_std = get_fused_ebc_uvm_time(
105+
embedding_bag_configs, device, EmbeddingLocation.MANAGED_CACHING
106+
)
107+
print(
108+
f"FusedEBC with UVM caching on DLRM: {fused_ebc_time_avg} +/- {fused_ebc_time_std} sec"
109+
)
110+
111+
embedding_bag_configs: List[EmbeddingBagConfig] = [
112+
EmbeddingBagConfig(
113+
name=f"ebc_{idx}",
114+
embedding_dim=128,
115+
num_embeddings=num_embeddings,
116+
feature_names=[f"ebc_{idx}_feat_1"],
117+
)
118+
for idx, num_embeddings in enumerate(DLRM_NUM_EMBEDDINGS_PER_FEATURE)
119+
]
120+
fused_ebc_time_avg, fused_ebc_time_std = get_fused_ebc_uvm_time(
121+
embedding_bag_configs, device, EmbeddingLocation.MANAGED
122+
)
123+
print(
124+
f"FusedEBC with UVM management on DLRM: {fused_ebc_time_avg} plus/minus {fused_ebc_time_std} sec"
125+
)
126+
127+
elif args.mode == "ebc_comparison_scaling":
128+
print("Running EBC vs. FusedEBC scaling experiment")
129+
130+
num_tables_list = [10, 100, 1000]
131+
embedding_dim_list = [4, 8, 16, 32, 64, 128]
132+
num_embeddings_list = [4, 8, 16, 32, 64, 128, 256, 1024, 2048, 4096, 8192]
133+
134+
for num_tables in num_tables_list:
135+
for num_embeddings in num_embeddings_list:
136+
for embedding_dim in embedding_dim_list:
137+
embedding_bag_configs: List[EmbeddingBagConfig] = [
138+
EmbeddingBagConfig(
139+
name=f"ebc_{idx}",
140+
embedding_dim=embedding_dim,
141+
num_embeddings=num_embeddings,
142+
feature_names=[f"ebc_{idx}_feat_1"],
143+
)
144+
for idx in range(num_tables)
145+
]
146+
ebc_time, _, fused_ebc_time, _, speedup = get_ebc_comparison(
147+
embedding_bag_configs, device, epochs=3
148+
)
149+
print(
150+
f"EBC num_tables = {num_tables}, num_embeddings = {num_embeddings}, embedding_dim = {embedding_dim}:"
151+
)
152+
print(
153+
f"ebc_time = {ebc_time} sec, fused_ebc_time = {fused_ebc_time} sec, speedup = {speedup}"
154+
)
155+
156+
157+
def get_fused_ebc_uvm_time(
158+
embedding_bag_configs: List[EmbeddingBagConfig],
159+
device: torch.device,
160+
location: EmbeddingLocation,
161+
epochs: int = 100,
162+
) -> Tuple[float, float]:
163+
164+
fused_ebc = FusedEmbeddingBagCollection(
165+
tables=embedding_bag_configs,
166+
optimizer_type=torch.optim.SGD,
167+
optimizer_kwargs={"lr": 0.02},
168+
device=device,
169+
location=location,
170+
)
171+
172+
dataset = ebc_benchmarks_utils.get_random_dataset(
173+
batch_size=64,
174+
num_batches=10,
175+
num_dense_features=1024,
176+
embedding_bag_configs=embedding_bag_configs,
177+
)
178+
179+
fused_ebc_time_avg, fused_ebc_time_std = ebc_benchmarks_utils.train(
180+
model=fused_ebc,
181+
optimizer=None,
182+
dataset=dataset,
183+
device=device,
184+
epochs=epochs,
185+
)
186+
187+
return fused_ebc_time_avg, fused_ebc_time_std
188+
189+
190+
def get_ebc_comparison(
191+
embedding_bag_configs: List[EmbeddingBagConfig],
192+
device: torch.device,
193+
epochs: int = 100,
194+
) -> Tuple[float, float, float, float, float]:
195+
196+
# Simple EBC module wrapping a list of nn.EmbeddingBag
197+
ebc = EmbeddingBagCollection(
198+
tables=embedding_bag_configs,
199+
device=device,
200+
)
201+
optimizer = torch.optim.SGD(ebc.parameters(), lr=0.02)
202+
203+
# EBC with fused optimizer backed by fbgemm SplitTableBatchedEmbeddingBagsCodegen
204+
fused_ebc = FusedEmbeddingBagCollection(
205+
tables=embedding_bag_configs,
206+
optimizer_type=torch.optim.SGD,
207+
optimizer_kwargs={"lr": 0.02},
208+
device=device,
209+
)
210+
211+
dataset = ebc_benchmarks_utils.get_random_dataset(
212+
batch_size=64,
213+
num_batches=10,
214+
num_dense_features=1024,
215+
embedding_bag_configs=embedding_bag_configs,
216+
)
217+
218+
ebc_time_avg, ebc_time_std = ebc_benchmarks_utils.train(
219+
model=ebc,
220+
optimizer=optimizer,
221+
dataset=dataset,
222+
device=device,
223+
epochs=epochs,
224+
)
225+
fused_ebc_time_avg, fused_ebc_time_std = ebc_benchmarks_utils.train(
226+
model=fused_ebc,
227+
optimizer=None,
228+
dataset=dataset,
229+
device=device,
230+
epochs=epochs,
231+
)
232+
speedup = ebc_time_avg / fused_ebc_time_avg
233+
234+
return ebc_time_avg, ebc_time_std, fused_ebc_time_avg, fused_ebc_time_std, speedup
235+
236+
237+
def parse_args(argv: List[str]) -> argparse.Namespace:
238+
parser = argparse.ArgumentParser(description="TorchRec ebc benchmarks")
239+
parser.add_argument(
240+
"--cpu_only",
241+
action="store_true",
242+
default=False,
243+
help="specify whether to use cpu",
244+
)
245+
parser.add_argument(
246+
"--mode",
247+
type=str,
248+
default="ebc_comparison_dlrm",
249+
help="specify 'ebc_comparison_dlrm', 'ebc_comparison_scaling' or 'fused_ebc_uvm'",
250+
)
251+
return parser.parse_args(argv)
252+
253+
254+
if __name__ == "__main__":
255+
main(sys.argv[1:])

0 commit comments

Comments
 (0)