# Relevance Based Word Embeddings 

This notebook implements the model for training relevance based word embeddings. The model is maximum likelihood model which is based on bag-of-words representation of queries, and the output words are the top-k words occuring in the relevant documents to the input query. 

In [None]:
import numpy as np
import torch.nn as nn
import torch
import heapq
import pandas as pd
import pickle
from tqdm import tqdm
import gzip
import csv

# Binary Huffman Tree

As descirbed in the paper as well, we need to use some efficient approximation of softmax in order to train the maximum likelihood model. I decided to implement the hierarchical softmax based on binary huffman tree. The following two cells are the implementation of Binary Huffman tree. The hierarchical softmax has been implemented based on matrix multiplication. This matrix multiplication based hierarchical softmax provides improvement over traditional hierarchical softmax in which to estimate the probability of word, we follow the path from root node to the leaf node corresponding to that word. Following approach estimates the probabilities for all the output words in one shot. No need to follow the path linearly from the root to leaf node. 

In [None]:
class Node:
    def __init__(self, token, freq):
        self.token = token
        self.freq = freq
        self.left = None
        self.right = None
        self.idx = None
        
    def __lt__(self, other):
        return self.freq < other.freq
    
    def __gt__(self, other):
        return self.freq > other.freq
    
    def __eq__(self, other):
        if(other == None):
            return False
        if(not isinstance(other, Node)):
            return False
        return self.freq == other.freq

In [None]:
class HuffmanTree:
    def __init__(self):
        self.heap = []
        self.codes = {}
        self.reverse_mapping = {}
        self.path_tensors = {}
        self.selectors = {}
        self.token_ids = {}
        self.root = None
        self.global_path_tensor = None
        self.global_selector = None
        self.global_splitter = None
        self.global_split_map = {}
        
        
    def make_heap(self, frequency):
        for key in frequency:
            node = Node(key, frequency[key])
            heapq.heappush(self.heap, node)
            
    def merge_nodes(self):
        while(len(self.heap)>1):
            node1 = heapq.heappop(self.heap)
            node2 = heapq.heappop(self.heap)
            
            merged = Node(None, node1.freq + node2.freq)
            merged.left = node1
            merged.right = node2
            heapq.heappush(self.heap, merged)
            
    def make_codes_helper(self, root, current_code):
        if(root==None):
            return
        if(root.token != None):
            self.codes[root.token] = current_code
            self.reverse_mapping[current_code] = root.token
            return
        
        self.make_codes_helper(root.left, current_code + "0")
        self.make_codes_helper(root.right, current_code + "1")
        
    def make_codes(self):
        root = heapq.heappop(self.heap)
        self.root = root
        current_code = ""
        self.make_codes_helper(root, current_code)
        
    def assign_ids(self):
        root = self.root
        queue = []
        queue.append(root)
        idx = 0
        
        while(len(queue)):
            root = queue.pop(0)
            
            if root.left:
                queue.append(root.left)
            if root.right:
                queue.append(root.right)
                
            if (root.left or root.right):
                root.idx = idx 
                idx +=1
                
    def assign_path_tensors(self):
        for k,v in self.codes.items():
            path = []
            for c in v:
                if(c=='0'):
                    path.append(1)
                else:
                    path.append(-1)
            self.path_tensors[k] = torch.FloatTensor(path)
            
        self.global_path_tensor = torch.FloatTensor(torch.cat(list(self.path_tensors.values()))).cuda()
        
            
    def assign_selectors(self):
        for k,v in self.codes.items():
            root = self.root
            sel = []
            sel.append(root.idx)

            for c in v[:-1]:
                if c=='0':
                    root = root.left
                else:
                    root = root.right
                sel.append(root.idx)
                
            self.selectors[k] = torch.LongTensor(sel)
        self.global_selector = torch.LongTensor(torch.cat(list(self.selectors.values()))).cuda()
        
    def build_global_splitter(self):
        splitter = []
        for k, v in self.selectors.items():
            splitter.append(list(v.shape)[0])
        self.global_splitter = splitter
        self.global_split_map = dict(zip(list(self.codes.keys()), range(len(self.codes))))
        
        codes_len_list = [len(code) for code in self.codes.values()]
        assert(codes_len_list == self.global_splitter)
        
    def build(self, vocab):
        self.make_heap(vocab)
        self.merge_nodes()
        self.make_codes()
        self.assign_ids()
        self.assign_path_tensors()
        self.assign_selectors()
        self.build_global_splitter()
        self.token_ids = dict(zip(range(len(vocab)), list(vocab.keys())))
        
        for k,v in self.token_ids.items():
            assert(k==v)

