In [None]:
# dependency headers
import numpy as np
import struct
import os
import time
import random
import threading
from scipy.spatial import distance
from concurrent.futures import ThreadPoolExecutor

In [None]:
%pip install numpy



In [None]:
def distance_l2(a, b):
    return np.sum((a - b) ** 2)

def distance_inner_product(a, b):
    return np.dot(a, b)

In [None]:
CompactGraph = list[list[int]]
class IndexNSG:
    def __init__(self, dimension, n, metric):
        self.dimension = dimension
        self.n = n
        if callable(metric):
            self.distance = metric
        elif metric == "L2":
            self.distance = distance_l2
        elif metric == "inner_product":
            self.distance = distance_inner_product
        else:
            raise ValueError("Unsupported metric type")
        self.data = None
        self.final_graph = [[] for _ in range(n)]
        self.locks = [threading.Lock() for _ in range(n)]
        self.has_built = False
        self.ep = 0

K-NN Loader

In [None]:
def load_data(filename):
    with open(filename, 'rb') as file:
        dim_bytes = file.read(4)
        dim = struct.unpack('<I', dim_bytes)[0]
        file.seek(0, os.SEEK_END)
        fsize = file.tell()
        num = fsize // ((dim + 1) * 4)
        data = np.empty((num, dim), dtype=np.int32)
        file.seek(0)
        for i in range(num):
            file.seek(4, os.SEEK_CUR)
            data_bytes = file.read(dim * 4)
            data[i] = np.array(struct.unpack('<' + 'i'*dim, data_bytes))

    return data, num, dim

In [None]:
filename = "/content/sift.50NN.graph"
data, num, dim = load_data(filename)
print("Data loaded:")
print(data)
print(f"Number of data points: {num}, Dimension of each point: {dim}")


Data loaded:
[[   2    6 4585 ... 3150 4560 4331]
 [   3   14    7 ... 4678 4640 8720]
 [   0    6    4 ... 8797 1331 4399]
 ...
 [9868 9841 7932 ... 8052 5159 9721]
 [9640 9592  505 ... 6114  298 3578]
 [9911 9927 3996 ...  857 9334 2182]]
Number of data points: 10000, Dimension of each point: 50


NSG Loader

In [None]:
#test_nsg_search invokes index_nsg : load
def Load(self, filename):
    with open(filename, 'rb') as f:
        self.width = struct.unpack('I', f.read(4))[0]
        self.ep_ = struct.unpack('I', f.read(4))[0]
        # cc = 0
        while True:
            k_bytes = f.read(4)
            if not k_bytes:
                break
            k = struct.unpack('I', k_bytes)[0]

            tmp_bytes = f.read(k * 4)
            tmp = struct.unpack(f'{k}I', tmp_bytes)
            # cc += tmp

            self.final_graph.append(list(tmp))
    return self.final_graph

    # cc /= self.nd_

In [None]:
filename = "/content/siftsmall0.nsg"
self = IndexNSG()
final_graph = Load(self, filename)
print("Rows:", len(final_graph))
print("Columns:", len(final_graph[0]))


Rows: 10000
Columns: 12


Parameters Class

In [None]:
class Parameters:
    def __init__(self):
        self.params = {}

    def set(self, name, value):
        self.params[name] = value

    def get(self, name, default=None):
        return self.params.get(name, default)

Neighbor Class

In [None]:
class Neighbor:
    def __init__(self, id=0, distance=0, flag=False):
        self.id = id
        self.distance = distance
        self.flag = flag

    def __lt__(self, other):
        return self.distance < other.distance

