In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from sklearn.datasets import fetch_20newsgroups
import nltk
import numpy as np
import hashlib

In [None]:
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
num_classes = len(categories)
twenty_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)

In [None]:
def word_encoder(w, max_idx):
    v = int(hashlib.sha1(w.encode('utf-8')).hexdigest(), 16)
    return (v % (max_idx-1)) + 1

In [None]:
max_words = 10**6
num_hash = 2
num_buckets = 5000
embedding_dim = 20
num_classes = len(categories)
num_hidden_units = 50
learning_rate = 1e-4
agg_function = torch.sum
num_epochs = 5
test_idx = int(len(twenty_train.data)*0.5)

data = [nltk.word_tokenize(text)[0:150] for text in twenty_train.data]
data_encoded = [[word_encoder(w, max_words) for w in text] for text in data]
#max_len = max([len(d) for d in data])
#data_encoded = [d+[0]*(max_len-len(d)) for d in data_encoded]
data_encoded = [torch.LongTensor(d) for d in data_encoded]
targets = np.asarray(twenty_train.target, 'int32').reshape((-1,1))

In [None]:
class HashEmbedding(nn.Module):
    
    def __init__(self, num_words, num_hash_functions, num_buckets, embedding_size, agg_function):
        super(HashEmbedding, self).__init__()
        self.num_words = num_words # K
        self.num_hash_functions = num_hash_functions # k
        self.num_buckets = num_buckets # B
        self.embedding_size = embedding_size # d
        self.W = nn.Parameter(torch.FloatTensor(num_buckets, embedding_size)) # B x d
        self.agg_func = agg_function
        self.hash_table = torch.LongTensor(np.random.randint(0, 2**30,
                                size=(num_words, num_hash_functions)))%num_buckets # K x k
        
        self.P = nn.Parameter(torch.FloatTensor(num_words, num_hash_functions)) # K x k

    
    def forward(self, words_as_ids):
        hashes = torch.index_select(self.hash_table, 0, words_as_ids)
        z = torch.gather(self.hash_table, 0, hashes)
        embeddings = []
        for i in range(self.num_hash_functions):
            embeddings.append(torch.mul(self.W[z[:, i]].t(), self.P[z[:, i]][:, i]).t())
        cat_embeddings = torch.stack(embeddings, -1)
        return self.agg_func(cat_embeddings, -1)
    
    def initializeWeights(self):
        nn.init.normal(self.W, 0, 0.1)
        nn.init.normal(self.P, 0, 0.0005)

In [None]:
class Model(nn.Module):
    
    def __init__(self, embedding_model, num_classes, num_hidden_units):
        super(Model, self).__init__()
        self.embedding_model = embedding_model
        self.num_classes = num_classes
        self.dense_layer = nn.Linear(self.embedding_model.embedding_size, num_hidden_units)
        self.output_layer = nn.Linear(num_hidden_units, num_classes)
    
    def forward(self, words_as_ids):
        embedded = torch.sum(self.embedding_model(words_as_ids), 0)
        dense_output = F.relu(self.dense_layer(embedded))
        t = self.output_layer(dense_output)
        final_output = F.log_softmax(t, dim=0)
        return final_output
    
    def initializeWeights(self):
        nn.init.xavier_uniform(self.dense_layer.weight)
        nn.init.xavier_uniform(self.output_layer.weight)
        model.dense_layer.bias.data.zero_()
        model.output_layer.bias.data.zero_()

In [None]:
embedding_model = HashEmbedding(max_words, num_hash, num_buckets, embedding_dim, agg_function)
embedding_model.initializeWeights()

In [None]:
model = Model(embedding_model, num_classes, num_hidden_units)
model.initializeWeights()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters())+list(embedding_model.parameters()),
                             lr=learning_rate)

for _ in range(num_epochs):
    print("Epoch = {}".format(_))
    for (i, d) in enumerate(data_encoded[:test_idx]):
        output = model(d)
        loss = criterion(torch.unsqueeze(output, 0), Variable(torch.LongTensor(targets[i])))
        #if i%100 == 0:
        #    print("iter = {}, loss = {:.4f}".format(i, loss.data[0]))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    t = 0
    for (i, d) in enumerate(data_encoded[test_idx:]):
        pred = model(d).max(0)[1].data[0]
        t = t + (pred==targets[i+test_idx][0])
    print("Accuracy = {:.2f}".format(t*100/(len(data_encoded)-test_idx)))