# Creating Binary Tree From Inverted Index
We are given an inverted index of terms and the corresponding documents in which those terms appear. We also have the weight of the term within that document.  
We need to create a dense representation of each passage from this inverted index. 

In [48]:
import h5py
import pickle

In [10]:
# import the inverted index
# not sure if this is the right file to look at
indexing_fp = '../../../splade-copy/experiments/index-full-data/index/array_index.h5py'
f = h5py.File(indexing_fp, 'r')

In [11]:
list(f.keys())[:10]

['dim',
 'index_doc_id_100',
 'index_doc_id_1000',
 'index_doc_id_10000',
 'index_doc_id_10001',
 'index_doc_id_10002',
 'index_doc_id_10003',
 'index_doc_id_10004',
 'index_doc_id_10005',
 'index_doc_id_10006']

In [12]:
# get the unique term index ids
ids = set([x.split("_")[-1] for x in f.keys() if x != "dim"])

In [13]:
from collections import defaultdict

In [14]:
from tqdm import tqdm

In [15]:
# construct the dense representation of each passage
passages = defaultdict(dict)

for i in tqdm(ids):
    passage_idx = f[f"index_doc_id_{i}"][:]
    values = f[f"index_doc_value_{i}"][:]
    for ip, p in enumerate(passage_idx):
        passages[str(p)][i] = float(values[ip])

100%|██████████| 26098/26098 [00:28<00:00, 931.32it/s] 


In [16]:
# dense rep of passage 65527
passages["65527"]

