In [2]:
import numpy as np
import torch.nn as nn
import torch
import heapq

In [25]:
class Node:
    def __init__(self, token, freq):
        self.token = token
        self.freq = freq
        self.left = None
        self.right = None
        self.idx = None
        
    def __lt__(self, other):
        return self.freq < other.freq
    
    def __gt__(self, other):
        return self.freq > other.freq
    
    def __eq__(self, other):
        if(other == None):
            return False
        if(not isinstance(other, Node)):
            return False
        return self.freq == other.freq

In [52]:
class HuffmanTree:
    def __init__(self):
        self.heap = []
        self.codes = {}
        self.reverse_mapping = {}
        self.root = None
        
    def make_heap(self, frequency):
        for key in frequency:
            node = Node(key, frequency[key])
            heapq.heappush(self.heap, node)
            
    def merge_nodes(self):
        while(len(self.heap)>1):
            node1 = heapq.heappop(self.heap)
            node2 = heapq.heappop(self.heap)
            
            merged = Node(None, node1.freq + node2.freq)
            merged.left = node1
            merged.right = node2
            heapq.heappush(self.heap, merged)
            
    def make_codes_helper(self, root, current_code):
        if(root==None):
            return
        if(root.token != None):
            self.codes[root.token] = current_code
            self.reverse_mapping[current_code] = root.token
            return
        
        self.make_codes_helper(root.left, current_code + "0")
        self.make_codes_helper(root.right, current_code + "1")
        
    def make_codes(self):
        root = heapq.heappop(self.heap)
        self.root = root
        current_code = ""
        self.make_codes_helper(root, current_code)
        
    def assign_ids(self):
        root = self.root
        queue = []
        queue.append(root)
        idx = 0
        
        while(len(queue)):
            root = queue.pop(0)
            
            if root.left:
                queue.append(root.left)
            if root.right:
                queue.append(root.right)
                
            if (root.left or root.right):
                root.idx = idx 
                idx +=1

In [53]:
def make_huffman_tree(vocab):
    ht = HuffmanTree()
    ht.make_heap(vocab)
    ht.merge_nodes()
    ht.make_codes()
    ht.assign_ids()
    
    return ht

In [None]:
vocab = pickle.load(open("vocab.p", "rb"))
vocab_size = len(vocab)

hs_weights = torch.randn((vocab_size-1, 300), dtype=torch.float, requires_grad=True)
word_embeddings = torch.randn((vocab_size, 300), dtype=torch.float, requires_grad=True)

ht = make_huffman_tree(vocab)

In [None]:
class Dataset()

In [None]:
class RelevanceModel(nn.Module):
    
    def __init__(self, vocab):
        vocab_size = len(vocab)
        
        hs_weights = Variable(torch.randn((vocab_size-1, 300), dtype=torch.float, requires_grad=True))
        word_embeddings = Variable(torch.randn((vocab_size, 300), dtype=torch.float, requires_grad=True))
        
        self.query_embedding = nn.Linear(self.vocab_size, 300, bias=False)
        
    def forward(self, q, word)

In [54]:
def char_range(c1, c2):
    """Generates the characters from `c1` to `c2`, inclusive."""
    for c in range(ord(c1), ord(c2)+1):
        yield chr(c) 

In [73]:
voc = {}
voc['a'] = 641
voc['b'] = 589
voc['c'] = 938
voc['d'] = 312
voc['e'] = 254
voc['f'] = 932
voc['g'] = 0
voc['h'] = 714

In [75]:
ht = make_huffman_tree(voc)

In [76]:
ht.codes

{'f': '00',
 'c': '01',
 'g': '10000',
 'e': '10001',
 'd': '1001',
 'b': '101',
 'a': '110',
 'h': '111'}

In [77]:
voc

{'a': 641, 'b': 589, 'c': 938, 'd': 312, 'e': 254, 'f': 932, 'g': 0, 'h': 714}

In [118]:
path_tensors = {}

for k,v in ht.codes.items():
    path = []
    for c in v:
        if c=='0':
            path.append(1)
        else:
            path.append(-1)
    path_tensors[k] = torch.tensor(path, dtype=torch.float)

In [119]:
path_tensors

