In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from sklearn.datasets import fetch_20newsgroups
import nltk
import numpy as np
import hashlib
from torch.utils.data import TensorDataset, DataLoader

In [2]:
categories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']
num_classes = len(categories)
twenty_train = fetch_20newsgroups(subset='train', categories=categories, shuffle=True, random_state=42)

In [3]:
def word_encoder(w, max_idx):
    v = int(hashlib.sha1(w.encode('utf-8')).hexdigest(), 16)
    return (v % (max_idx-1)) + 1

In [4]:
max_words = 10**6
num_hash = 2
num_buckets = 5000
embedding_dim = 20
num_classes = len(categories)
num_hidden_units = 50
learning_rate = 1e-3
agg_function = torch.sum
num_epochs = 8
batch_size = 32
test_idx = int(len(twenty_train.data)*0.5)
max_len = 150

data = [nltk.word_tokenize(text)[0:150] for text in twenty_train.data]
data_encoded = [[word_encoder(w, max_words) for w in text] for text in data]
#max_len = max([len(d) for d in data])
data_encoded = torch.LongTensor([d+[0]*(max_len-len(d)) for d in data_encoded])
#data_encoded = [torch.LongTensor(d) for d in data_encoded]
targets = torch.LongTensor(np.asarray(twenty_train.target, 'int32'))

train_dataset = TensorDataset(data_encoded[:test_idx], targets[:test_idx])
test_dataset = TensorDataset(data_encoded[test_idx:], targets[test_idx:])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

In [5]:
class HashEmbedding(nn.Module):
    
    def __init__(self, num_words, num_hash_functions, num_buckets, embedding_size, agg_function):
        super(HashEmbedding, self).__init__()
        self.num_words = num_words # K
        self.num_hash_functions = num_hash_functions # k
        self.num_buckets = num_buckets # B
        self.embedding_size = embedding_size # d
        self.W = nn.Parameter(torch.FloatTensor(num_buckets, embedding_size)) # B x d
        self.agg_func = agg_function
        self.hash_table = torch.LongTensor(np.random.randint(0, 2**30,
                                size=(num_words, num_hash_functions)))%num_buckets # K x k
        
        self.P = nn.Parameter(torch.FloatTensor(num_words, num_hash_functions)) # K x k

    
    def forward(self, words_as_ids):
        embeddings = []
        pvals = []
        for i in range(self.num_hash_functions):
            hashes = torch.take(self.hash_table[:, i], words_as_ids)
            embeddings.append(self.W[hashes, :]*self.P[hashes, :][:, :, i].unsqueeze(-1))
            pvals.append(self.P[hashes, :][:, :, i].unsqueeze(-1))

        cat_embeddings = torch.stack(embeddings, -1)
        cat_embeddings = self.agg_func(cat_embeddings, -1)
        cat_pvals = torch.cat(pvals, -1)
        output = torch.cat([cat_embeddings, cat_pvals], -1)
        return output
        #return cat_embeddings
    
    def initializeWeights(self):
        nn.init.normal(self.W, 0, 0.1)
        nn.init.normal(self.P, 0, 0.0005)

In [6]:
class Model(nn.Module):
    
    def __init__(self, embedding_model, num_classes, num_hidden_units):
        super(Model, self).__init__()
        self.embedding_model = embedding_model
        self.num_classes = num_classes
        self.dense_layer = nn.Linear(self.embedding_model.embedding_size+self.embedding_model.num_hash_functions,
        #self.dense_layer = nn.Linear(self.embedding_model.embedding_size,
                                     num_hidden_units)
        self.output_layer = nn.Linear(num_hidden_units, num_classes)
    
    def forward(self, words_as_ids):
        mask = Variable(torch.unsqueeze(1-torch.eq(words_as_ids, 0).float(), -1))
        embedded = torch.sum(self.embedding_model(words_as_ids)*mask, 1)
        #import pdb; pdb.set_trace()
        dense_output = F.relu(self.dense_layer(embedded))
        final_output = self.output_layer(dense_output)
        return final_output
    
    def initializeWeights(self):
        nn.init.xavier_uniform(self.dense_layer.weight)
        nn.init.xavier_uniform(self.output_layer.weight)
        model.dense_layer.bias.data.zero_()
        model.output_layer.bias.data.zero_()

In [7]:
embedding_model = HashEmbedding(max_words, num_hash, num_buckets, embedding_dim, agg_function)
embedding_model.initializeWeights()

In [8]:
model = Model(embedding_model, num_classes, num_hidden_units)
model.initializeWeights()

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters())+list(embedding_model.parameters()),
                             lr=learning_rate)

for _ in range(num_epochs):
    print("Epoch = {}".format(_))
    for (i, d) in enumerate(train_dataloader):
        output = model(d[0])
        loss = criterion(output, Variable(torch.squeeze(d[1], -1)))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    correct = 0
    total = 0
    for (i, d) in enumerate(test_dataloader):
        pred = model(d[0]).max(1)[1].data
        correct = correct + (pred==d[1]).sum()
        total += pred.size(0)
    print("Accuracy = {:.2f}".format(correct*100/total))

Epoch = 0
Accuracy = 56.07
Epoch = 1
Accuracy = 88.04
Epoch = 2
Accuracy = 91.23
Epoch = 3
Accuracy = 90.26
Epoch = 4
Accuracy = 90.97
Epoch = 5
Accuracy = 92.38
Epoch = 6
Accuracy = 91.94
Epoch = 7
Accuracy = 92.21
