-
Notifications
You must be signed in to change notification settings - Fork 2
/
nas_embedding_card.py
167 lines (132 loc) · 6.31 KB
/
nas_embedding_card.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Ravi Krishna 07/25/21
import torch
import torch.nn as nn
import torch.nn.functional as F
from nas_embedding import EmbeddingDLRM
from nas_supernet import SuperNet
import numpy as np
class EmbeddingCardSuperNet(SuperNet):
def __init__(self,
cardinality_options,
dim):
"""
Implements an embedding cardinality search supernet.
We cannot use the FBNetv2 method of just creating the largest
cardinality embedding and then using that becuase then hashed
and unhashes embedding indices may be mapped to the same values.
Further, because we will typically be choosing between the original
embedding size, and reducing by factors of [10, 100, 1000, etc.], the
overhead for directly storing these embeddings is at most 1/(1 - 0.1) =
~1.11 and thus properly representing the search space is worth it.
"""
# Superclass initialization.
super(EmbeddingCardSuperNet, self).__init__()
# Store for later.
self.cardinality_options = cardinality_options
self.num_card_options = len(self.cardinality_options)
self.dim = dim
self.params_options = nn.Parameter(torch.Tensor([self.dim * curr_card for curr_card in self.cardinality_options]), requires_grad=False)
# For compatibility with DLRM.
self.num_embeddings = max(self.cardinality_options)
# Create the embedding matrices.
self.embedding_options = nn.ModuleList([])
for card_option in self.cardinality_options:
self.embedding_options.append(EmbeddingDLRM(card_option, self.dim))
# Create other parameters.
self.theta_parameters = nn.ParameterList([nn.Parameter(torch.Tensor([0.00] * self.num_card_options), requires_grad=True)])
self.mask_values = [None] * len(self.theta_parameters)
self.num_mask_lists = len(self.mask_values)
# For compatibility with DLRM, store the current cost
# instead of returning it after each call to forward().
self.curr_cost = None
def calculate_cost(self):
"""
Calculates the cost as the weighted average number of parameters.
"""
# Get the mask values. This will be a tensor of size
# (batch_size, self.num_card_options).
curr_mask_values = self.mask_values[0]
# Take the dot product with the number of parameters for each
# of the cardinality options. Should be of size (batch_size)>
#print(f"CURRENT MASK VALUES = {curr_mask_values.size()}, PARAMS OPTIONS = {self.params_options.size()}")
weighted_avg_cost = torch.matmul(curr_mask_values, self.params_options)
# Return the weighted average cost.
return weighted_avg_cost
def to(self, device):
"""
Overrides the original to() and also moves self.params_options
and self.theta_parameters.
"""
nn.Module.to(self, device)
self.params_options = self.params_options.to(device)
self.theta_parameters = self.theta_parameters.to(device)
def forward(self, indices, offsets, sampling="None", temperature=-1.0):
"""
Note that DLRM actually uses an nn.EmbeddingBag, not an nn.Embedding.
Thus, the input includes both indices and offsets. However, this
SuperNet only implements an nn.Embedding search and ignores these offsets.
"""
# Get the batch size.
curr_batch_size = int(list(indices.size())[0])
# Run the sampling if necessary.
if sampling == "soft":
self.soft_sample(temperature, curr_batch_size)
# Calculate the cost of the network.
self.curr_cost = self.calculate_cost()
# Create output for different cardinalities.
card_outputs = []
for i, current_card in enumerate(self.cardinality_options):
# Hash the indices.
curr_hashed_indices = indices % current_card
# Get the output for the correct embedding.
curr_output = self.embedding_options[i](curr_hashed_indices)
# Add to list of outputs.
card_outputs.append(curr_output)
# Take the weighted average of the outputs.
weighted_average_output = self.calculate_weighted_sum(self.mask_values[0], card_outputs, n_mats=self.num_card_options)
# Return the weighted average outputs.
return weighted_average_output
def sample_emb_arch(self):
"""
Hard-samples from the Gumbel Softmax distribution
and returns the resulting embedding cardinality as
an integer.
"""
# Hard-sample from the distribution.
sampled_mask_values = self.hard_sample()
# Find the one-hot index of the mask.
one_hot_ix = np.argmax(sampled_mask_values[0])
# Return the embedding cardinality size at
# that index.
sampled_card = self.cardinality_options[one_hot_ix]
return sampled_card
if __name__ == "__main__":
# Create a supernet.
supernet = EmbeddingCardSuperNet([10000000, 1000000, 100000, 10000, 1000], 128)
# Move it to GPU.
supernet.to("cuda:7")
print(f"PARAMS OPTIONS DEVICE: {supernet.params_options.device}")
print(f"THETA PARAMETERS DEVICE: {supernet.theta_parameters[0].device}")
# Run the forward pass with a batch size of 10.
random_indices = torch.Tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).to(dtype=torch.long).to("cuda:7")
random_offsets = None
sampling_type = "soft"
temperature = 0.1
avg_embs = supernet(indices=random_indices, offsets=random_offsets, sampling=sampling_type, temperature=temperature)
print(f"SIZE OF AVERAGE EMBEDDINGS TENSOR: {avg_embs.size()}")
print(f"MASK VALUES: {supernet.mask_values[0]}")
print(f"CURRENT COST: {supernet.curr_cost}")
# Try to sample architectures.
print(f"CURRENT THETA PARAMETERS: {supernet.theta_parameters[0]}")
archs = {k : 0 for k in [10000000, 1000000, 100000, 10000, 1000]}
for i in range(10000):
with torch.no_grad():
curr_arch = supernet.sample_emb_arch()
archs[curr_arch] += 1
print(f"SAMPLED ARCHITECTURE: {curr_arch} {archs}")