In [2]:
import os
import time
import faiss
import numpy as np
from pprint import pprint
# from faiss import write_index, read_index

In [3]:
def random_floats(size, low=0, high=1):
    return [np.random.uniform(low, high) for _ in range(size)]

def random_normal_vectors(num_embeds, dim, mean=0, std=1):
    data = np.random.normal(mean, std, (num_embeds, dim)).astype('float32')
    return data

def random_embeddings(num_embeds, dim):
    # create random embeddings
    data = np.random.random((num_embeds, dim)).astype('float32')
    # data[:, 0] += np.arange(num_embeds) / 1000.
    return data

def random_queries(num_queries, dim):
    # create random queries
    queries = np.random.random((num_queries, dim)).astype('float32')
    # queries[:, 0] += np.arange(num_queries) / 1000.
    return queries

def save_np_to_file(file_path, np_array):
    np.save(file_path, np_array)
    print(f"Saved to {file_path}")

# load npy file
def load_npy(npy_path):
    return np.load(npy_path)

def create_ivf_index(npy_path):
    # load npy file
    data = load_npy(npy_path)
    nlist = 100
    # print(data.shape)
    d = data.shape[1]
    quantizer = faiss.IndexFlatL2(d)  # the other index
    index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
    assert not index.is_trained
    index.train(data)
    assert index.is_trained
    # save index and embeds
    index.add(data)
    return index

def create_flat_index(npy_path):
    # KNN search
    data = load_npy(npy_path)
    dim = data.shape[1]
    index = faiss.IndexFlatL2(dim)
    index.add(data)
    return index

def save_index(index, index_path):
    faiss.write_index(index, index_path)

def compute_embeds_avgs(embeds):
    return np.mean(embeds, axis=0)

In [5]:
# create multiple indexes
dim = 64
num_shards = 50
npy_root = "../shards/npys/"
index_root = "../shards/idxs/"

# init empty npy arrays
embeds_centroids = np.zeros((num_shards, dim))
# create npy files
for i in range(num_shards):
    # embeds = random_embeddings(100000, dim)
    random_mean = random_floats(1)
    # random_std = random_floats(1)
    random_std = [0.5]
    embeds = random_normal_vectors(100000, dim, random_mean[0], random_std[0])
    embeds_centroids[i] = compute_embeds_avgs(embeds)
    save_np_to_file(os.path.join(npy_root, f"embeds_{i}.npy"), embeds)

# save centroids
save_np_to_file(os.path.join(npy_root, f"embeds_centroids.npy"), embeds_centroids)

# create indexes
for npy_path in os.listdir(npy_root):
    if "centroids" in npy_path: 
        # create flat index
        index = create_flat_index(os.path.join(npy_root, npy_path))
        save_index(index, os.path.join(index_root, "embeds_centroids.index"))
        continue
    
    print(npy_path)
    index_prefix = npy_path.split(".")[0]
    index = create_ivf_index(os.path.join(npy_root, npy_path))
    save_index(index, os.path.join(index_root, f"{index_prefix}.index"))

Saved to ../shards/npys/embeds_0.npy
Saved to ../shards/npys/embeds_1.npy
Saved to ../shards/npys/embeds_2.npy
Saved to ../shards/npys/embeds_3.npy
Saved to ../shards/npys/embeds_4.npy
Saved to ../shards/npys/embeds_5.npy
Saved to ../shards/npys/embeds_6.npy
Saved to ../shards/npys/embeds_7.npy
Saved to ../shards/npys/embeds_8.npy
Saved to ../shards/npys/embeds_9.npy
Saved to ../shards/npys/embeds_10.npy
Saved to ../shards/npys/embeds_11.npy
Saved to ../shards/npys/embeds_12.npy
Saved to ../shards/npys/embeds_13.npy
Saved to ../shards/npys/embeds_14.npy
Saved to ../shards/npys/embeds_15.npy
Saved to ../shards/npys/embeds_16.npy
Saved to ../shards/npys/embeds_17.npy
Saved to ../shards/npys/embeds_18.npy
Saved to ../shards/npys/embeds_19.npy
Saved to ../shards/npys/embeds_20.npy
Saved to ../shards/npys/embeds_21.npy
Saved to ../shards/npys/embeds_22.npy
Saved to ../shards/npys/embeds_23.npy
Saved to ../shards/npys/embeds_24.npy
Saved to ../shards/npys/embeds_25.npy
Saved to ../shards/npy

In [5]:
# load index
def load_index(index_path):
    return faiss.read_index(index_path)

# query index
def query_index(index, queries, k):
    D, I = index.search(queries, k)
    return D, I

def query_index_file(index_path, queries, k):
    index = load_index(index_path)
    D, I = query_index(index, queries, k)
    return D, I

In [6]:
# Testing
index = load_index("shards/idxs/embeds_1.index")
queries = random_queries(100, dim)
# index.nprobe = 100
D, I = query_index(index, queries, 1)

index.is_trained
# D

True

In [7]:
'''
User given query and k
'''