# Relevance based word embeddings model

The code in the following cell is Relevance model which estimates the probabilities of words of being appearing in relevant documents for the given input query. We train this model using Negative Log likelihood and cross entropy loss function, exactly as described in the original work. 

In [None]:
class RelevanceModel(nn.Module):
    
    def __init__(self, vocab):
        super(RelevanceModel, self).__init__()
        self.vocab_size = len(vocab)
        self.embedding_dim = 300
        
        self.embedding = nn.Linear(self.vocab_size, self.embedding_dim).cuda()
        self.hs = nn.Linear(self.vocab_size-1, self.embedding_dim).cuda()
        
        self.ht = HuffmanTree()
        self.ht.build(vocab)
        
    def forward(self, q, words):
        qs = self.embedding(q)
        h = torch.mean(qs, axis=0).cuda()
        
        selector = list(map(self.ht.selectors.__getitem__, words))
        splitter = [list(sel.shape)[0] for sel in selector]
        
        select = torch.cat(selector).cuda()
        path = torch.cat(list(map(self.ht.path_tensors.__getitem__, words))).cuda()
        
        A = self.hs(select)
        
        dots = torch.matmul(A, h)
        apply_paths = dots * path
        
        apply_sigmoid = torch.sigmoid(apply_paths)
        apply_log = -1*torch.log(apply_sigmoid)
        
        splits = torch.split(apply_log, splitter, dim=0)
        
        log_probs = torch.stack([torch.sum(split) for split in splits])
        
        return log_probs
        
        

# DataLoader 

The dataframe `data_frame.p` is the training data. This data contains for each query, the top 500 words and their probabilities. The probabilities were estimated using the state-of-the-art language modelling approach. There is a seperate notebook for the preparation of the training data from the corpus. 

In [None]:
from torch.utils.data import Dataset, DataLoader
class Rel_dataset(Dataset):
    
    def __init__(self):
        self.df = pd.read_pickle('data_frame.p')
        self.df.columns = list(range(1000))
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        
        words = self.df.iloc[index][:500].values.astype(np.int64)
        probs = self.df.iloc[index][500:].values.astype(np.float32)
        name = self.df.iloc[index].name
        
        return name, words, probs
    
dataset = Rel_dataset()
dataloader = DataLoader(dataset, batch_size = 1)

In [None]:
v = pickle.load(open("rel_vocab.p", "rb"))
vocab = dict(zip(range(len(v)), list(v.values())))

model = RelevanceModel(vocab)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1, momentum=0.9)
def my_loss(out_probs, probs):
    loss = out_probs * probs
    tot_loss = torch.sum(loss)
    return tot_loss

In [None]:
qvecs = torch.load("qvecs.p")
num_epochs = 100

losses = []
PATH = "model_checkpoints/model"

for epoch in range(num_epochs):
    print("Epoch : ",epoch)
    epoch_loss = 0
    for name, words, probs in tqdm(dataloader):    
        q = qvecs[name[0]]
        out_probs = model(q, words.tolist()[0])
        probs = probs.cuda()
        loss = my_loss(out_probs, probs)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
        epoch_loss = epoch_loss + loss.data
    
        
    losses.append(epoch_loss) 
    print("Epoch : ", epoch, "\t Loss : ", epoch_loss)
    
    MODEL_PATH = PATH + str(epoch) + ".pt"
    
    torch.save({
        'epoch' : epoch,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'loss' : epoch_loss
    }, MODEL_PATH)
    
pickle.dump(losses, open('training_loss.p', 'wb'))