In [5]:
EMBEDDING_LEN = 2048

In [62]:
from pymilvus import CollectionSchema, FieldSchema, DataType, MilvusClient, DataType
import numpy as np
import random
import string
from pymilvus import (
    connections, utility, DataType, FieldSchema, CollectionSchema, Collection, Partition
)
from time import time, sleep
from pymilvus import CollectionSchema, FieldSchema, DataType, MilvusClient, DataType
from tqdm import tqdm

In [63]:
client = MilvusClient(
    uri="http://140.112.28.129:19530",
    db_name="default"
)

In [7]:
collection_name = "collection"
schema = CollectionSchema(fields=[
    FieldSchema(name='filename', dtype=DataType.VARCHAR, is_primary =True, max_length=128),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_LEN),
], auto_id=False)

index_params = client.prepare_index_params()
index_params.add_index(
    field_name="embedding", 
    index_type="IVF_FLAT",
    metric_type="COSINE",
    params={"nlist": 1}
)
client.drop_collection(collection_name="collection")
client.create_collection(
    collection_name=collection_name,
    schema=schema,
    index_params=index_params
)

In [8]:
def insert_random_data(partition_name, num, collection_name = "collection"):
    for _ in range(num):
        random_vector = np.random.rand(EMBEDDING_LEN)
        random_filename = ''.join(random.choice(string.ascii_letters) for x in range(20))
        
        client.insert(
            collection_name=collection_name,
            data = [{
                "filename": random_filename,
                "embedding": random_vector
            }],
            partition_name=partition_name
        )

In [9]:
partition_num = 50
partition_size = 50

for i in range(partition_num):
    partition_name=f"partition_{i}"
    client.create_partition(
        collection_name=collection_name, 
        partition_name=partition_name
    )
    insert_random_data(partition_name=partition_name, num = partition_size)
    

In [64]:
connections.connect(db_name='default', host='140.112.28.129', port='19530')

def load_partion(partition_name, collection_name = "collection"):
    print(f"load {partition_name}")
    partition = Partition(collection=collection_name, name=partition_name)
    partition.load()
    
def release_partion(partition_name, collection_name = "collection"):
    print(f"release {partition_name}")
    partition = Partition(collection=collection_name, name=partition_name)
    partition.release()

def release_all_partition(partition_num = partition_num):
    for i in range(partition_num):
        partition_name=f"partition_{i}"
        partition = Partition(collection=collection_name, name=partition_name)
        partition.release()
    print(f"all partition released")

In [77]:
# Cash

class PartitionNode:
    def __init__(self, partition_name: str):
        self.partition_name = partition_name
        self.prev = None
        self.next = None

class DLL:
    def __init__(self):
        self.dummy_start = PartitionNode("start")
        self.dummy_end = PartitionNode("end")
        self.dummy_start.next = self.dummy_end
        self.dummy_end.prev = self.dummy_start

    def appendleft(self, node: PartitionNode):
        left, right = self.dummy_start, self.dummy_start.next
        node.next = right
        right.prev = node
        left.next = node
        node.prev = left

    def remove(self, node: PartitionNode):
        left, right = node.prev, node.next
        left.next = right
        right.prev = left

    def move_to_start(self, node: PartitionNode):
        self.remove(node)
        self.appendleft(node)

    def pop(self):
        self.remove(self.dummy_end.prev)
    
    def back(self) -> string:
        return self.dummy_end.prev.partition_name
    
    def peek(self):
        return self.dummy_end.prev.partition_name

class LRUCache:
    def __init__(self, capacity):
        self.capacity = capacity
        self.cache = dict()
        self.dll = DLL()
        self.__hit = 0
        self.__miss = 0

    def put(self, partition_name: string) -> None:
        if partition_name in self.cache:
            node = self.cache[partition_name]
            self.dll.remove(node)
            self.__hit += 1
        else:
            node = PartitionNode(partition_name)
            self.cache[partition_name] = node
            self.__miss += 1
            load_partion(partition_name)
        
        self.dll.appendleft(node)
        
        if len(self.cache) > self.capacity:
            back_name = self.dll.back()
            self.cache.pop(back_name)
            self.dll.pop()
            release_partion(partition_name)
    
    def hit_rate(self):
        return (self.__hit / (self.__hit + self.__miss))

In [78]:
CACHE_CAPACITY = 10

In [84]:
cache = LRUCache(CACHE_CAPACITY)
num_of_access = 100
click_more_percentage = 0.8

partition_list = [random.randint(0, 50) for _ in range(num_of_access)]

release_all_partition()
for i in tqdm(range(10)):
    cache.put(partition_name= f"partition_{i}")
    
t = time()
for i, partition_idx in tqdm(enumerate(partition_list), total=len(partition_list)):
    if random.random() < click_more_percentage and i >= 1:
        partition_idx = partition_list[i - 1]
        partition_list[i] = partition_list[i - 1]
    cache.put(partition_name= f"partition_{partition_idx}")
print(time() - t)

print(f"hit rate = {cache.hit_rate()}")

# release_all_partition()
# t = time()
# for partition_idx in tqdm(partition_list):
#     load_partion(f"partition_{partition_idx}")
#     release_partion(f"partition_{partition_idx}")
# print(time() - t)

all partition released


100%|██████████| 10/10 [00:09<00:00,  1.01it/s]
100%|██████████| 100/100 [00:11<00:00,  8.86it/s]

11.296098709106445
hit rate = 0.7727272727272727





w/o cache
all partition released
99.68985271453857
hit rate = x

percentage = 0
71.07469701766968
hit rate = 0.19090909090909092

percentage = 0.2
54.940876960754395
hit rate = 0.3181818181818182

percentage = 0.4
36.95194172859192
hit rate = 0.5

percentage = 0.6
27.089735746383667
hit rate = 0.6

percentage = 0.8
11.296098709106445
hit rate = 0.7727272727272727

In [75]:
release_all_partition()
for i in tqdm(range(10)):
    cache.put(partition_name= f"partition_{i}")

all partition released


100%|██████████| 10/10 [00:09<00:00,  1.06it/s]


In [76]:
for i in tqdm(range(10)):
    cache.put(partition_name= f"partition_{i}")

100%|██████████| 10/10 [00:10<00:00,  1.01s/it]