# num_file_visit * subk > k 
k = 5
num_file_visit = 15
subk = (k // num_file_visit) + k
# print(subk)

idx_paths = [os.path.join(index_root, f) for f in os.listdir(index_root)]
# for now randomly select 3 index file to visit
idxs = list(np.random.choice(idx_paths, num_file_visit, replace=False))

D_concat = np.array([])
I_concat = np.array([])
file_idx_concat = np.array([])
for i, idx_path in enumerate(idxs):
    start_time = time.perf_counter()
    D, I = query_index_file(idx_path, queries, subk)
    end_time = time.perf_counter()
    query_idx_time = end_time - start_time
    # print query time 5digits
    print(f"query {idx_path:30}: {query_idx_time:.8f}s")
    # print(D, I)
    # make idx_matrix: [[i, i], [i,i]...], shape same as D and I
    file_idx = np.ones_like(D) * i
    # print(file_idx)
    D_concat = np.concatenate((D_concat, D), axis=1) if D_concat.size else D
    I_concat = np.concatenate((I_concat, I), axis=1) if I_concat.size else I
    file_idx_concat = np.concatenate((file_idx_concat, file_idx), axis=1) if file_idx_concat.size else file_idx

# sort by distance, and also sort file_idx and I
sort_idx = np.argsort(D_concat, axis=1)
D_sorted = np.take_along_axis(D_concat, sort_idx, axis=1)
I_sorted = np.take_along_axis(I_concat, sort_idx, axis=1)
file_idx_sorted = np.take_along_axis(file_idx_concat, sort_idx, axis=1).astype(int)

query shards/idxs/embeds_43.index   : 0.03076972s
query shards/idxs/embeds_23.index   : 0.02722288s
query shards/idxs/embeds_46.index   : 0.02610893s
query shards/idxs/embeds_20.index   : 0.02791506s
query shards/idxs/embeds_19.index   : 0.02585185s
query shards/idxs/embeds_41.index   : 0.02887433s
query shards/idxs/embeds_21.index   : 0.01999626s
query shards/idxs/embeds_24.index   : 0.01883139s
query shards/idxs/embeds_22.index   : 0.02385323s
query shards/idxs/embeds_17.index   : 0.02589414s
query shards/idxs/embeds_38.index   : 0.01358158s
query shards/idxs/embeds_45.index   : 0.02614252s
query shards/idxs/embeds_9.index    : 0.02599484s
query shards/idxs/embeds_8.index    : 0.02581422s
query shards/idxs/embeds_18.index   : 0.02590331s


In [None]:
def batch_search(idx_file_list, queries, subk, verbose=False):
    pass

In [8]:
# compare distance across different indexes, get top k for the D_sorted and I_sorted, file_idx_sorted

# get top k
top_k = k
top_k_idx = I_sorted[:, :top_k]
top_k_dist = D_sorted[:, :top_k]
top_k_file_idx = file_idx_sorted[:, :top_k]

# print(top_k_file_idx)

In [None]:
# data class
class IndexFileProfile:
    '''
    A data class to store a index file profile
    '''
    def __init__(self, idx_path, location, runtime, action):
        self.idx_path = idx_path
        # "DRAM" or "CXL"
        self.location = location
        self.runtime = runtime
        # "search" or "promotion" or "demotion"
        self.action = action

    def return_dict(self):
        return {
            "idx_path": self.idx_path,
            "location": self.location,
            "runtime": self.runtime, 
            "action": self.action
        }
    
class CXLLatencyFactorCalculator:
    '''
    Since we will only obtain DRAM time, we want to simulate the CXL time based on some factor (should be slower than DRAM)
    Some determine factor: 
        - file size
        - CXL memory size
        - CXL bandwidth
        - CXL mem usage (when it is high, CXL controller will be slower)
    '''
    def __init__(self):
        pass

    def calculate_factor(self):
        '''
        Calculate the CXL latency factor
        '''
        # for now, just return a fixed value, 2 times slower than DRAM
        return 2

class IndexFileManager:
    '''
    Manage index files: actively move index files between DRAM and CXL based on some metric like query time (could potentially get infor from read logs)
    '''
    # CXL_Latency = 100
    def __init__(self, idx_paths):
        pass

    def init_idx_location(self):
        pass

In [None]:
'''
1. Query batch comes in
2. Search top centriods 
3. Based on top centriods, find top index files to visit

*. Index files placement optimization
    Case 1: all index files are in DRAM. 
        - No NEED to move
    Case 2: all index files are in CXL
    Case 3: some index files are in DRAM, some are in CXL

4. Query top index files
5. Merge results and return

Case 2 and 3: Need to move index files between DRAM and CXL
    - Need to determine which index files to move
        - Needs to be visited frequently
    - Need to determine when to move
        - Can overlap with search


Traces
    - Time steps
    - Search profile


'''

In [15]:
# create a Priority Queue with reverse order
from queue import PriorityQueue

pq = PriorityQueue()
for i in range(10):
    pq.put((-i, i))

while not pq.empty():
    print(pq.get())

(-9, 9)
(-8, 8)
(-7, 7)
(-6, 6)
(-5, 5)
(-4, 4)
(-3, 3)
(-2, 2)
(-1, 1)
(0, 0)


In [11]:
'''
# Create a Priority Queue (-visit_count, idx)
# Book keeping for each (idx, location)
# Get one from the queue   
    - if location is DRAM, query
    - if location is CXL

search priority queue
    - (location, -visit_count, idx)

move priority queue (CXL to DRAM or DRAM to CXL)
    - (-visit_count, loc, idx)

'''

<queue.PriorityQueue at 0x7f58c1685880>