{'3668': 1.5903515815734863,
 '2823': 0.06369511038064957,
 '2296': 1.0219107866287231,
 '4703': 0.14530520141124725,
 '2116': 0.21547049283981323,
 '5002': 0.11930034309625626,
 '2467': 0.6826394200325012,
 '5004': 0.8320001363754272,
 '2146': 2.235799551010132,
 '3160': 0.10148779302835464,
 '2205': 0.2753998041152954,
 '10984': 0.062274351716041565,
 '4256': 0.08308016508817673,
 '2558': 0.20775261521339417,
 '13162': 1.0077824592590332,
 '2681': 0.21058611571788788,
 '6098': 0.32030370831489563,
 '2933': 2.2978293895721436,
 '3198': 1.8764160871505737,
 '2411': 1.7142620086669922,
 '6976': 0.7544534206390381,
 '5075': 0.4816349446773529,
 '2054': 0.5362437963485718,
 '12026': 1.636215090751648,
 '3677': 0.2754783630371094,
 '4357': 0.1885817050933838,
 '2051': 0.13864165544509888,
 '6485': 0.39879557490348816,
 '3740': 0.1884268969297409,
 '2129': 0.7910882234573364,
 '3967': 0.06687205284833908,
 '2360': 0.001524953986518085,
 '2061': 0.24100495874881744,
 '3524': 0.61306142807006

In [17]:
import json

In [18]:
# save to json
with open("dense_passages.json", "w") as fp:
    json.dump(passages , fp)

In [10]:
import json
f = open("dense_passages.json")

passages = json.load(f)

In [11]:
passages

{'1487': {'6452': 0.05136897787451744,
  '2223': 0.09542778134346008,
  '2374': 0.31804773211479187,
  '3410': 0.6295368671417236,
  '3873': 1.4411897659301758,
  '10453': 0.005180981010198593,
  '2862': 0.07993192970752716,
  '8263': 0.09498373419046402,
  '2447': 0.23809461295604706,
  '5088': 0.13260714709758759,
  '4433': 2.353710651397705,
  '2355': 2.2908172607421875,
  '2161': 0.439339280128479,
  '2330': 0.22310397028923035,
  '9046': 0.06670334190130234,
  '2621': 0.006091007497161627,
  '6134': 0.24523115158081055,
  '4443': 1.6521860361099243,
  '2977': 0.32332518696784973,
  '2418': 1.1897939443588257,
  '3789': 0.16459393501281738,
  '2097': 1.199805736541748,
  '2602': 0.05072692409157753,
  '4060': 0.23658186197280884,
  '3607': 0.3629269003868103,
  '2724': 0.20663565397262573,
  '2782': 0.490182489156723,
  '2022': 0.5553063750267029,
  '7097': 2.474431276321411,
  '2136': 0.6067743301391602,
  '2267': 0.02332359552383423,
  '7462': 1.6742274761199951,
  '2713': 0.1073

# Making The Tree

In [2]:
import math
import numpy as np
import functools
from copy import deepcopy

In [3]:
def sparse_max(s1, s2):
    s3 = deepcopy(s2)
    for k, v in s1.items():
        s3[k] = max(s2.get(k, 0), v)
    return s3

def sparse_similarity(s1, s2):
    s3 = {}
    for k, v in s1.items():
        s3[k] = s2.get(k, 0) * v
    return sum(s3.values())

In [4]:
def next_smallest_power_of_2(x: int) -> int:
    """Get the next smallest power of 2

    Args:
        x (int)
    
    Result:
        next smallest power of 2 from x
    """
    exponent = math.floor(math.log2(x))
    return 2**exponent

In [5]:
def make_tree(array, min_bucket_size=100, func=...):
    """Make balanced binary tree with a minimum bucket size

    Aggregate using the given function
    """
    number_of_buckets = math.ceil(len(array) / min_bucket_size)
    balanced_number_bottom_buckets = next_smallest_power_of_2(number_of_buckets)
    bottom_bucket_size = math.ceil(len(array) / balanced_number_bottom_buckets)

    #TODO: you still need to store the indices of the underlying array 
    # associated with each aggregate score
    bottom_buckets = [
        functools.reduce(func, array[i:i+bottom_bucket_size]) 
        for i in range(0, len(array), bottom_bucket_size)]
    
    tree = {}
    bottom_level = int(np.log2(balanced_number_bottom_buckets))
    tree[bottom_level] = bottom_buckets
    bucks = bottom_buckets
    level = bottom_level - 1
    while len(bucks) > 2:
        b = [
                functools.reduce(func, bucks[i:i+1])
                for i in range(0, len(bucks), 2)
            ]
        tree[level] = b
        bucks = b
        level-=1
    
    return tree

In [6]:
def get_result(query, tree, max_depth=None):
    level = min(tree.keys())
    index = 0
    if not max_depth:
        max_depth = len(tree.keys())

    while (level in tree) and max_depth > 0:
        left = tree[level][index]
        right = tree[level][index+1]
        go_left = sparse_similarity(left, query) > sparse_similarity(right, query)
        index = index*2 if go_left else index*2+1
        level+=1
        max_depth-=1
    
    final_level = level-1
    return index, final_level

In [13]:
passage_list = list(passages.items())
passage_list

[('1487',
  {'6452': 0.05136897787451744,
   '2223': 0.09542778134346008,
   '2374': 0.31804773211479187,
   '3410': 0.6295368671417236,
   '3873': 1.4411897659301758,
   '10453': 0.005180981010198593,
   '2862': 0.07993192970752716,
   '8263': 0.09498373419046402,
   '2447': 0.23809461295604706,
   '5088': 0.13260714709758759,
   '4433': 2.353710651397705,
   '2355': 2.2908172607421875,
   '2161': 0.439339280128479,
   '2330': 0.22310397028923035,
   '9046': 0.06670334190130234,
   '2621': 0.006091007497161627,
   '6134': 0.24523115158081055,
   '4443': 1.6521860361099243,
   '2977': 0.32332518696784973,
   '2418': 1.1897939443588257,
   '3789': 0.16459393501281738,
   '2097': 1.199805736541748,
   '2602': 0.05072692409157753,
   '4060': 0.23658186197280884,
   '3607': 0.3629269003868103,
   '2724': 0.20663565397262573,
   '2782': 0.490182489156723,
   '2022': 0.5553063750267029,
   '7097': 2.474431276321411,
   '2136': 0.6067743301391602,
   '2267': 0.02332359552383423,
   '7462': 1.

In [14]:
# sort the passages by the passage id
sorted_passage_list = sorted(passage_list, key=lambda x: x[0])

In [15]:
# this is probably too many passages
len(sorted_passage_list)

276142

In [16]:
sorted_passage_list[0]

('0',
 {'5875': 0.27495330572128296,
  '2831': 0.06040031462907791,
  '6970': 0.09528885036706924,
  '2056': 0.2581539750099182,
  '3112': 0.06525882333517075,
  '3739': 2.2572896480560303,
  '2426': 1.3699175119400024,
  '5792': 0.09570199996232986,
  '12799': 0.04179960489273071,
  '2018': 0.16452254354953766,
  '7155': 1.165930986404419,
  '8826': 0.23535940051078796,
  '5921': 1.443000078201294,
  '4807': 2.431880474090576,
  '17171': 1.1926664113998413,
  '4471': 0.3649667799472809,
  '5081': 0.3923812210559845,
  '4167': 0.29389825463294983,
  '3271': 0.013347674161195755,
  '4301': 0.4548214375972748,
  '5456': 0.3844963610172272,
  '2470': 0.18671047687530518,
  '4069': 0.2388460338115692,
  '2045': 0.0013369916705414653,
  '2568': 1.3893468379974365,
  '10617': 0.020997727289795876,
  '13463': 2.3916304111480713,
  '4676': 0.05684340000152588,
  '4400': 0.009104876779019833,
  '8553': 0.008865873329341412,
  '2306': 0.6787437796592712,
  '12610': 0.2685873508453369,
  '20805':

In [17]:
# the order of the passages
sorted_passage_index_dense = [v[0] for v in sorted_passage_list]

In [82]:
# save the passage order
import pickle
with open("passage_index_order", "wb") as fp:   #Pickling
    pickle.dump(sorted_passage_index_dense, fp)

In [18]:
sorted_passage_dense = [v[1] for v in sorted_passage_list]

In [29]:
# construct the tree
binary_tree = make_tree(sorted_passage_dense, min_bucket_size=2000, func=sparse_max)

In [80]:
# save the tree
with open("binary_tree_scores.json", "w") as fp:
    json.dump(binary_tree , fp)

In [19]:
with open("binary_tree_scores.json", "r") as fp:
    xx = json.load(fp)

In [60]:
import pickle
sorted_passage_dense = pickle.loads("doc_ids.pkl")

TypeError: a bytes-like object is required, not 'str'

In [20]:
xx.keys()

dict_keys(['7', '6', '5', '4', '3', '2', '1'])

In [49]:
f = open("binary_tree_scores.json")

binary_tree = json.load(f)

In [50]:
binary_tree = {int(key): value for key, value in binary_tree.items()}

In [51]:
# how many nodes at each tree depth (level)
for k,v in binary_tree.items():
    print(k, len(v))

7 128
6 64
5 32
4 16
3 8
2 4
1 2


In [52]:
res = get_result(sorted_passage_dense[0], binary_tree)

In [53]:
res

(0, 7)

In [55]:
qi = 1
q = sorted_passage_dense[qi]  # use the first passage as a dummy query
max_depth = None # specify how many levels into the tree you want to go
res_index, lev = get_result(q, binary_tree, max_depth)  # get the index and level of the tree where the result is located

bucket_size = math.ceil(len(sorted_passage_dense) / len(binary_tree[lev]))  # 
bucket_start = bucket_size*res_index
bucket_end = bucket_size*(res_index+1)

res_passages = sorted_passage_dense[bucket_start:bucket_end]
res_indx = sorted_passage_index_dense[bucket_start:bucket_end]

tree_scores = [sparse_similarity(r, q) for r in res_passages]
tree_loc = np.argmax(tree_scores) + bucket_start
len(res_indx)
# tree_loc

# # look up the passage index
# sorted_passage_index_dense[tree_loc]

2158

In [57]:
bucket_end

2158

In [93]:
sorted_passage_index_dense[0]

'0'

In [92]:
bucket_size

2158

['0', '1', '10', '100', '1000']

In [71]:
sparse_similarity(sorted_passage_dense[tree_loc], q)

15.505542186613347

# Hierarchical Sorting

In [72]:
# imports
from scipy.cluster import hierarchy
from sklearn.metrics import pairwise_distances

In [73]:
def dense_to_sparse(d, len):
    output_vec = np.zeros(len)
    for k, v in d.items():
        output_vec[k] = v
    return output_vec

In [None]:
passages_sparse = [dense_to_sparse(p, tokenizer.vocab_size) for p in passages]
links = hierarchy.linkage(np.vstack(passages_sparse), method="ward")
optimal_order = hierarchy.leaves_list(hierarchy.optimal_leaf_ordering(links, passages_sparse))
hierarchical_ordered_passages = np.array(passages)[optimal_order]
hier_dict = dict(
    zip(
        optimal_order,
        range(len(optimal_order))
    )
)