forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
embedding_lookup_idx.h
57 lines (53 loc) · 1.63 KB
/
embedding_lookup_idx.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#pragma once
#include <cstdint>
namespace caffe2 {
// clang-format off
/**
* Embedding lookup with reduction.
*
* `input` of size data_size * block_size
* `indices` of size index_size
* `offsets` of size output_size
* `weights` nullptr or array of size index_size
* `out` of size output_size * block_size
*
* Behavior is roughly equivalent to pseudocode:
*
* pos = 0
* for (i = 0..output_size-1)
* for (k = 0..block_size-1)
* out[i*block_size + k] = 0
* start_offset = offsets[i]
* end_offset = offsets[i+1]
* length = end_offset - start_offset
* for (j = start_offset..end_offset-1)
* for (k = 0..block_size-1)
* out[i*block_size + k] += input[indices[pos]*block_size + k] *
* (weights ? weights[IS_WEIGHT_POSITIONAL ? j - start_offset : pos] : 1.0)
* pos += 1
* if (normalize_weights && length > 0)
* for (k = 0..block_size-1)
* out[i*block_size + k] /= length
*
* TODO: make this API also take "offsets" rather than "lengths" to match the
* API for PyTorch's EmbeddingBag
*/
// clang-format on
template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
void EmbeddingLookupIdx(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size,
const InType* input,
const IndexType* indices,
const IndexType* offsets,
const float* weights, // optional, can be null for non-weighted sum
const float* scale_bias, // optional scale & bias params for uint8 input
bool normalize_by_lengths,
OutType* out);
} // namespace caffe2