Skip to content

Commit

Permalink
Fix the int4 table batched embedding benchmark with mixed dim (#609)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #609

Fix the bench with "--mixed" dimension as reported by xcliang in https://www.internalfb.com/diff/D28248236 (0fe80ee014b936733278a77d0a24c9fe9a431c31)?dst_version_fbid=834500627143449&transaction_fbid=331571331658562

Differential Revision: D28466825

fbshipit-source-id: ac4725f37d89a3ecd2bcccd564b0424173f81230
  • Loading branch information
jianyuh authored and facebook-github-bot committed May 16, 2021
1 parent 0fe80ee commit c48ba9b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,8 @@ def cpu( # noqa C901
feature_requires_grad = None
if mixed:
Ds = [
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
# int4 table batched emb op can only handle mixed D where D is multiple of 8
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8)
for _ in range(T)
]
D = np.average(Ds)
Expand Down Expand Up @@ -905,8 +906,9 @@ def int4_device( # noqa C901
else:
feature_requires_grad = None
if mixed:
# int4 table batched emb op can only handle mixed D where D is multiple of 8
Ds = [
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 4)
div_round_up(np.random.randint(low=int(0.5 * D), high=int(1.5 * D)), 8)
for _ in range(T)
]
D = np.average(Ds)
Expand Down

0 comments on commit c48ba9b

Please sign in to comment.