def insert_into_pool(addr, K, nn):
    # Find the location to insert
    left, right = 0, K - 1
    if addr[left].distance > nn.distance:
        addr[left + 1 : left + K + 1] = addr[left : left + K]
        addr[left] = nn
        return left

    if addr[right].distance < nn.distance:
        addr[K] = nn
        return K

    while left < right - 1:
        mid = (left + right) // 2
        if addr[mid].distance > nn.distance:
            right = mid
        else:
            left = mid

    # Check equal ID
    while left > 0:
        if addr[left].distance < nn.distance:
            break
        if addr[left].id == nn.id:
            return K + 1
        left -= 1

    if addr[left].id == nn.id or addr[right].id == nn.id:
        return K + 1

    addr[right + 1 : K + 1] = addr[right : K]
    addr[right] = nn
    return right

Search Function

In [None]:
def search(self, query : list[float], x : list[float], K : list[float], parameters : Parameters):
        indices = np.empty(K,int)
        # print("L: ",parameters.get("L_search"))
        L = int(parameters.get("L_search"))
        data = x
        retset = np.empty(L+1,Neighbor)
        init_ids = np.zeros(L,int)
        flags = np.empty(self.n,bool)
        for b in flags:
            flags[b]=False
        # print(flags)

        tmp_l = 0
        while tmp_l < L and tmp_l < len(self.final_graph[self.ep]):
            init_ids[tmp_l] = self.final_graph[self.ep][tmp_l]
            flags[init_ids[tmp_l]] = True
            tmp_l +=1

        while tmp_l < L:
            id = random.randint(0,self.n-1)
            if flags[id]: continue
            flags[id] = True
            init_ids[tmp_l] = id
            tmp_l+=1

        # print("finished random")

        for i in range(0,len(init_ids)):
            id = init_ids[i]
            dist = self.distance(data[id],query)
            retset[i] = Neighbor(id, dist, True)

        retset[0:L] = sorted(retset[0:L])
        # print(retset)
        k = 0
        while k < L:
            nk = L

            if retset[k].flag:
                retset[k].flag = False
                n = retset[k].id

                for m in range(0,len(self.final_graph[n])):
                    id = self.final_graph[n][m]
                    if flags[id]: continue
                    flags[id] = True
                    dist = self.distance(query,data[id])
                    if dist >= retset[L-1].distance: continue
                    nn = Neighbor(id,dist,True)
                    # TODO: make InsertIntoPool
                    r = insert_into_pool(retset,L,nn)

                    if r < nk:
                        nk = r
            if nk <= k:
                k = nk
            else:
                k+=1
        for i in range(0,K):
            indices[i] = retset[i].id

        return indices

Text_NSG_Search

In [None]:
def test_nsg_search_main(args):
  if(len(args) != 5):
        print("data_file query_file nsg_path search_L search_K")
        exit(-1)

  filename0 = args[0]
  data_load, points_num, dim = load_data(filename)

  filename1 = args[1]
  query_load, query_num, query_dim = load_data(filename)
  assert dim == query_dim

  L = int(args[3])
  K = int(args[4])
  if( L < K):
      print("search_L cannot be smaller than search_K")
      exit(-1)

  index = IndexNSG(dim,points_num,"L2")
  Load(self,args[2])

  paras = Parameters()
  paras.set("L_search", L)
  paras.set("P_search", L)

  start = time.process_time()
  res = []
  for i in range(0,query_num):
      tmp = search(index,query_load[(i*dim):],data_load,K,paras)
      res.append(tmp)
  end = time.process_time()
  diff = end - start
  print("Search Time: ", diff)

In [None]:
args = list()
args.append("/content/siftsmall_base.fvecs")
args.append("/content/siftsmall_query.fvecs")
args.append("/content/siftsmall0.nsg")
args.append(50)
args.append(40)
test_nsg_search_main(args)
print(args)

Search Time:  168.69698365000002
['/content/siftsmall_base.fvecs', '/content/siftsmall_query.fvecs', '/content/siftsmall0.nsg', 50, 40]


NSG Build

