In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

In [12]:
class HashEmbedding(nn.Module):
    
    def __init__(self, num_words, num_hash_functions, num_buckets, embedding_size, agg_function):
        super(HashEmbedding, self).__init__()
        self.num_words = num_words # K
        self.num_hash_functions = num_hash_functions # k
        self.num_buckets = num_buckets # B
        self.embedding_size = embedding_size # d
        self.W = nn.Parameter(torch.FloatTensor(num_buckets, embedding_size)) # B x d
        self.agg_func = agg_function
        self.hash_table = torch.LongTensor(np.random.randint(0, 2**30,
                                size=(num_words, num_hash_functions)))%num_buckets # K x k
        
        self.P = nn.Parameter(torch.FloatTensor(num_words, num_hash_functions)) # K x k

    
    def forward(self, words_as_ids):
        hashes = torch.index_select(self.hash_table, 0, words_as_ids)
        z = torch.gather(self.hash_table, 0, hashes)
        embeddings = []
        for i in range(self.num_hash_functions):
            embeddings.append(torch.mul(self.W[z[:, i]].t(), self.P[z[:, i]][:, i]).t())
        cat_embeddings = torch.stack(embeddings, -1)
        return self.agg_func(cat_embeddings, -1)
    
    def initializeWeights(self):
        nn.init.normal(self.W, 0, 0.1)
        nn.init.normal(self.P, 0, 0.0005)

In [13]:
max_words = 10**6
num_hash = 2
num_buckets = 5000
embedding_dim = 20
agg_function = torch.sum

embedding_model = HashEmbedding(max_words, num_hash, num_buckets, embedding_dim, agg_function)
embedding_model.initializeWeights()

In [15]:
list(embedding_model.parameters())

[Parameter containing:
 -6.5621e-02 -2.9306e-01  7.4280e-02  ...  -8.2494e-02 -9.6310e-02 -2.9069e-02
  2.2511e-02 -2.3906e-02 -9.9755e-02  ...   1.6616e-01  1.3403e-03 -6.1546e-02
 -1.4984e-01  2.5936e-02  2.6952e-02  ...  -8.3221e-02 -8.0320e-02 -6.5011e-02
                 ...                   ⋱                   ...                
 -2.0936e-01  1.9901e-01 -1.0038e-01  ...   9.7263e-02  2.1844e-02  1.7956e-01
 -5.3256e-02  3.2948e-02  3.0606e-02  ...  -4.7108e-02 -2.0300e-02 -5.2797e-02
 -8.7721e-02  1.6017e-01 -8.0196e-02  ...   1.2384e-02  8.0157e-02  6.2865e-03
 [torch.FloatTensor of size 5000x20], Parameter containing:
  8.0226e-04  5.3378e-04
 -3.3916e-04  1.8015e-04
 -3.3546e-04 -8.7273e-04
            ⋮            
  7.3516e-04 -4.6255e-04
 -3.8472e-04 -1.8639e-04
 -1.1415e-03 -4.2809e-04
 [torch.FloatTensor of size 1000000x2]]

In [16]:
embedding_model

HashEmbedding(
)