{'f': tensor([1., 1.]),
 'c': tensor([ 1., -1.]),
 'g': tensor([-1.,  1.,  1.,  1.,  1.]),
 'e': tensor([-1.,  1.,  1.,  1., -1.]),
 'd': tensor([-1.,  1.,  1., -1.]),
 'b': tensor([-1.,  1., -1.]),
 'a': tensor([-1., -1.,  1.]),
 'h': tensor([-1., -1., -1.])}

In [109]:
selectors = {}

for k,v in ht.codes.items():
    root = ht.root
    sel = []
    sel.append(root.idx)
    
    for c in v[:-1]:
        if c=='0':
            root = root.left
        else:
            root = root.right
        sel.append(root.idx)
    selectors[k] = torch.tensor(sel)

In [110]:
selectors

{'f': tensor([0, 1]),
 'c': tensor([0, 1]),
 'g': tensor([0, 2, 3, 5, 6]),
 'e': tensor([0, 2, 3, 5, 6]),
 'd': tensor([0, 2, 3, 5]),
 'b': tensor([0, 2, 3]),
 'a': tensor([0, 2, 4]),
 'h': tensor([0, 2, 4])}

In [111]:
root = ht.root

v = ht.codes['a']
print('s', root.idx)
for c in v[:-1]:
    if c=='0':
        root = root.left
    else:
        root = root.right
    print(c, root.idx)

s 0
1 2
1 4


In [112]:
we = torch.randn((len(voc)-1, 300))

In [113]:
h = torch.randn(300)

In [114]:
m1 = we[selectors['a']]

In [115]:
mul = torch.matmul(m1, h)

In [116]:
mul

tensor([-21.2404,  -8.4863,   4.7103])

In [121]:
prods = mul * path_tensors['a']

In [150]:
t1 = torch.stack((torch.tensor([0, 1]), torch.tensor([0, 1])))

In [154]:
l1 = list(selectors.values())

In [160]:
l1 = [i.unsqueeze(0) for i in l1]

In [177]:
l1 = [torch.tensor([0, 1]),
 torch.tensor([0, 1]),
 torch.tensor([0, 2, 3, 5, 6]),
 torch.tensor([0, 2, 3, 5, 6]),
 torch.tensor([0, 2, 3, 5]),
 torch.tensor([0, 2, 3]),
 torch.tensor([0, 2, 4]),
 torch.tensor([0, 2, 4])]

In [200]:
l3 = torch.tensor([i.shape for i in l1])
l3 = tuple(l3)

In [230]:
l3

(tensor([2]),
 tensor([2]),
 tensor([5]),
 tensor([5]),
 tensor([4]),
 tensor([3]),
 tensor([3]),
 tensor([3]))

In [191]:
t1 = torch.cat(l1)

In [232]:
t1.shape

torch.Size([27])

In [195]:
m1 = W[t1]

In [196]:
m1.shape

torch.Size([27, 300])

In [204]:
m2 = torch.split(m1, l3, dim=0)

In [211]:
torch.cat(m2, dim=0).shape

torch.Size([27, 300])

In [213]:
m3 = torch.cat(tuple(path_tensors.values()))

In [215]:
m3.shape

torch.Size([27])

In [217]:
m3

tensor([ 1.,  1.,  1., -1., -1.,  1.,  1.,  1.,  1., -1.,  1.,  1.,  1., -1.,
        -1.,  1.,  1., -1., -1.,  1., -1., -1., -1.,  1., -1., -1., -1.])

In [218]:
m21 = torch.cat(m2, dim=0)

In [219]:
m21.shape

torch.Size([27, 300])

In [220]:
m3.shape

torch.Size([27])

In [224]:
h = torch.randn(300)

In [226]:
h.shape, m21.shape

(torch.Size([300]), torch.Size([27, 300]))

In [231]:
torch.split(torch.matmul(m21, h) * m3, l3)

(tensor([ -8.1478, -21.9214]),
 tensor([-8.1478, 21.9214]),
 tensor([  8.1478,  33.1918,  -7.2199, -17.2290,   6.6762]),
 tensor([  8.1478,  33.1918,  -7.2199, -17.2290,  -6.6762]),
 tensor([ 8.1478, 33.1918, -7.2199, 17.2290]),
 tensor([ 8.1478, 33.1918,  7.2199]),
 tensor([  8.1478, -33.1918,  24.0853]),
 tensor([  8.1478, -33.1918, -24.0853]))