In [None]:
class IndexNSG:
    def __init__(self, dimension, n, metric):
        self.dimension = dimension
        self.n = n
        if callable(metric):
            self.distance = metric
        elif metric == "L2":
            self.distance = distance_l2
        elif metric == "inner_product":
            self.distance = distance_inner_product
        else:
            raise ValueError("Unsupported metric type")
        self.data = None
        self.final_graph = [[] for _ in range(n)]
        self.locks = [threading.Lock() for _ in range(n)]
        self.has_built = False
        self.ep = 0

    def build(self, n, data, parameters):
        nn_graph_path = parameters.get(self, '/content/sift.50NN.graph')
        range_r = 12
        self.load_nn_graph(nn_graph_path)
        self.data = data
        self.init_graph(parameters)
        cut_graph = [Neighbor() for _ in range(n * range_r)]
        self.link(parameters, cut_graph)
        self.final_graph = [[] for _ in range(n)]

        for i in range(n):
            pool = cut_graph[i * range_r:(i + 1) * range_r]
            pool_size = 0
            for neighbor in pool:
                if neighbor.distance == -1:
                    break
                pool_size += 1

            self.final_graph[i] = [neighbor.id for neighbor in pool[:pool_size]]

        self.tree_grow(parameters)

        max_degree = max(len(g) for g in self.final_graph)
        min_degree = min(len(g) for g in self.final_graph)
        avg_degree = sum(len(g) for g in self.final_graph) / n
        print(f"Degree Statistics: Max = {max_degree}, Min = {min_degree}, Avg = {avg_degree:.2f}")

        self.has_built = True

    def load_nn_graph(self, filename):
        with open(filename, 'rb') as f:
            graph_data = np.fromfile(f, dtype=np.uint32)
            k = graph_data[0]
            self.final_graph = graph_data.reshape(-1, k+1)[:,1:].tolist()

    def init_graph(self, parameters):
        center = np.mean(self.data, axis=0)
        self.ep = random.randint(0, self.n - 1)
        retset, _ = self.get_neighbors(center, parameters)
        self.ep = retset[0].id

    def link(self, parameters, cut_graph):
        range_r = parameters.get('R')
        step_size = self.n // 100
        locks = [threading.Lock() for _ in range(self.n)]

        def process_node(n):
            print("run", n)
            point = self.data[n]
            pool, tmp = [], []
            flags = np.zeros(self.n, dtype=bool)

            tmp, pool = self.get_neighbors(point, parameters, flags)
            self.sync_prune(n, pool, parameters, flags, cut_graph)


        # First parallel execution block
        with ThreadPoolExecutor(max_workers=step_size) as executor:
            list(executor.map(process_node, range(self.n)))

        # InterInsert now includes locks as a parameter
        def call_interinsert(n):
            self.inter_insert(n, range_r, locks, cut_graph)

        # Second parallel execution block
        with ThreadPoolExecutor(max_workers=step_size) as executor:
            list(executor.map(call_interinsert, range(self.n)))

    def tree_grow(self, parameter):
        root = self.ep
        flags = [False] * self.n
        unlinked_cnt = 0
        while unlinked_cnt < self.n:
            flags, unlinked_cnt = self.DFS(root, flags, unlinked_cnt)
            if unlinked_cnt >= self.n:
                break
            self.findroot(flags, root, parameter)

        for i in range(self.n):
            if len(self.final_graph[i]) > self.width:
                self.width = len(self.final_graph[i])

    def findroot(self, flag, root, parameter):
        id = self.n
        for i in range(self.n):
            if not flag[i]:
                id = i
                break

        if id == self.n:
            return  # No Unlinked Node

        tmp, pool = self.get_neighbors(self.data[id], parameter)
        pool.sort()

        found = False
        for neighbor in pool:
            if flag[neighbor.id]:
                root = neighbor.id
                found = True
                break

        if not found:
            while True:
                rid = random.randint(0, self.n - 1)
                if flag[rid]:
                    root = rid
                    break

        self.final_graph[root].append(id)

    def DFS(self, root, flag, cnt):
        tmp = root
        stack = [root]
        if not flag[root]:
            cnt += 1
        flag[root] = True
        while stack:
            next_node = None
            for neighbor in self.final_graph[tmp]:
                if not flag[neighbor]:
                    next_node = neighbor
                    break
            if next_node is None:
                stack.pop()
                if not stack:
                    break
                tmp = stack[-1]
                continue
            tmp = next_node
            flag[tmp] = True
            stack.append(tmp)
            cnt += 1
        return flag, cnt

    def sync_prune(self, q, pool, parameter, flags, cut_graph):
        range_r = parameter.get('R')
        maxc = parameter.get('C')
        self.width = range_r
        start = 0

        # Collect eligible neighbors
        for nn in self.final_graph[q]:
            if flags[nn]:
                continue
            dist = distance.euclidean(self.data[q], self.data[nn])
            pool.append(Neighbor(nn, dist))

        # Sort neighbors by distance
        pool.sort()

        # Start could be incremented if the closest neighbor is itself (q)
        if len(pool) > 0 and pool[start].id == q: start += 1
        if pool: result = [pool[start]]
        else: result = []

        # Prune the pool to meet the criteria
        while len(result) < range_r and start < len(pool) and start < maxc:
            p = pool[start]
            occlude = False
            for t in result:
                if p.id == t.id:
                    occlude = True
                    break
                # Check if another neighbor is closer to t than p is to q
                djk = distance.euclidean(self.data[t.id], self.data[p.id])
                if djk < p.distance:
                    occlude = True
                    break
            if not occlude:
                result.append(p)
            start += 1

        # Update the cut graph with the pruned results
        for t, neighbor in enumerate(result):
            cut_graph[q * range_r + t] = Neighbor(id=neighbor.id, distance=neighbor.distance)

        # Mark unused slots with a special flag (-1)
        if len(result) < range_r:
            cut_graph[q * range_r + len(result)].distance = -1

    def inter_insert(self, n, range_r, locks, cut_graph):

        for i in range(range_r):
            if cut_graph[n*range_r+i].distance == -1: break

            sn = Neighbor(n, cut_graph[n*range_r+i].distance)
            des = cut_graph[n*range_r+i].id

            temp_pool = []
            dup = False
            lock = locks[des]
            with lock:
                for j in range(range_r):
                    if cut_graph[des*range_r+j].distance == -1: break
                    if n == cut_graph[des*range_r+j].id: dup = True; break
                    temp_pool.append(cut_graph[des*range_r+j])

            if dup: continue

            temp_pool.append(sn)
            if len(temp_pool) > range_r:
                start = 0
                temp_pool.sort()
                result = [temp_pool[start]]
                while len(result) < range_r and start + 1 < len(temp_pool):
                    start += 1
                    p = temp_pool[start]
                    occlude = False
                    for t in result:
                        if p.id == t.id:
                            occlude = True
                            break
                        djk = distance.euclidean(self.data[t.id], self.data[p.id])
                        if djk < p.distance:
                            occlude = True
                            break
                    if not occlude:
                        result.append(p)

                with lock:
                    for i,t in enumerate(result):
                        cut_graph[des*range_r+i] = t
            else:
                with lock:
                    for t in range(range_r):
                        if cut_graph[t+des*range_r].distance == -1:
                            cut_graph[t+des*range_r] = sn
                            if t + 1 < range_r:
                                cut_graph[t + 1 + des*range_r].distance = -1
                            break
    def save(self, filename):
        assert len(self.final_graph) == self.n, "Final graph size does not match the expected size"
        with open(filename, 'wb') as f:
            f.write(struct.pack('I', self.width))
            f.write(struct.pack('I', self.ep))
            for neighbors in self.final_graph:
                GK = len(neighbors)
                f.write(struct.pack('I', GK))
                if GK > 0:
                    f.write(struct.pack(f'{GK}I', *neighbors))

    def get_neighbors(self, query, parameters, flags = np.zeros(10000,  dtype=bool)):
        L = parameters.get("L", 10)
        retset = []
        fullset = []

        # Initialize neighbors
        init_ids = list(self.final_graph[self.ep][:int(L)])
        flags[init_ids] = True
        additional_ids = np.random.choice([i for i in range(self.n) if not flags[i]], size=L-len(init_ids), replace=False)
        init_ids.extend(additional_ids)
        flags[additional_ids] = True

        for id in init_ids:
            if id < self.n:
                dist = distance.euclidean(self.data[id], query)
                retset.append(Neighbor(id, dist, True))

        retset.sort()

        k = 0
        while k < len(retset):
            nk = len(retset)
            if retset[k].flag:
                current_id = retset[k].id
                retset[k].flag = False
                neighbors = self.final_graph[current_id]
                for neighbor_id in neighbors:
                    if flags[neighbor_id]:
                        continue
                    flags[neighbor_id] = True
                    dist = distance.euclidean(self.data[neighbor_id], query)
                    new_neighbor = Neighbor(neighbor_id, dist, True)
                    fullset.append(new_neighbor)
                    if dist < retset[-1].distance:
                        r = insert_into_pool(retset, L, new_neighbor)
                        if r < nk: nk = r
                        if len(retset) > L: retset.pop()
            if nk < k: k = nk
            else: k+=1

        return retset, fullset


