-
Notifications
You must be signed in to change notification settings - Fork 0
/
HashedEmbeddingBag_test.py
306 lines (220 loc) · 9.76 KB
/
HashedEmbeddingBag_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
import torch
import HashedEmbeddingBag
import hashed_embedding_bag
def make_offset2bag(offsets, indices):
offsets2bag = torch.zeros(indices.size(0) + 1, dtype=indices.dtype, device=offsets.device)
offsets2bag.index_add_(0, offsets, torch.ones_like(offsets, memory_format=torch.legacy_contiguous_format))
offsets2bag[0] -= 1
offsets2bag = offsets2bag.cumsum(0)
offsets2bag.resize_(indices.size(0))
return offsets2bag
def test_hashedEmbeddingBag():
# the 'sum' mode
mode = 0
bag_num = 18
num_categories = 100
num_feature = 200
hashed_weight_size = 200
# generate random weight and input for testing
hashed_weights = torch.rand(hashed_weight_size)
bag_size = torch.randint(low=0, high=7, size=(bag_num,))
indices_num = bag_size.sum().item()
indices = torch.randint(low=0, high=num_categories - 1, size=(indices_num,))
offsets = torch.cat([torch.zeros(1, dtype=torch.long), bag_size.cumsum(dim=0)[:-1]])
# move all inputs to GPU
device = torch.cuda.current_device()
hashed_weights = hashed_weights.to(device)
indices = indices.to(device)
offsets = offsets.to(device)
# run forward function on GPU
output, offset2bag, bag_size, max_indices, hashed_idx = \
hashed_embedding_bag.forward(hashed_weights, indices, offsets, mode, num_feature)
# move weight, inputs, and outputs to CPU
device = torch.device("cpu")
hashed_weights = hashed_weights.to(device)
indices = indices.to(device)
offsets = offsets.to(device)
output = output.to(device)
offset2bag = offset2bag.to(device)
bag_size = bag_size.to(device)
max_indices = max_indices.to(device)
hashed_idx = hashed_idx.to(device)
expected_offsets2bag = make_offset2bag(offsets, indices)
# generate expected output by python
expected_hashed_index = torch.zeros((indices_num, num_feature), dtype=torch.long)
expected_output = torch.zeros(bag_num, num_feature)
for i in range(indices.size(0)):
for j in range(num_feature):
weight_idx = hashed_embedding_bag.hash(indices[i].item(), j) % hashed_weights.size(0)
expected_hashed_index[i, j] = weight_idx
expected_output[expected_offsets2bag[i].item(), j] += hashed_weights[weight_idx]
# assert forward results are correct
assert ((expected_offsets2bag - offset2bag).abs().sum().item() == 0)
assert (expected_hashed_index.equal(hashed_idx))
assert (expected_output.equal(output))
# the gradient of output, which is the input for backward.
output_grad = torch.rand_like(expected_output)
expected_weight_grad = torch.zeros_like(hashed_weights)
# generate gradient for weight in python
for i in range(indices.size(0)):
for j in range(num_feature):
weight_idx = hashed_embedding_bag.hash(indices[i].item(), j) % hashed_weights.size(0)
expected_weight_grad[weight_idx] += output_grad[offset2bag[i].item(), j]
# move all tensors to GPU
device = torch.cuda.current_device()
hashed_weights = hashed_weights.to(device)
indices = indices.to(device)
offsets = offsets.to(device)
offset2bag = offset2bag.to(device)
bag_size = bag_size.to(device)
max_indices = max_indices.to(device)
hashed_idx = hashed_idx.to(device)
output_grad = output_grad.to(device)
weight_grad = hashed_embedding_bag.backward(
output_grad, indices, offsets, offset2bag, bag_size, max_indices, hashed_idx, hashed_weights.size(0), False,
mode, num_feature)
weight_grad = weight_grad.cpu()
assert ((weight_grad - expected_weight_grad).sum().item() < 0.1)
def test_hashedEmbeddingBag_single():
bag_num = 180
num_categories = 100
num_feature = 200
hashed_weight_size = 200
# generate random weight and input for testing
hashed_weights = torch.rand(hashed_weight_size, requires_grad=True)
embedding = HashedEmbeddingBag.HashedEmbeddingBag(
num_categories, num_feature, compression=0.1, _weight=hashed_weights)
embedding = embedding.cuda()
indices_num = bag_num
indices = torch.randint(low=0, high=num_categories - 1, size=(indices_num, 1))
# move all inputs to GPU
device = torch.cuda.current_device()
indices = indices.to(device)
output = embedding.forward(indices)
# give a 'weight' to different locations in output, so that the element in output_grad is different from each other.
x = torch.rand_like(output).cuda()
loss = (output * x).sum()
loss.backward()
# move weight, inputs, and outputs to CPU
device = torch.device("cpu")
hashed_weights = hashed_weights.to(device)
indices = indices.to(device)
output = output.to(device)
# generate expected output by python
expected_hashed_index = torch.zeros((indices_num, num_feature), dtype=torch.long)
expected_output = torch.zeros(bag_num, num_feature)
for i in range(indices.size(0)):
for j in range(num_feature):
weight_idx = hashed_embedding_bag.hash(indices[i].item(), j) % hashed_weights.size(0)
expected_hashed_index[i, j] = weight_idx
expected_output[i, j] += hashed_weights[weight_idx]
# assert forward results are correct
assert (expected_output.equal(output))
# the gradient of output, which is the input for backward.
output_grad = x.cpu()
expected_weight_grad = torch.zeros_like(hashed_weights)
# generate gradient for weight in python
for i in range(indices.size(0)):
for j in range(num_feature):
weight_idx = hashed_embedding_bag.hash(indices[i].item(), j) % hashed_weights.size(0)
expected_weight_grad[weight_idx] += output_grad[i, j]
# move all tensors to GPU
device = torch.cuda.current_device()
hashed_weights = hashed_weights.to(device)
indices = indices.to(device)
assert ((embedding.hashed_weight.grad.data.cpu() - expected_weight_grad).sum().item() < 0.1)
def test_HashedEmbeddingBagAPI_mean():
bag_num = 18
num_categories = 10
num_feature = 5
hashed_weight_size = 200
# generate random weight and input for testing
hashed_weights = torch.rand(hashed_weight_size)
embedding = HashedEmbeddingBag.HashedEmbeddingBag(
num_categories, num_feature, compression=0.1, mode="mean", _weight=hashed_weights)
embedding = embedding.cuda()
bag_size = torch.randint(low=0, high=3, size=(bag_num,))
indices_num = bag_size.sum().item()
indices = torch.randint(low=0, high=num_categories - 1, size=(indices_num,))
offsets = torch.cat([torch.zeros(1, dtype=torch.long), bag_size.cumsum(dim=0)[:-1]])
# move all inputs to GPU
device = torch.cuda.current_device()
indices = indices.to(device)
offsets = offsets.to(device)
x = embedding.forward(indices, offsets)
loss = x.sum()
loss.backward()
def test_HashedEmbeddingBagAPI_max():
bag_num = 18
num_categories = 10
num_feature = 5
hashed_weight_size = 200
# generate random weight and input for testing
hashed_weights = torch.rand(hashed_weight_size)
embedding = HashedEmbeddingBag.HashedEmbeddingBag(num_categories, num_feature, compression=0.1, mode="max")
embedding = embedding.cuda()
bag_size = torch.randint(low=0, high=3, size=(bag_num,))
indices_num = bag_size.sum().item()
indices = torch.randint(low=0, high=num_categories - 1, size=(indices_num,))
offsets = torch.cat([torch.zeros(1, dtype=torch.long), bag_size.cumsum(dim=0)[:-1]])
# move all inputs to GPU
device = torch.cuda.current_device()
indices = indices.to(device)
offsets = offsets.to(device)
x = embedding.forward(indices, offsets)
loss = x.sum()
loss.backward()
def test_HashedEmbeddingBagAPI_sum():
bag_num = 18
num_categories = 10
num_feature = 5
hashed_weight_size = 200
# generate random weight and input for testing
hashed_weights = torch.rand(hashed_weight_size)
embedding = HashedEmbeddingBag.HashedEmbeddingBag(num_categories, num_feature, compression=0.1, mode="sum")
embedding = embedding.cuda()
bag_size = torch.randint(low=0, high=3, size=(bag_num,))
indices_num = bag_size.sum().item()
indices = torch.randint(low=0, high=num_categories - 1, size=(indices_num,))
offsets = torch.cat([torch.zeros(1, dtype=torch.long), bag_size.cumsum(dim=0)[:-1]])
# move all inputs to GPU
device = torch.cuda.current_device()
indices = indices.to(device)
offsets = offsets.to(device)
x = embedding.forward(indices, offsets)
loss = x.sum()
loss.backward()
def test_HashedEmbeddingBagAPI_single():
bag_num = 200
num_categories = 10
num_feature = 5
hashed_weight_size = 200
# generate random weight and input for testing
hashed_weights = torch.rand(hashed_weight_size)
embedding = HashedEmbeddingBag.HashedEmbeddingBag(num_categories, num_feature, compression=0.1, mode="sum")
embedding = embedding.cuda()
indices_num = bag_num
indices = torch.randint(low=0, high=num_categories - 1, size=(indices_num, 1))
# move all inputs to GPU
device = torch.cuda.current_device()
indices = indices.to(device)
x = embedding.forward(indices)
loss = x.sum()
loss.backward()
def test_HashedEmbeddingBagAPI_embeddding():
bag_num = 200
num_categories = 10
num_feature = 5
hashed_weight_size = 200
# generate random weight and input for testing
hashed_weights = torch.rand(hashed_weight_size)
embedding = HashedEmbeddingBag.HashedEmbedding(num_categories, num_feature, compression=0.1)
embedding = embedding.cuda()
indices_num = bag_num
indices = torch.randint(low=0, high=num_categories - 1, size=(indices_num, ))
# move all inputs to GPU
device = torch.cuda.current_device()
indices = indices.to(device)
x = embedding.forward(indices)
loss = x.sum()
loss.backward()