Skip to content

Commit

Permalink
[pt] Add half precision support for nn.EmbeddingBag (CPU) (#74844)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #74844

- Use FBGEMM/perf kernel implementation for the fast path.
- Use FP32 accumulation for FP16 weight embeddings (`index_select_add` and `index_select_scale_add`).
- Add the unit test coverage.

Test Plan:
```
buck run mode/opt //ai_codesign/nonprod/jianyuhuang/pytorch_examples:eb
Parsing buck files: finished in 0.6 sec
Downloaded 0/2 artifacts, 0.00 bytes, 100.0% cache miss (for updated rules)
Building: finished in 01:52.1 min (100%) 12247/12247 jobs, 2/12247 updated
  Total time: 01:52.8 min
BUILD SUCCEEDED
tensor([[ 0.1282, -0.0244,  1.0996],
        [-1.2285, -0.8643,  2.6621]], dtype=torch.float16,
       grad_fn=<EmbeddingBagBackward0>)
tensor([[[-0.1643,  0.1266, -0.4851],
         [ 0.0710,  0.5024,  0.2798],
         [ 0.4797,  0.5991, -0.0188],
         [ 0.8843,  1.2090,  1.6494]],

        [[ 0.4797,  0.5991, -0.0188],
         [ 0.0662, -0.4121,  1.5752],
         [ 0.0710,  0.5024,  0.2798],
         [-0.8242,  0.2668, -0.6177]]], dtype=torch.float16,
       grad_fn=<EmbeddingBackward0>)
```

```
$ buck run mode/opt //caffe2/test:nn -- -r test_embedding_bag_half 2>&1 | tee output.log
PARSING BUCK FILES: FINISHED IN 0.8s
CREATING ACTION GRAPH: FINISHED IN 0.0s
test_embedding_bag_half_cpu_int32_int32 (test_nn.TestNNDeviceTypeCPU) ... ok
test_embedding_bag_half_cpu_int32_int64 (test_nn.TestNNDeviceTypeCPU) ... ok
test_embedding_bag_half_cpu_int64_int32 (test_nn.TestNNDeviceTypeCPU) ... ok
test_embedding_bag_half_cpu_int64_int64 (test_nn.TestNNDeviceTypeCPU) ... ok
test_embedding_bag_half_cuda_int32_int32 (test_nn.TestNNDeviceTypeCUDA) ... ok
test_embedding_bag_half_cuda_int32_int64 (test_nn.TestNNDeviceTypeCUDA) ... ok
test_embedding_bag_half_cuda_int64_int32 (test_nn.TestNNDeviceTypeCUDA) ... ok
test_embedding_bag_half_cuda_int64_int64 (test_nn.TestNNDeviceTypeCUDA) ... ok

----------------------------------------------------------------------
Ran 8 tests in 44.621s

OK
```

```
TORCH_SHOW_CPP_STACKTRACES=1  buck run mode/opt //caffe2/test:nn -- -r test_EmbeddingBag_per_sample_weights_and_new_offsets 2>&1 | tee output.log
```

Reviewed By: jasonjk-park

Differential Revision: D35190299

fbshipit-source-id: d1daa6e837660259b92a1f316b09f38e509ee077
(cherry picked from commit 86f575f)
  • Loading branch information
jianyuh authored and pytorchmergebot committed Apr 6, 2022
1 parent 57ba615 commit b8a4708
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 70 deletions.

0 comments on commit b8a4708

Please sign in to comment.