In [None]:
def main(args):
    if len(args) != 7:
        print(f"Usage: data_file nn_graph_path L R C save_graph_file")

    data_file = args[0]
    nn_graph_path = args[1]
    L = int(args[2])
    R = int(args[3])
    C = int(args[4])
    save_graph_file = args[5]

    data_load, points_num, dim = load_data(data_file)
    index = IndexNSG(dim, points_num, 'L2')

    start_time = time.time()
    paras = Parameters()
    paras.set('L', L)
    paras.set('R', R)
    paras.set('C', C)
    paras.set('nn_graph_path', nn_graph_path)
    index.build(points_num, data_load, paras)
    print(f"Building index with parameters L={L}, R={R}, C={C}, using graph {nn_graph_path}")
    end_time = time.time()

    indexing_time = end_time - start_time
    print(f"indexing time: {indexing_time}")
    index.save(save_graph_file)


In [None]:
args = list()
args.append("/content/siftsmall_base.fvecs")
args.append("/content/sift.50NN.graph")
args.append(10)
args.append(12)
args.append(125)
args.append("/content/siftsmall0.nsg")
main(args)
print(args)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
run 5005
runrun 5007
run 5008
run 5009
run 5010
run 5011
runrun 5013
run 5014
run 5015
run  5012
run 5017
run 5018
 4993
run 5019
run 5020
run 5021
run 5022
run 5023
run 5024
run 5025
runrun 5027
run 5028
run 5029
run 5006
5016
 5026
run 5031
run 5032
run 5033
run 5034
run 5035
runrun 5037
run 5038
run 5039
runrun 5041
run 5042
run 5043
 5030
run 5044
run 5045
runrun 5047
run 5048
 5040
run 5049
 5036
run 5050
run 5051
run 5052
run 5053
run 5054run 5055
run 5056
run 5057
run 5058run 5059
runrun 5061
run 5062
run 5063
run 5064
run 5065


 5060
 5046run 5066
run
run 5068
run  5069
5067run 5070
run 5071
run 5072

run 5073
run run 5075
5074
run 5076
run 5077
runrun 5079
 5078
run 5080
run 5081
run 5082
run 5083
run 5084
run 5085
run 5086
run 5087
run 5088
run 5089
run 5090
run 5091
run 5092
run 5093
run 5094
run run 5096
run 5097
5095run 5098
run 5099
runrun 5101
 run 5102

5100
run 5103
run 5104
run 5105
run 5106
run 5107
ru