diff --git a/benchmarks/_plotter_helper.py b/benchmarks/_plotter_helper.py index 4030085..4555728 100644 --- a/benchmarks/_plotter_helper.py +++ b/benchmarks/_plotter_helper.py @@ -19,7 +19,7 @@ def convert_to_dataframe_from_benchmark(benchmark: "Benchmark") -> tuple: "tn_list": benchmark.tn_list, "fn_list": benchmark.fn_list, "latency_direct_list": benchmark.latency_direct_list, - "latency_vectorq_list": benchmark.latency_vcach_list, + "latency_vectorq_list": benchmark.latency_vcache_list, } df = pd.DataFrame(data) diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py index a9c920f..b39de34 100644 --- a/benchmarks/benchmark.py +++ b/benchmarks/benchmark.py @@ -33,6 +33,7 @@ HNSWLibVectorDB, SimilarityMetricType, ) +from vcache.vcache_core.similarity_evaluator import SimilarityEvaluator from vcache.vcache_core.similarity_evaluator.strategies.llm_comparison import ( LLMComparisonSimilarityEvaluator, ) @@ -101,7 +102,7 @@ class Dataset(Enum): ECOMMERCE_DATASET = "ecommerce_dataset" -class GenerateResultsOnly(Enum): +class GeneratePlotsOnly(Enum): YES = True NO = False @@ -112,62 +113,37 @@ class GenerateResultsOnly(Enum): MAX_SAMPLES: int = 60000 -CONFIDENCE_INTERVALS_ITERATIONS: int = 5 -IS_LLM_JUDGE_BENCHMARK: bool = False -DISABLE_PROGRESS_BAR: bool = True +CONFIDENCE_INTERVALS_ITERATIONS: int = 2 +DISABLE_PROGRESS_BAR: bool = False KEEP_SPLIT: int = 100 RUN_COMBINATIONS: List[ - Tuple[EmbeddingModel, LargeLanguageModel, Dataset, GenerateResultsOnly] + Tuple[EmbeddingModel, LargeLanguageModel, Dataset, GeneratePlotsOnly] ] = [ ( EmbeddingModel.GTE, LargeLanguageModel.LLAMA_3_8B, - Dataset.SEM_BENCHMARK_SEARCH_QUERIES, - GenerateResultsOnly.YES, + Dataset.SEM_BENCHMARK_CLASSIFICATION, + GeneratePlotsOnly.NO, + StringComparisonSimilarityEvaluator(), ), ( EmbeddingModel.GTE, LargeLanguageModel.GPT_4O_MINI, Dataset.SEM_BENCHMARK_ARENA, - GenerateResultsOnly.YES, - ), - ( - EmbeddingModel.E5_LARGE_V2, - LargeLanguageModel.GPT_4O_MINI, - Dataset.SEM_BENCHMARK_ARENA, - GenerateResultsOnly.YES, - ), - ( - EmbeddingModel.E5_LARGE_V2, - LargeLanguageModel.LLAMA_3_8B, - Dataset.SEM_BENCHMARK_CLASSIFICATION, - GenerateResultsOnly.YES, - ), - ( - EmbeddingModel.GTE, - LargeLanguageModel.LLAMA_3_8B, - Dataset.SEM_BENCHMARK_CLASSIFICATION, - GenerateResultsOnly.YES, - ), - ( - EmbeddingModel.GTE, - LargeLanguageModel.LLAMA_3_70B, - Dataset.SEM_BENCHMARK_CLASSIFICATION, - GenerateResultsOnly.YES, + GeneratePlotsOnly.NO, + LLMComparisonSimilarityEvaluator(), ), ] BASELINES_TO_RUN: List[Baseline] = [ # Baseline.IID, # Baseline.GPTCache, - # Baseline.VCacheLocal, + Baseline.VCacheLocal, # Baseline.BerkeleyEmbedding, # Baseline.VCacheBerkeleyEmbedding, ] -DATASETS_TO_RUN: List[str] = [Dataset.SEM_BENCHMARK_SEARCH_QUERIES] - STATIC_THRESHOLDS: List[float] = [ 0.80, 0.81, @@ -220,7 +196,7 @@ def stats_set_up(self): self.tn_list: List[int] = [] self.fn_list: List[int] = [] self.latency_direct_list: List[float] = [] - self.latency_vcach_list: List[float] = [] + self.latency_vcache_list: List[float] = [] self.observations_dict: Dict[str, Dict[str, float]] = {} self.gammas_dict: Dict[str, float] = {} self.t_hats_dict: Dict[str, float] = {} @@ -465,7 +441,7 @@ def dump_results_to_json(self): "tn_list": self.tn_list, "fn_list": self.fn_list, "latency_direct_list": self.latency_direct_list, - "latency_vectorq_list": self.latency_vcach_list, + "latency_vectorq_list": self.latency_vcache_list, "observations_dict": self.observations_dict, "gammas_dict": self.gammas_dict, "t_hats_dict": self.t_hats_dict, @@ -498,12 +474,8 @@ def __run_baseline( timestamp: str, delta: float, threshold: float, + similarity_evaluator: SimilarityEvaluator, ): - if IS_LLM_JUDGE_BENCHMARK: - similarity_evaluator = LLMComparisonSimilarityEvaluator() - else: - similarity_evaluator = StringComparisonSimilarityEvaluator() - vcache_config: VCacheConfig = VCacheConfig( inference_engine=BenchmarkInferenceEngine(), embedding_engine=BenchmarkEmbeddingEngine(), @@ -547,8 +519,15 @@ def main(): timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M") - for embedding_model, llm_model, dataset, generate_results_only in RUN_COMBINATIONS: + for ( + embedding_model, + llm_model, + dataset, + generate_plots_only, + similarity_evaluator, + ) in RUN_COMBINATIONS: try: + print(f"DatasetPath: {datasets_dir}, Dataset: {dataset.value}") dataset_file = os.path.join(datasets_dir, f"{dataset.value}.json") logging.info( f"Running benchmark for dataset: {dataset}, embedding model: {embedding_model.value[1]}, LLM model: {llm_model.value[1]}\n" @@ -557,7 +536,10 @@ def main(): ##################################################### ### Baseline: vCache Local - if Baseline.VCacheLocal in BASELINES_TO_RUN and not generate_results_only: + if ( + Baseline.VCacheLocal in BASELINES_TO_RUN + and not generate_plots_only.value + ): for delta in DELTAS: for i in range(0, CONFIDENCE_INTERVALS_ITERATIONS): path = os.path.join( @@ -583,11 +565,15 @@ def main(): timestamp=timestamp, delta=delta, threshold=-1, + similarity_evaluator=similarity_evaluator, ) ##################################################### ### Baseline: vCache Global - if Baseline.VCacheGlobal in BASELINES_TO_RUN and not generate_results_only: + if ( + Baseline.VCacheGlobal in BASELINES_TO_RUN + and not generate_plots_only.value + ): for delta in DELTAS: path = os.path.join( results_dir, @@ -612,13 +598,14 @@ def main(): timestamp=timestamp, delta=delta, threshold=-1, + similarity_evaluator=similarity_evaluator, ) ##################################################### ### Baseline: Berkeley Embedding if ( Baseline.BerkeleyEmbedding in BASELINES_TO_RUN - and not generate_results_only + and not generate_plots_only.value ): for threshold in STATIC_THRESHOLDS: if embedding_model == EmbeddingModel.E5_MISTRAL_7B: @@ -658,13 +645,14 @@ def main(): timestamp=timestamp, delta=-1, threshold=threshold, + similarity_evaluator=similarity_evaluator, ) ##################################################### ### Baseline: vCache + Berkeley Embedding if ( Baseline.VCacheBerkeleyEmbedding in BASELINES_TO_RUN - and not generate_results_only + and not generate_plots_only.value ): for delta in DELTAS: for i in range(0, CONFIDENCE_INTERVALS_ITERATIONS): @@ -707,11 +695,12 @@ def main(): timestamp=timestamp, delta=delta, threshold=-1, + similarity_evaluator=similarity_evaluator, ) ##################################################### ### Baseline: IID Local - if Baseline.IID in BASELINES_TO_RUN and not generate_results_only: + if Baseline.IID in BASELINES_TO_RUN and not generate_plots_only.value: for delta in DELTAS: for i in range(0, CONFIDENCE_INTERVALS_ITERATIONS): path = os.path.join( @@ -737,11 +726,12 @@ def main(): timestamp=timestamp, delta=delta, threshold=-1, + similarity_evaluator=similarity_evaluator, ) ##################################################### ### Baseline: GPTCache - if Baseline.GPTCache in BASELINES_TO_RUN and not generate_results_only: + if Baseline.GPTCache in BASELINES_TO_RUN and not generate_plots_only.value: for threshold in STATIC_THRESHOLDS: path = os.path.join( results_dir, @@ -764,6 +754,7 @@ def main(): timestamp=timestamp, delta=-1, threshold=threshold, + similarity_evaluator=similarity_evaluator, ) ##################################################### diff --git a/tests/ReadMe.md b/tests/ReadMe.md index 7ab2408..6d4cb66 100644 --- a/tests/ReadMe.md +++ b/tests/ReadMe.md @@ -25,7 +25,11 @@ vCache includes both **unit tests** and **integration tests** to ensure correctn Unit tests verify the **logic of individual module strategies** (e.g., caching policies, embedding engines, similarity evaluators) in isolation. They are designed to be fast, deterministic, and independent of external services. +#### Running Unit Tests +```bash +python -m pytest tests/unit/ +``` ### Integration Tests diff --git a/tests/integration/test_concurrency.py b/tests/integration/test_concurrency.py new file mode 100644 index 0000000..576d75b --- /dev/null +++ b/tests/integration/test_concurrency.py @@ -0,0 +1,114 @@ +import random +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock, patch + +from dotenv import load_dotenv + +from vcache import ( + DynamicLocalThresholdPolicy, + HNSWLibVectorDB, + InMemoryEmbeddingMetadataStorage, + LangChainEmbeddingEngine, + StringComparisonSimilarityEvaluator, + VCache, + VCacheConfig, +) +from vcache.vcache_policy.strategies.dynamic_local_threshold import _Action + +load_dotenv() + + +class TestConcurrency(unittest.TestCase): + def test_async_label_generation_and_timeout(self): + similarity_evaluator = StringComparisonSimilarityEvaluator() + + mock_answers_similar = MagicMock() + + def answers_similar(a, b): + if "Return 'xxxxxxxxx' as the answer" in a: + time.sleep(10) + print(f"Answers Similar (Execution time: 10s) => a: {a}, b: {b}\n") + return True + else: + execution_time = random.uniform(0.5, 3) + time.sleep(execution_time) + print( + f"Answers Similar (Execution time: {execution_time}s) => a: {a}, b: {b}\n" + ) + return True + + mock_answers_similar.side_effect = answers_similar + + config = VCacheConfig( + embedding_engine=LangChainEmbeddingEngine( + model_name="sentence-transformers/all-mpnet-base-v2" + ), + vector_db=HNSWLibVectorDB(), + embedding_metadata_storage=InMemoryEmbeddingMetadataStorage(), + similarity_evaluator=similarity_evaluator, + ) + + with DynamicLocalThresholdPolicy(delta=0.05) as policy: + vcache: VCache = VCache(config, policy) + vcache.vcache_policy.setup(config) + + with ( + patch.object( + policy.similarity_evaluator, + "answers_similar", + new=mock_answers_similar, + ), + patch.object( + policy.bayesian, "select_action", return_value=_Action.EXPLORE + ), + ): + initial_prompt = "What is the capital of Germany?" + vcache.infer(prompt=initial_prompt) + + concurrent_prompts_chunk_1 = [ + "What is the capital of Germany?Germany's capital?", + "Capital of Germany is...", + "Return 'xxxxxxxxx' as the answer", # This is the slow prompt + "Berlin is the capital of what country?", + ] + concurrent_prompts_chunk_2 = [ + "Which city is the seat of the German government?", + "What is Germany's primary city?", + "Tell me about Berlin.", + "Is Frankfurt the capital of Germany?", + "What's the main city of Germany?", + "Where is the German government located?", + ] + + def do_inference(prompt): + prompt_index = total_prompts.index(prompt) + print(f"Inferring prompt {prompt_index}: {prompt}\n") + vcache.infer(prompt=prompt) + + total_prompts = concurrent_prompts_chunk_1 + concurrent_prompts_chunk_2 + with ThreadPoolExecutor(max_workers=len(total_prompts)) as executor: + executor.map(do_inference, concurrent_prompts_chunk_1) + time.sleep(1.5) + executor.map(do_inference, concurrent_prompts_chunk_2) + + all_metadata_objects = vcache.vcache_config.embedding_metadata_storage.get_all_embedding_metadata_objects() + final_observation_count = len(all_metadata_objects) + + for i, metadata_object in enumerate(all_metadata_objects): + print(f"metadata_object {i}: {metadata_object}") + + print(f"\nfinal_observation_count: {final_observation_count}") + + assert final_observation_count == 1, ( + f"Expected 1 metadata object, got {final_observation_count}" + ) + # We expect the 'slow prompt' to be the only prompt not being part of the observations + assert len(all_metadata_objects[0].observations) == 12, ( + f"Expected 12 observations (10 + 2 initial labels), got {len(all_metadata_objects[0].observations)}" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/Benchmark/test.py b/tests/unit/Benchmark/test_benchmark.py similarity index 100% rename from tests/unit/Benchmark/test.py rename to tests/unit/Benchmark/test_benchmark.py diff --git a/tests/unit/EmbeddingEngineStrategy/test.py b/tests/unit/EmbeddingEngineStrategy/test_embedding_engine_strategy.py similarity index 100% rename from tests/unit/EmbeddingEngineStrategy/test.py rename to tests/unit/EmbeddingEngineStrategy/test_embedding_engine_strategy.py diff --git a/tests/unit/EmbeddingMetadataStrategy/test.py b/tests/unit/EmbeddingMetadataStrategy/test.py deleted file mode 100644 index e266159..0000000 --- a/tests/unit/EmbeddingMetadataStrategy/test.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest - -from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage import ( - InMemoryEmbeddingMetadataStorage, -) -from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import ( - EmbeddingMetadataObj, -) - - -class TestEmbeddingMetadataStorageStrategy(unittest.TestCase): - def test_in_memory_strategy(self): - embedding_metadata_storage = InMemoryEmbeddingMetadataStorage() - - initial_obj = EmbeddingMetadataObj(embedding_id=0, response="test") - embedding_id = embedding_metadata_storage.add_metadata( - embedding_id=0, metadata=initial_obj - ) - assert embedding_id == 0 - assert embedding_metadata_storage.get_metadata(embedding_id=0) == initial_obj - - updated_obj = EmbeddingMetadataObj(embedding_id=0, response="test2") - embedding_metadata_storage.update_metadata(embedding_id=0, metadata=updated_obj) - assert embedding_metadata_storage.get_metadata(embedding_id=0) == updated_obj - - embedding_metadata_storage.flush() - with self.assertRaises(ValueError): - embedding_metadata_storage.get_metadata(embedding_id=0) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/EmbeddingMetadataStrategy/test_embedding_metadata_strategy.py b/tests/unit/EmbeddingMetadataStrategy/test_embedding_metadata_strategy.py new file mode 100644 index 0000000..f134b46 --- /dev/null +++ b/tests/unit/EmbeddingMetadataStrategy/test_embedding_metadata_strategy.py @@ -0,0 +1,62 @@ +import unittest +from concurrent.futures import ThreadPoolExecutor + +from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage import ( + InMemoryEmbeddingMetadataStorage, +) +from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import ( + EmbeddingMetadataObj, +) + + +class TestEmbeddingMetadataStorageStrategy(unittest.TestCase): + def test_in_memory_strategy(self): + embedding_metadata_storage = InMemoryEmbeddingMetadataStorage() + + initial_obj = EmbeddingMetadataObj(embedding_id=0, response="test") + embedding_id = embedding_metadata_storage.add_metadata( + embedding_id=0, metadata=initial_obj + ) + assert embedding_id == 0 + assert embedding_metadata_storage.get_metadata(embedding_id=0) == initial_obj + + updated_obj = EmbeddingMetadataObj(embedding_id=0, response="test2") + embedding_metadata_storage.update_metadata(embedding_id=0, metadata=updated_obj) + assert embedding_metadata_storage.get_metadata(embedding_id=0) == updated_obj + + embedding_metadata_storage.flush() + with self.assertRaises(ValueError): + embedding_metadata_storage.get_metadata(embedding_id=0) + + +class TestInMemoryEmbeddingMetadataStorageThreadSafety(unittest.TestCase): + def test_concurrent_add_observation(self): + storage = InMemoryEmbeddingMetadataStorage() + embedding_id = 0 + initial_obj = EmbeddingMetadataObj(embedding_id=embedding_id, response="test") + storage.add_metadata(embedding_id=embedding_id, metadata=initial_obj) + + num_threads = 10 + num_observations_per_thread = 100 + total_observations = num_threads * num_observations_per_thread + + def add_observations_task(storage, embedding_id): + for i in range(num_observations_per_thread): + observation = (float(i), 1) + storage.add_observation(embedding_id, observation) + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [ + executor.submit(add_observations_task, storage, embedding_id) + for _ in range(num_threads) + ] + for future in futures: + future.result() + + metadata = storage.get_metadata(embedding_id) + # The initial object has 2 observations, so we add the total observations to that. + self.assertEqual(len(metadata.observations), total_observations + 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/EvictionPolicyStrategy/test.py b/tests/unit/EvictionPolicyStrategy/test_eviction_policy_strategy.py similarity index 100% rename from tests/unit/EvictionPolicyStrategy/test.py rename to tests/unit/EvictionPolicyStrategy/test_eviction_policy_strategy.py diff --git a/tests/unit/InferenceEngineStrategy/test.py b/tests/unit/InferenceEngineStrategy/test_inference_engine_strategy.py similarity index 100% rename from tests/unit/InferenceEngineStrategy/test.py rename to tests/unit/InferenceEngineStrategy/test_inference_engine_strategy.py diff --git a/tests/unit/SimilarityEvalutatorStrategy/test.py b/tests/unit/SimilarityEvalutatorStrategy/test_similarity_evaluator_strategy.py similarity index 100% rename from tests/unit/SimilarityEvalutatorStrategy/test.py rename to tests/unit/SimilarityEvalutatorStrategy/test_similarity_evaluator_strategy.py diff --git a/tests/unit/VectorDB/test_thread_safety.py b/tests/unit/VectorDB/test_thread_safety.py new file mode 100644 index 0000000..41b6531 --- /dev/null +++ b/tests/unit/VectorDB/test_thread_safety.py @@ -0,0 +1,189 @@ +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List, Set + +from vcache.vcache_core.cache.embedding_store.vector_db.strategies.chroma import ( + ChromaVectorDB, +) +from vcache.vcache_core.cache.embedding_store.vector_db.strategies.faiss import ( + FAISSVectorDB, +) +from vcache.vcache_core.cache.embedding_store.vector_db.strategies.hnsw_lib import ( + HNSWLibVectorDB, +) +from vcache.vcache_core.cache.embedding_store.vector_db.vector_db import ( + SimilarityMetricType, +) + + +class TestVectorDBThreadSafety(unittest.TestCase): + """Test thread safety of vector database implementations.""" + + def setUp(self): + """Set up test fixtures.""" + self.embedding_dim = 128 + self.num_threads = 10 + self.embeddings_per_thread = 20 + + def _generate_random_embedding(self, seed: int) -> List[float]: + """Generate a deterministic random embedding based on seed.""" + import random + + random.seed(seed) + return [random.random() for _ in range(self.embedding_dim)] + + def _add_embeddings_worker( + self, vector_db, start_seed: int, count: int + ) -> Set[int]: + """Worker function to add embeddings in a thread.""" + added_ids = set() + for i in range(count): + embedding = self._generate_random_embedding(start_seed + i) + embedding_id = vector_db.add(embedding) + added_ids.add(embedding_id) + return added_ids + + def test_hnswlib_thread_safety(self): + """Test HNSWLibVectorDB thread safety.""" + vector_db = HNSWLibVectorDB(similarity_metric_type=SimilarityMetricType.COSINE) + self._test_vector_db_thread_safety(vector_db) + + def test_chroma_thread_safety(self): + """Test ChromaVectorDB thread safety.""" + vector_db = ChromaVectorDB(similarity_metric_type=SimilarityMetricType.COSINE) + self._test_vector_db_thread_safety(vector_db) + + def test_faiss_thread_safety(self): + """Test FAISSVectorDB thread safety.""" + vector_db = FAISSVectorDB(similarity_metric_type=SimilarityMetricType.COSINE) + self._test_vector_db_thread_safety(vector_db) + + def _test_vector_db_thread_safety(self, vector_db): + """Generic test for vector database thread safety.""" + all_ids = set() + + # Use ThreadPoolExecutor to simulate concurrent access + with ThreadPoolExecutor(max_workers=self.num_threads) as executor: + # Submit tasks to add embeddings concurrently + futures = [] + for i in range(self.num_threads): + start_seed = i * self.embeddings_per_thread + future = executor.submit( + self._add_embeddings_worker, + vector_db, + start_seed, + self.embeddings_per_thread, + ) + futures.append(future) + + # Collect results + for future in as_completed(futures): + thread_ids = future.result() + # Check for ID collisions + intersection = all_ids.intersection(thread_ids) + self.assertEqual( + len(intersection), 0, f"ID collision detected: {intersection}" + ) + all_ids.update(thread_ids) + + # Verify total number of unique IDs + expected_total = self.num_threads * self.embeddings_per_thread + self.assertEqual( + len(all_ids), + expected_total, + f"Expected {expected_total} unique IDs, got {len(all_ids)}", + ) + + # Test that we can query the database + test_embedding = self._generate_random_embedding(999) + results = vector_db.get_knn(test_embedding, k=5) + self.assertLessEqual(len(results), 5, "Should return at most 5 results") + + # Verify that all returned IDs exist in our added IDs + for _, embedding_id in results: + self.assertIn( + embedding_id, + all_ids, + f"Returned ID {embedding_id} was not in added IDs", + ) + + def test_concurrent_add_and_query(self): + """Test concurrent add and query operations.""" + vector_db = HNSWLibVectorDB(similarity_metric_type=SimilarityMetricType.COSINE) + + def add_worker(): + """Worker that adds embeddings.""" + for i in range(10): + embedding = self._generate_random_embedding(i) + vector_db.add(embedding) + time.sleep(0.001) # Small delay to increase chance of race conditions + + def query_worker(): + """Worker that queries embeddings.""" + query_embedding = self._generate_random_embedding(999) + for _ in range(10): + try: + results = vector_db.get_knn(query_embedding, k=3) + # Should not raise exceptions + self.assertIsInstance(results, list) + except Exception as e: + self.fail(f"Query worker failed with exception: {e}") + time.sleep(0.001) + + # Run concurrent add and query operations + threads = [] + for _ in range(3): + threads.append(threading.Thread(target=add_worker)) + threads.append(threading.Thread(target=query_worker)) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Verify final state + self.assertFalse(vector_db.is_empty()) + + def test_concurrent_reset_operations(self): + """Test concurrent reset operations don't cause issues.""" + vector_db = HNSWLibVectorDB(similarity_metric_type=SimilarityMetricType.COSINE) + + # Add some initial data + for i in range(5): + embedding = self._generate_random_embedding(i) + vector_db.add(embedding) + + def reset_worker(): + """Worker that resets the database.""" + vector_db.reset() + + def add_worker(): + """Worker that adds embeddings.""" + for i in range(5): + embedding = self._generate_random_embedding(100 + i) + try: + vector_db.add(embedding) + except Exception: + # Reset might have happened, which is fine + pass + + # Run concurrent operations + threads = [] + threads.append(threading.Thread(target=reset_worker)) + threads.append(threading.Thread(target=add_worker)) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Should not crash and database should be in a valid state + self.assertIsInstance(vector_db.is_empty(), bool) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/VectorDBStrategy/test.py b/tests/unit/VectorDBStrategy/test_vector_db_strategy.py similarity index 94% rename from tests/unit/VectorDBStrategy/test.py rename to tests/unit/VectorDBStrategy/test_vector_db_strategy.py index 2ca1de9..ff1815d 100644 --- a/tests/unit/VectorDBStrategy/test.py +++ b/tests/unit/VectorDBStrategy/test_vector_db_strategy.py @@ -71,8 +71,11 @@ def test_remove(self, vector_db_class, similarity_metric_type): # Verify only one remains knn = vector_db.get_knn(embedding=[0.1, 0.2, 0.3], k=2) - assert len(knn) == 1 - assert knn[0][1] == id2 + returned_ids = {result[1] for result in knn} + + # Check that the removed ID is gone and the other one remains. + assert id1 not in returned_ids + assert id2 in returned_ids @pytest.mark.parametrize( "vector_db_class, similarity_metric_type", diff --git a/vcache/__init__.py b/vcache/__init__.py index 1cccbbd..2961e9c 100644 --- a/vcache/__init__.py +++ b/vcache/__init__.py @@ -2,6 +2,12 @@ vCache: Reliable and Efficient Semantic Prompt Caching """ +import os + +# Disable Hugging Face tokenizer parallelism to prevent deadlocks when using +# vCache in multi-threaded applications. This is a library-level fix. +os.environ["TOKENIZERS_PARALLELISM"] = "false" + # Main vCache classes from vcache.config import VCacheConfig diff --git a/vcache/inference_engine/inference_engine.py b/vcache/inference_engine/inference_engine.py index 8f89e1c..c19cb52 100644 --- a/vcache/inference_engine/inference_engine.py +++ b/vcache/inference_engine/inference_engine.py @@ -2,17 +2,17 @@ class InferenceEngine(ABC): - """ - Abstract base class for inference engines - """ + """Abstract base class for inference engines.""" @abstractmethod def create(self, prompt: str, system_prompt: str = None) -> str: - """ - Args - prompt: str - The prompt to create an answer for - system_prompt: str - The optional output format to use for the response - Returns - str - The answer to the prompt + """Create an answer for the given prompt. + + Args: + prompt (str): The prompt to create an answer for. + system_prompt (str, optional): The optional system prompt to use for the response. + + Returns: + str: The answer to the prompt. """ pass diff --git a/vcache/vcache_core/cache/cache.py b/vcache/vcache_core/cache/cache.py index 8a64525..7c0501f 100644 --- a/vcache/vcache_core/cache/cache.py +++ b/vcache/vcache_core/cache/cache.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple from vcache.vcache_core.cache.embedding_engine.embedding_engine import EmbeddingEngine from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import ( @@ -9,6 +9,8 @@ class Cache: + """Cache that manages prompt embeddings and responses using a vector database and metadata store.""" + def __init__( self, embedding_store: EmbeddingStore, @@ -20,68 +22,109 @@ def __init__( self.eviction_policy = eviction_policy def add(self, prompt: str, response: str) -> int: - """ - prompt: str - The prompt to add to the cache - response: str - The response to add to the cache - returns: int - The id of the embedding + """Generate an embedding for the prompt and add it to the cache. + + This method obtains an embedding for the prompt via the embedding engine, + stores it in the vector database, and initializes its metadata with the response. + + Args: + prompt (str): The prompt to cache. + response (str): The generated response for the prompt. + + Returns: + int: The unique ID of the new embedding. """ embedding = self.embedding_engine.get_embedding(prompt) - self.embedding_store.add_embedding(embedding, response) + return self.embedding_store.add_embedding(embedding, response) def remove(self, embedding_id: int) -> int: + """Remove an embedding and its metadata from the cache. + + Args: + embedding_id (int): The ID of the embedding to remove. + + Returns: + int: The ID of the removed embedding. """ - embedding_id: int - The id of the embedding to remove - returns: int - The id of the embedding - """ - self.embedding_store.remove(embedding_id) + return self.embedding_store.remove(embedding_id) def get_knn(self, prompt: str, k: int) -> List[tuple[float, int]]: - """ - prompt: str - The prompt to get the k-nearest neighbors for - k: int - The number of nearest neighbors to get - returns: List[tuple[float, int]] - A list of tuples, each containing a similarity score and an embedding id + """Retrieve the k closest embeddings for a prompt. + + This method encodes the prompt to a vector and queries the vector database + for its k nearest neighbors. + + Args: + prompt (str): The prompt to query. + k (int): Number of nearest neighbors to return. + + Returns: + List[tuple[float, int]]: List of (similarity_score, embedding_id) tuples. """ embedding = self.embedding_engine.get_embedding(prompt) return self.embedding_store.get_knn(embedding, k) def flush(self) -> None: - """ - Flushes the cache - """ + """Clear all embeddings and metadata from the cache.""" self.embedding_store.reset() def get_metadata(self, embedding_id: int) -> EmbeddingMetadataObj: - """ - embedding_id: int - The id of the embedding to get the metadata for - returns: EmbeddingMetadataObj - The metadata of the embedding + """Get metadata associated with a specific embedding. + + Args: + embedding_id (int): The ID of the embedding. + + Returns: + EmbeddingMetadataObj: The metadata for the embedding. """ return self.embedding_store.get_metadata(embedding_id) def update_metadata( self, embedding_id: int, embedding_metadata: EmbeddingMetadataObj ) -> EmbeddingMetadataObj: + """Update metadata for an existing embedding. + + Args: + embedding_id (int): The ID of the embedding. + embedding_metadata (EmbeddingMetadataObj): The new metadata object. + + Returns: + EmbeddingMetadataObj: The updated metadata object. """ - embedding_id: int - The id of the embedding to update - embedding_metadata: EmbeddingMetadataObj - The metadata to update the embedding with - returns: EmbeddingMetadataObj - The updated metadata of the embedding + return self.embedding_store.update_metadata(embedding_id, embedding_metadata) + + def add_observation( + self, embedding_id: int, observation: Tuple[float, int] + ) -> None: + """Atomically add an observation to an embedding's metadata. + + Args: + embedding_id (int): The ID of the embedding. + observation (Tuple[float, int]): A tuple (similarity_score, label). """ - self.embedding_store.update_metadata(embedding_id, embedding_metadata) + self.embedding_store.add_observation(embedding_id, observation) def get_current_capacity(self) -> int: - """ - returns: int - The current capacity of the cache + """Return the current capacity of the cache. + + Returns: + int: The number of embeddings currently stored. """ # TODO return None def is_empty(self) -> bool: - """ - returns: bool - Whether the cache is empty + """Check if the cache has no embeddings. + + Returns: + bool: True if empty, False otherwise. """ return self.embedding_store.is_empty() def get_all_embedding_metadata_objects(self) -> List[EmbeddingMetadataObj]: - """ - returns: List["EmbeddingMetadataObj"] - A list of all the embedding metadata objects in the cache + """Retrieve all embedding metadata objects. + + Returns: + List[EmbeddingMetadataObj]: All metadata objects in the cache. """ return self.embedding_store.embedding_metadata_storage.get_all_embedding_metadata_objects() diff --git a/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/embedding_metadata_storage.py b/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/embedding_metadata_storage.py index c598873..e8fd249 100644 --- a/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/embedding_metadata_storage.py +++ b/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/embedding_metadata_storage.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Tuple from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import ( EmbeddingMetadataObj, @@ -7,20 +7,30 @@ class EmbeddingMetadataStorage(ABC): + """Abstract base class for embedding metadata storage.""" + @abstractmethod def add_metadata(self, embedding_id: int, metadata: EmbeddingMetadataObj) -> int: - """ - embedding_id: int - The id of the embedding to add the metadata for - metadata: EmbeddingMetadataObj - The metadata to add to the embedding - returns: int - The id of the embedding + """Add metadata entry for an embedding. + + Args: + embedding_id (int): The ID of the embedding. + metadata (EmbeddingMetadataObj): The metadata to add. + + Returns: + int: The ID of the embedding. """ pass @abstractmethod def get_metadata(self, embedding_id: int) -> EmbeddingMetadataObj: - """ - embedding_id: int - The id of the embedding to get the metadata for - returns: EmbeddingMetadataObj - The metadata of the embedding + """Retrieve metadata for an embedding. + + Args: + embedding_id (int): The ID of the embedding. + + Returns: + EmbeddingMetadataObj: The metadata for the embedding. """ pass @@ -28,30 +38,48 @@ def get_metadata(self, embedding_id: int) -> EmbeddingMetadataObj: def update_metadata( self, embedding_id: int, metadata: EmbeddingMetadataObj ) -> EmbeddingMetadataObj: + """Update metadata for an existing embedding. + + Args: + embedding_id (int): The ID of the embedding. + metadata (EmbeddingMetadataObj): The new metadata object. + + Returns: + EmbeddingMetadataObj: The updated metadata object. """ - embedding_id: int - The id of the embedding to update the metadata for - metadata: EmbeddingMetadataObj - The metadata to update the embedding with - returns: EmbeddingMetadataObj - The updated metadata of the embedding + pass + + @abstractmethod + def add_observation( + self, embedding_id: int, observation: Tuple[float, int] + ) -> None: + """Atomically add an observation to an embedding's metadata. + + Args: + embedding_id (int): The ID of the embedding. + observation (Tuple[float, int]): A tuple (similarity_score, label). """ pass @abstractmethod def remove_metadata(self, embedding_id: int) -> None: - """ - embedding_id: int - The id of the embedding to remove the metadata for + """Remove metadata for an embedding. + + Args: + embedding_id (int): The ID of the embedding. """ pass @abstractmethod def flush(self) -> None: - """ - Flushes the metadata storage - """ + """Clear all metadata from storage.""" pass @abstractmethod def get_all_embedding_metadata_objects(self) -> List[EmbeddingMetadataObj]: - """ - returns: List["EmbeddingMetadataObj"] - A list of all the embedding metadata objects in the storage + """Retrieve all metadata objects in storage. + + Returns: + List[EmbeddingMetadataObj]: All metadata objects in storage. """ pass diff --git a/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/strategies/in_memory.py b/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/strategies/in_memory.py index d36326c..4721380 100644 --- a/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/strategies/in_memory.py +++ b/vcache/vcache_core/cache/embedding_store/embedding_metadata_storage/strategies/in_memory.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional +import threading +from typing import Dict, List, Tuple from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage.embedding_metadata_obj import ( EmbeddingMetadataObj, @@ -9,42 +10,137 @@ class InMemoryEmbeddingMetadataStorage(EmbeddingMetadataStorage): + """In-memory implementation of embedding metadata storage. + + This class stores embedding metadata in a dictionary and uses locks to ensure + thread-safe access. + + Attributes: + metadata_storage (Dict[int, EmbeddingMetadataObj]): The dictionary storing metadata. + """ + def __init__(self): - self.metadata_storage: Dict[int, "EmbeddingMetadataObj"] = {} + self.metadata_storage: Dict[int, EmbeddingMetadataObj] = {} + self._store_lock = threading.RLock() + self._entry_locks: Dict[int, threading.Lock] = {} - def add_metadata( - self, embedding_id: int, metadata: Optional[Dict[str, Any]] = None - ) -> None: - self.metadata_storage[embedding_id] = metadata - return embedding_id - - def get_metadata(self, embedding_id: int) -> Optional[Dict[str, Any]]: - if embedding_id not in self.metadata_storage: - raise ValueError( - f"Embedding metadata for embedding id {embedding_id} not found" - ) - else: + def _get_entry_lock(self, embedding_id: int) -> threading.Lock: + """Get a lock for a specific embedding ID, creating it if needed. + + Args: + embedding_id (int): The ID of the embedding. + + Returns: + threading.Lock: The lock for the given embedding ID. + """ + with self._store_lock: + if embedding_id not in self._entry_locks: + self._entry_locks[embedding_id] = threading.Lock() + return self._entry_locks[embedding_id] + + def add_metadata(self, embedding_id: int, metadata: EmbeddingMetadataObj) -> int: + """Add metadata for an embedding to the in-memory store. + + Args: + embedding_id (int): The ID of the embedding. + metadata (EmbeddingMetadataObj): The metadata object to associate with the ID. + + Returns: + int: The ID of the embedding. + """ + with self._store_lock: + self.metadata_storage[embedding_id] = metadata + return embedding_id + + def get_metadata(self, embedding_id: int) -> EmbeddingMetadataObj: + """Retrieve metadata for an embedding from the in-memory store. + + Args: + embedding_id (int): The ID of the embedding to retrieve. + + Returns: + EmbeddingMetadataObj: The metadata object for the given ID. + + Raises: + ValueError: If no metadata is found for the embedding ID. + """ + with self._store_lock: + if embedding_id not in self.metadata_storage: + raise ValueError( + f"Embedding metadata for embedding id {embedding_id} not found" + ) return self.metadata_storage[embedding_id] def update_metadata( - self, embedding_id: int, metadata: Optional[Dict[str, Any]] = None - ) -> bool: - if embedding_id not in self.metadata_storage: - raise ValueError( - f"Embedding metadata for embedding id {embedding_id} not found" - ) - else: + self, embedding_id: int, metadata: EmbeddingMetadataObj + ) -> EmbeddingMetadataObj: + """Update metadata for an existing embedding in the in-memory store. + + Args: + embedding_id (int): The ID of the embedding to update. + metadata (EmbeddingMetadataObj): The new metadata object. + + Returns: + EmbeddingMetadataObj: The updated metadata object. + + Raises: + ValueError: If no metadata is found for the embedding ID. + """ + with self._store_lock: + if embedding_id not in self.metadata_storage: + raise ValueError( + f"Embedding metadata for embedding id {embedding_id} not found" + ) self.metadata_storage[embedding_id] = metadata return metadata + def add_observation( + self, embedding_id: int, observation: Tuple[float, int] + ) -> None: + """Atomically add an observation to an embedding's metadata. + + This method ensures that appending an observation to the list of + observations is a thread-safe operation. + + Args: + embedding_id (int): The ID of the embedding to update. + observation (Tuple[float, int]): The observation tuple (similarity, label). + """ + entry_lock = self._get_entry_lock(embedding_id) + with entry_lock: + metadata = self.get_metadata(embedding_id) + metadata.observations.append(observation) + self.update_metadata(embedding_id, metadata) + def remove_metadata(self, embedding_id: int) -> bool: - if embedding_id in self.metadata_storage: - del self.metadata_storage[embedding_id] - return True + """Remove metadata for a specific embedding ID. + + Args: + embedding_id (int): The ID of the embedding metadata to remove. + + Returns: + bool: True if the metadata was found and removed, False otherwise. + """ + with self._store_lock: + if embedding_id in self.metadata_storage: + del self.metadata_storage[embedding_id] + # Also remove the associated lock + if embedding_id in self._entry_locks: + del self._entry_locks[embedding_id] + return True return False def flush(self) -> None: - self.metadata_storage = {} + """Clear all metadata from the in-memory store.""" + with self._store_lock: + self.metadata_storage = {} + self._entry_locks = {} def get_all_embedding_metadata_objects(self) -> List[EmbeddingMetadataObj]: - return list(self.metadata_storage.values()) + """Retrieve all metadata objects from the in-memory store. + + Returns: + List[EmbeddingMetadataObj]: A list of all metadata objects. + """ + with self._store_lock: + return list(self.metadata_storage.values()) diff --git a/vcache/vcache_core/cache/embedding_store/embedding_store.py b/vcache/vcache_core/cache/embedding_store/embedding_store.py index bc1ca04..2eecd8c 100644 --- a/vcache/vcache_core/cache/embedding_store/embedding_store.py +++ b/vcache/vcache_core/cache/embedding_store/embedding_store.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple from vcache.vcache_core.cache.embedding_store.embedding_metadata_storage import ( EmbeddingMetadataStorage, @@ -52,5 +52,10 @@ def update_metadata( ) -> "EmbeddingMetadataObj": return self.embedding_metadata_storage.update_metadata(embedding_id, metadata) + def add_observation( + self, embedding_id: int, observation: Tuple[float, int] + ) -> None: + self.embedding_metadata_storage.add_observation(embedding_id, observation) + def is_empty(self) -> bool: return self.vector_db.is_empty() diff --git a/vcache/vcache_core/cache/embedding_store/vector_db/strategies/chroma.py b/vcache/vcache_core/cache/embedding_store/vector_db/strategies/chroma.py index 1649833..fc78dbf 100644 --- a/vcache/vcache_core/cache/embedding_store/vector_db/strategies/chroma.py +++ b/vcache/vcache_core/cache/embedding_store/vector_db/strategies/chroma.py @@ -1,3 +1,4 @@ +import threading from typing import List import chromadb @@ -9,6 +10,17 @@ class ChromaVectorDB(VectorDB): + """A vector database implementation using ChromaDB. + + This class provides a thread-safe vector database that stores embeddings and + performs k-nearest neighbor searches using the ChromaDB library. + + Attributes: + collection (chromadb.Collection): The ChromaDB collection object. + client (chromadb.Client): The ChromaDB client instance. + similarity_metric_type (SimilarityMetricType): The metric for measuring similarity. + """ + def __init__( self, similarity_metric_type: SimilarityMetricType = SimilarityMetricType.COSINE ): @@ -16,48 +28,100 @@ def __init__( self.collection = None self.client = None self.similarity_metric_type = similarity_metric_type + self._operation_lock = threading.RLock() def add(self, embedding: List[float]) -> int: - if self.collection is None: - self._init_vector_store(len(embedding)) - id = self.__next_embedding_id - self.collection.add(embeddings=[embedding], ids=[str(id)]) - self.__next_embedding_id += 1 - return id + """Add an embedding to the database, initializing the collection if needed. + + This method is thread-safe. + + Args: + embedding (List[float]): The embedding vector to add. + + Returns: + int: The unique ID assigned to the added embedding. + """ + with self._operation_lock: + if self.collection is None: + self._init_vector_store(len(embedding)) + + # Atomic ID generation and assignment + embedding_id = self.__next_embedding_id + self.collection.add(embeddings=[embedding], ids=[str(embedding_id)]) + self.__next_embedding_id += 1 + + return embedding_id def remove(self, embedding_id: int) -> int: - if self.collection is None: - raise ValueError("Collection is not initialized") - self.collection.delete(ids=[str(embedding_id)]) - return embedding_id + """Remove an embedding from the database by its ID. + + This method is thread-safe. + + Args: + embedding_id (int): The ID of the embedding to remove. + + Returns: + int: The ID of the removed embedding. + """ + with self._operation_lock: + if self.collection is None: + raise ValueError("Collection is not initialized") + self.collection.delete(ids=[str(embedding_id)]) + return embedding_id def get_knn(self, embedding: List[float], k: int) -> List[tuple[float, int]]: - if self.collection is None: - raise ValueError("Collection is not initialized") - if self.collection.count() == 0: - return [] - k_ = min(k, self.collection.count()) - results = self.collection.query( - query_embeddings=[embedding], n_results=k_, include=["distances"] - ) - distances = results.get("distances", [[]])[0] - ids = results.get("ids", [[]])[0] - return [ - ( - self.transform_similarity_score( - float(dist), self.similarity_metric_type.value - ), - int(idx), + """Find the k-nearest neighbors for a given embedding. + + This method is thread-safe. If `k` is larger than the number of items, + it returns the maximum number of neighbors possible. + + Args: + embedding (List[float]): The query embedding. + k (int): The number of nearest neighbors to return. + + Returns: + List[tuple[float, int]]: List of (similarity_score, embedding_id) tuples. + """ + with self._operation_lock: + if self.collection is None: + # Initialize the store with the dimension of the query embedding + self._init_vector_store(len(embedding)) + if self.collection.count() == 0: + return [] + k_ = min(k, self.collection.count()) + results = self.collection.query( + query_embeddings=[embedding], n_results=k_, include=["distances"] ) - for dist, idx in zip(distances, ids) - ] + distances = results.get("distances", [[]])[0] + ids = results.get("ids", [[]])[0] + return [ + ( + self.transform_similarity_score( + float(dist), self.similarity_metric_type.value + ), + int(idx), + ) + for dist, idx in zip(distances, ids) + ] def reset(self) -> None: - if self.collection is not None: - self.collection.delete(ids=self.collection.get()["ids"]) - self.__next_embedding_id = 0 + """Clear all embeddings from the collection.""" + with self._operation_lock: + if self.collection is not None: + self.collection.delete(ids=self.collection.get()["ids"]) + self.__next_embedding_id = 0 def _init_vector_store(self, embedding_dim: int): + """Initialize the ChromaDB client and collection. + + This method creates a ChromaDB client and a new collection with a unique + name. It configures the collection's metadata to use the specified + similarity metric ('cosine' or 'l2'). This method should be called + within a locked context. + + Args: + embedding_dim (int): The dimension of the embedding vectors. + """ self.client = chromadb.Client() collection_name = f"vcache_collection_{id(self)}" metric_type = self.similarity_metric_type.value @@ -70,9 +134,19 @@ def _init_vector_store(self, embedding_dim: int): raise ValueError(f"Invalid similarity metric type: {metric_type}") self.collection = self.client.create_collection( name=collection_name, - metadata={"dimension": embedding_dim, "hnsw:space": space}, + metadata={"hnsw:space": space}, get_or_create=True, ) def is_empty(self) -> bool: - return self.collection.count() == 0 + """Check if the collection contains any embeddings. + + This method is thread-safe. + + Returns: + bool: True if the collection has no embeddings, False otherwise. + """ + with self._operation_lock: + if self.collection is None: + return True + return self.collection.count() == 0 diff --git a/vcache/vcache_core/cache/embedding_store/vector_db/strategies/faiss.py b/vcache/vcache_core/cache/embedding_store/vector_db/strategies/faiss.py index 913ff8e..6f8119e 100644 --- a/vcache/vcache_core/cache/embedding_store/vector_db/strategies/faiss.py +++ b/vcache/vcache_core/cache/embedding_store/vector_db/strategies/faiss.py @@ -1,3 +1,4 @@ +import threading from typing import List import faiss @@ -10,17 +11,41 @@ class FAISSVectorDB(VectorDB): + """A vector database implementation using FAISS. + + This class provides a thread-safe vector database that stores embeddings and + performs k-nearest neighbor searches using the FAISS library. It supports + both cosine similarity and L2 (Euclidean) distance. + + Attributes: + similarity_metric_type (SimilarityMetricType): The metric for measuring similarity. + index (faiss.Index): The underlying FAISS index. + """ + def __init__( self, similarity_metric_type: SimilarityMetricType = SimilarityMetricType.COSINE ): self.similarity_metric_type = similarity_metric_type self.__next_embedding_id = 0 self.index = None + self._operation_lock = threading.RLock() def transform_similarity_score( self, similarity_score: float, metric_type: str ) -> float: - # Override the default transform_similarity_score method + """Transform a raw score from FAISS to a normalized similarity score. + + For cosine similarity, FAISS's IndexFlatIP returns the inner product, which + is already the similarity, so no transformation is needed. For L2 distance, + the raw distance is converted to a similarity score. + + Args: + similarity_score (float): The raw score from the FAISS index. + metric_type (str): The similarity metric used ('cosine' or 'euclidean'). + + Returns: + float: The transformed similarity score, normalized to [0, 1]. + """ match metric_type: case "cosine": return similarity_score @@ -30,67 +55,126 @@ def transform_similarity_score( raise ValueError(f"Invalid similarity metric type: {metric_type}") def add(self, embedding: List[float]) -> int: - if self.index is None: - self._init_vector_store(len(embedding)) - id = self.__next_embedding_id - ids = np.array([id], dtype=np.int64) - embedding_array = np.array([embedding], dtype=np.float32) - metric_type = self.similarity_metric_type.value - # Normalize the embedding vector if the metric type is cosine - if metric_type == "cosine": - faiss.normalize_L2(embedding_array) - self.index.add_with_ids(embedding_array, ids) - self.__next_embedding_id += 1 - return id + """Add an embedding to the database, initializing the index if needed. + + This method is thread-safe. For cosine similarity, embeddings are + L2-normalized before being added to the index. + + Args: + embedding (List[float]): The embedding vector to add. + + Returns: + int: The unique ID assigned to the added embedding. + """ + with self._operation_lock: + if self.index is None: + self._init_vector_store(len(embedding)) + + # Atomic ID generation and assignment + embedding_id = self.__next_embedding_id + ids = np.array([embedding_id], dtype=np.int64) + embedding_array = np.array([embedding], dtype=np.float32) + metric_type = self.similarity_metric_type.value + # Normalize the embedding vector if the metric type is cosine + if metric_type == "cosine": + faiss.normalize_L2(embedding_array) + self.index.add_with_ids(embedding_array, ids) + self.__next_embedding_id += 1 + + return embedding_id def remove(self, embedding_id: int) -> int: - if self.index is None: - raise ValueError("Index is not initialized") - id_array = np.array([embedding_id], dtype=np.int64) - self.index.remove_ids( - faiss.IDSelectorBatch(id_array.size, faiss.swig_ptr(id_array)) - ) - return embedding_id + """Remove an embedding from the database by its ID. + + This method is thread-safe. + + Args: + embedding_id (int): The ID of the embedding to remove. + + Returns: + int: The ID of the removed embedding. + """ + with self._operation_lock: + if self.index is None: + raise ValueError("Index is not initialized") + id_array = np.array([embedding_id], dtype=np.int64) + self.index.remove_ids( + faiss.IDSelectorBatch(id_array.size, faiss.swig_ptr(id_array)) + ) + return embedding_id def get_knn(self, embedding: List[float], k: int) -> List[tuple[float, int]]: - if self.index is None: - raise ValueError("Index is not initialized") - if self.index.ntotal == 0: - return [] - k_ = min(k, self.index.ntotal) - query_vector = np.array([embedding], dtype=np.float32) - metric_type = self.similarity_metric_type.value - # Normalize the query vector if the metric type is cosine - if metric_type == "cosine": - faiss.normalize_L2(query_vector) - distances, indices = self.index.search(query_vector, k_) - # Filter out results where index is -1 (deleted embeddings) - filtered_results = [ - (distances[0][i], indices[0][i]) - for i in range(len(indices[0])) - if indices[0][i] != -1 - ] - return [ - (self.transform_similarity_score(dist, metric_type), int(idx)) - for dist, idx in filtered_results - ] + """Find the k-nearest neighbors for a given embedding. + + This method is thread-safe. For cosine similarity, the query embedding is + L2-normalized before searching. Invalid IDs (-1) are filtered from results. + + Args: + embedding (List[float]): The query embedding. + k (int): The number of nearest neighbors to return. + + Returns: + List[tuple[float, int]]: List of (similarity_score, embedding_id) tuples. + """ + with self._operation_lock: + if self.index is None: + return [] + if self.index.ntotal == 0: + return [] + k_ = min(k, self.index.ntotal) + embedding_array = np.array([embedding], dtype=np.float32) + metric_type = self.similarity_metric_type.value + # Normalize the embedding vector if the metric type is cosine + if metric_type == "cosine": + faiss.normalize_L2(embedding_array) + similarities, ids = self.index.search(embedding_array, k_) + similarity_scores = [ + self.transform_similarity_score(sim, metric_type) + for sim in similarities[0] + ] + id_list = [int(id) for id in ids[0] if id != -1] # Filter out invalid IDs + return list(zip(similarity_scores[: len(id_list)], id_list)) def reset(self) -> None: - if self.index is not None: - dim = self.index.d - self._init_vector_store(dim) - self.__next_embedding_id = 0 + """Clear all embeddings from the database and reset the index.""" + with self._operation_lock: + self.index = None + self.__next_embedding_id = 0 def _init_vector_store(self, embedding_dim: int): + """Initialize the FAISS index. + + This method selects the appropriate FAISS index type based on the + similarity metric (IndexFlatIP for cosine, IndexFlatL2 for Euclidean) + and wraps it with an IDMap to support custom embedding IDs. It must be + called within a locked context. + + Args: + embedding_dim (int): The dimension of the embedding vectors. + """ metric_type = self.similarity_metric_type.value match metric_type: case "cosine": - faiss_metric = faiss.METRIC_INNER_PRODUCT + # Use IndexFlatIP for cosine similarity (inner product after normalization) + self.index = faiss.IndexFlatIP(embedding_dim) case "euclidean": - faiss_metric = faiss.METRIC_L2 + # Use IndexFlatL2 for euclidean distance + self.index = faiss.IndexFlatL2(embedding_dim) case _: raise ValueError(f"Invalid similarity metric type: {metric_type}") - self.index = faiss.index_factory(embedding_dim, "IDMap,Flat", faiss_metric) + + # Wrap with IDMap to support custom IDs + self.index = faiss.IndexIDMap(self.index) def is_empty(self) -> bool: - return self.index.ntotal == 0 + """Check if the database contains any embeddings. + + This method is thread-safe. + + Returns: + bool: True if the database has no embeddings, False otherwise. + """ + with self._operation_lock: + if self.index is None: + return True + return self.index.ntotal == 0 diff --git a/vcache/vcache_core/cache/embedding_store/vector_db/strategies/hnsw_lib.py b/vcache/vcache_core/cache/embedding_store/vector_db/strategies/hnsw_lib.py index 1370bf5..480c798 100644 --- a/vcache/vcache_core/cache/embedding_store/vector_db/strategies/hnsw_lib.py +++ b/vcache/vcache_core/cache/embedding_store/vector_db/strategies/hnsw_lib.py @@ -1,3 +1,4 @@ +import threading from typing import List import hnswlib @@ -13,6 +14,17 @@ class HNSWLibVectorDB(VectorDB): + """A vector database implementation using HNSWLib. + + This class provides a thread-safe vector database that stores embeddings and + performs k-nearest neighbor searches using the HNSWLib library. + + Attributes: + embedding_count (int): The current number of embeddings in the database. + similarity_metric_type (SimilarityMetricType): The metric for measuring similarity. + index (hnswlib.Index): The underlying HNSWLib index. + """ + def __init__( self, similarity_metric_type: SimilarityMetricType = SimilarityMetricType.COSINE, @@ -28,44 +40,96 @@ def __init__( self.M = None self.ef = None self.index = None + self._operation_lock = threading.RLock() def add(self, embedding: List[float]) -> int: - if self.index is None: - self._init_vector_store(len(embedding)) - self.index.add_items(embedding, self.__next_embedding_id) - self.embedding_count += 1 - self.__next_embedding_id += 1 - return self.__next_embedding_id - 1 + """Add an embedding to the database, initializing the index if needed. + + This method is thread-safe. + + Args: + embedding (List[float]): The embedding vector to add. + + Returns: + int: The unique ID assigned to the added embedding. + """ + with self._operation_lock: + if self.index is None: + self._init_vector_store(len(embedding)) + + # Atomic ID generation and assignment + embedding_id = self.__next_embedding_id + self.index.add_items(embedding, embedding_id) + self.embedding_count += 1 + self.__next_embedding_id += 1 + + return embedding_id def remove(self, embedding_id: int) -> int: - if self.index is None: - raise ValueError("Index is not initialized") - self.index.mark_deleted(embedding_id) - self.embedding_count -= 1 - return embedding_id + """Mark an embedding as deleted in the database. + + This method is thread-safe. Note that HNSWLib only marks items for + deletion; it does not immediately reclaim the space. + + Args: + embedding_id (int): The ID of the embedding to remove. + + Returns: + int: The ID of the removed embedding. + """ + with self._operation_lock: + if self.index is None: + raise ValueError("Index is not initialized") + self.index.mark_deleted(embedding_id) + self.embedding_count -= 1 + return embedding_id def get_knn(self, embedding: List[float], k: int) -> List[tuple[float, int]]: - if self.index is None: - return [] - k_ = min(k, self.embedding_count) - if k_ == 0: - return [] - ids, similarities = self.index.knn_query(embedding, k=k_) - metric_type = self.similarity_metric_type.value - similarity_scores = [ - self.transform_similarity_score(sim, metric_type) for sim in similarities[0] - ] - id_list = [int(id) for id in ids[0]] - return list(zip(similarity_scores, id_list)) + """Find the k-nearest neighbors for a given embedding. + + This method is thread-safe. If the database is empty or `k` is larger + than the number of items, it returns the maximum number of neighbors possible. + + Args: + embedding (List[float]): The query embedding. + k (int): The number of nearest neighbors to return. + + Returns: + List[tuple[float, int]]: List of (similarity_score, embedding_id) tuples. + """ + with self._operation_lock: + if self.index is None: + return [] + k_ = min(k, self.embedding_count) + if k_ == 0: + return [] + ids, similarities = self.index.knn_query(embedding, k=k_) + metric_type = self.similarity_metric_type.value + similarity_scores = [ + self.transform_similarity_score(sim, metric_type) + for sim in similarities[0] + ] + id_list = [int(id) for id in ids[0]] + return list(zip(similarity_scores, id_list)) def reset(self) -> None: - if self.dim is None: - return - self._init_vector_store(self.dim) - self.embedding_count = 0 - self.__next_embedding_id = 0 + """Clear all embeddings from the database.""" + with self._operation_lock: + if self.dim is None: + return + self._init_vector_store(self.dim) + self.embedding_count = 0 + self.__next_embedding_id = 0 def _init_vector_store(self, embedding_dim: int): + """Initialize the HNSWLib index. + + This method configures the index with the specified space ('cosine' or 'l2') + and dimensions. It must be called within a locked context. + + Args: + embedding_dim (int): The dimension of the embedding vectors. + """ metric_type = self.similarity_metric_type.value match metric_type: case "cosine": @@ -87,4 +151,12 @@ def _init_vector_store(self, embedding_dim: int): self.index.set_ef(self.ef) def is_empty(self) -> bool: - return self.embedding_count == 0 + """Check if the database contains any embeddings. + + This method is thread-safe. + + Returns: + bool: True if the database has no embeddings, False otherwise. + """ + with self._operation_lock: + return self.embedding_count == 0 diff --git a/vcache/vcache_core/cache/embedding_store/vector_db/vector_db.py b/vcache/vcache_core/cache/embedding_store/vector_db/vector_db.py index dba81c5..adbdb14 100644 --- a/vcache/vcache_core/cache/embedding_store/vector_db/vector_db.py +++ b/vcache/vcache_core/cache/embedding_store/vector_db/vector_db.py @@ -9,13 +9,23 @@ class SimilarityMetricType(Enum): class VectorDB(ABC): + """Abstract base class for vector database implementations.""" + def transform_similarity_score( self, similarity_score: float, metric_type: str ) -> float: - """ - similarity_score: float - The similarity score to transform - metric_type: SimilarityMetricType - The type of similarity metric - returns: float - The transformed similarity score in the range of [0, 1] + """Transform a distance-based score into a normalized similarity score. + + This function converts raw distance scores from vector databases (like + Euclidean or cosine distance) into a unified similarity score from 0 to 1, + where 1 means most similar. + + Args: + similarity_score (float): The raw distance score from the vector database. + metric_type (str): The similarity metric used (e.g., 'cosine', 'euclidean'). + + Returns: + float: The transformed similarity score, normalized to the range [0, 1]. """ match metric_type: case "cosine": @@ -27,46 +37,63 @@ def transform_similarity_score( @abstractmethod def add(self, embedding: List[float]) -> int: - """ - embedding: List[float] - The embedding to add to the vector db - returns: int - The id of the embedding + """Add an embedding to the vector database. + + Args: + embedding (List[float]): The embedding vector to add. + + Returns: + int: The unique ID assigned to the embedding. """ pass @abstractmethod def remove(self, embedding_id: int) -> int: - """ - embedding_id: int - The id of the embedding to remove - returns: int - The id of the embedding + """Remove an embedding from the vector database. + + Args: + embedding_id (int): The ID of the embedding to remove. + + Returns: + int: The ID of the removed embedding. """ pass @abstractmethod def get_knn(self, embedding: List[float], k: int) -> List[tuple[float, int]]: - """ - embedding: List[float] - The embedding to get the k-nearest neighbors for - k: int - The number of nearest neighbors to get - returns: List[tuple[float, int]] - A list of tuples, each containing a similarity score and an embedding id + """Find the k-nearest neighbors for a given embedding. + + Args: + embedding (List[float]): The query embedding. + k (int): The number of nearest neighbors to return. + + Returns: + List[tuple[float, int]]: A list of (similarity_score, embedding_id) tuples. """ pass @abstractmethod def reset(self) -> None: - """ - Resets the vector db - """ + """Clear all embeddings from the vector database.""" pass @abstractmethod def _init_vector_store(self, embedding_dim: int): - """ - embedding_dim: int - The dimension of the embedding + """Initialize the underlying vector store. + + This method should be called within a lock and is responsible for setting up + the database with the correct dimensions and parameters. + + Args: + embedding_dim (int): The dimension of the embedding vectors. """ pass @abstractmethod def is_empty(self) -> bool: - """ - Returns: bool - Whether the vector db is empty + """Check if the vector database contains no embeddings. + + Returns: + bool: True if the database is empty, False otherwise. """ pass diff --git a/vcache/vcache_core/similarity_evaluator/similarity_evaluator.py b/vcache/vcache_core/similarity_evaluator/similarity_evaluator.py index ebdf704..ab1c9b8 100644 --- a/vcache/vcache_core/similarity_evaluator/similarity_evaluator.py +++ b/vcache/vcache_core/similarity_evaluator/similarity_evaluator.py @@ -10,9 +10,14 @@ def __init__(self): @abstractmethod def answers_similar(self, a: str, b: str) -> bool: """ - a: str - The first answer - b: str - The second answer - returns: bool - True if the answers are similar, False otherwise + Evaluates the similarity between two answers. + + Args: + a: str - The first answer + b: str - The second answer + + Returns: + bool - True if the answers are similar, False otherwise """ pass diff --git a/vcache/vcache_policy/strategies/dynamic_global_threshold.py b/vcache/vcache_policy/strategies/dynamic_global_threshold.py index 53ac95d..abc7354 100644 --- a/vcache/vcache_policy/strategies/dynamic_global_threshold.py +++ b/vcache/vcache_policy/strategies/dynamic_global_threshold.py @@ -26,6 +26,8 @@ def __init__( delta: float = 0.01, ): """ + IMPORTANT: This policy is used as an ablation for the DynamicLocalThresholdPolicy and should not be used in production. + This policy uses the vCache algorithm to compute the optimal threshold across all embeddings. Each threshold is used to determine if a response is a cache hit. This is suboptimal in cases when the embeddings cannot seperate correct from incorrect responses. @@ -93,12 +95,12 @@ def process_request( similarity_score=similarity_score, is_correct=should_have_exploited, metadata=metadata, + cache=self.cache, + embedding_id=embedding_id, ) if not should_have_exploited: self.cache.add(prompt=prompt, response=response) - self.cache.update_metadata( - embedding_id=embedding_id, embedding_metadata=metadata - ) + return False, response, metadata.response @@ -171,7 +173,12 @@ def __init__(self, delta: float): } def update_metadata( - self, similarity_score: float, is_correct: bool, metadata: EmbeddingMetadataObj + self, + similarity_score: float, + is_correct: bool, + metadata: EmbeddingMetadataObj, + cache: Cache, + embedding_id: int, ) -> None: """ Update the metadata with the new observation @@ -179,12 +186,16 @@ def update_metadata( similarity_score: float - The similarity score between the query and the embedding is_correct: bool - Whether the query was correct metadata: EmbeddingMetadataObj - The metadata of the embedding + cache: Cache - The cache to update the metadata for + embedding_id: int - The id of the embedding to update the metadata for """ if is_correct: self.global_observations.append((round(similarity_score, 3), 1)) else: self.global_observations.append((round(similarity_score, 3), 0)) + cache.update_metadata(embedding_id=embedding_id, embedding_metadata=metadata) + def select_action( self, similarity_score: float, metadata: EmbeddingMetadataObj ) -> _Action: diff --git a/vcache/vcache_policy/strategies/dynamic_local_threshold.py b/vcache/vcache_policy/strategies/dynamic_local_threshold.py index 6e4a26f..4a822eb 100644 --- a/vcache/vcache_policy/strategies/dynamic_local_threshold.py +++ b/vcache/vcache_policy/strategies/dynamic_local_threshold.py @@ -1,4 +1,6 @@ +import logging import random +from concurrent.futures import ThreadPoolExecutor from enum import Enum from typing import Dict, List, Optional, Tuple @@ -23,22 +25,38 @@ class DynamicLocalThresholdPolicy(VCachePolicy): - def __init__(self, delta: float = 0.01): - """ - This policy uses the vCache algorithm to compute the optimal threshold for each - embedding in the cache. - Each threshold is used to determine if a response is a cache hit. + """A policy that uses a dynamic, per-embedding threshold to make cache decisions. + + This policy implements the vCache algorithm, which uses a probabilistic approach + to learn an optimal similarity threshold for each cached item. It balances + exploiting the cache and exploring new responses to refine its decision boundaries. + + Attributes: + bayesian (_Algorithm): The core algorithm for action selection and updates. + similarity_evaluator (SimilarityEvaluator): Component for comparing responses. + inference_engine (InferenceEngine): The LLM for generating new responses. + cache (Cache): The vCache instance. + """ + + def __init__(self, delta: float = 0.01, max_background_workers: int = 100): + """Initializes the policy. - Args - delta: float - The delta value to use + Args: + delta (float): The desired error bound the cache needs to maintain. + max_background_workers (int): Max threads for background processing. """ self.bayesian = _Algorithm(delta=delta) self.similarity_evaluator: SimilarityEvaluator = None self.inference_engine: InferenceEngine = None self.cache: Cache = None + self._executor = ThreadPoolExecutor( + max_workers=max_background_workers, thread_name_prefix="vcache-bg" + ) + self._logger = logging.getLogger(__name__) @override def setup(self, config: VCacheConfig): + """Configure the policy with the necessary components from VCacheConfig.""" self.similarity_evaluator = config.similarity_evaluator self.inference_engine = config.inference_engine self.cache = Cache( @@ -54,12 +72,22 @@ def setup(self, config: VCacheConfig): def process_request( self, prompt: str, system_prompt: Optional[str] ) -> tuple[bool, str, str]: - """ - Args - prompt: str - The prompt to check for cache hit - system_prompt: Optional[str] - The optional system prompt to use for the response. It will override the system prompt in the VCacheConfig if provided. - Returns - tuple[bool, str, str] - [is_cache_hit, actual_response, nn_response] + """Process a request to decide whether to serve from cache or generate a new response. + + This method finds the nearest neighbor in the cache. If none exists, it + generates a new response. Otherwise, it uses the Bayesian algorithm to + decide whether to EXPLOIT (use the cached response) or EXPLORE (generate a + new one). In the EXPLORE case, label generation happens in the background. + + Args: + prompt (str): The user's prompt. + system_prompt (str, optional): An optional system prompt to guide the LLM. + + Returns: + tuple[bool, str, str]: A tuple containing: + - is_cache_hit (bool): True if the response is from the cache (EXPLOIT). + - actual_response (str): The response served. + - nn_response (str): The nearest neighbor's response, if one was found. """ if self.inference_engine is None or self.cache is None: raise ValueError("Policy has not been setup") @@ -85,21 +113,68 @@ def process_request( response = self.inference_engine.create( prompt=prompt, system_prompt=system_prompt ) - should_have_exploited = self.similarity_evaluator.answers_similar( - a=response, b=metadata.response - ) - self.bayesian.update_metadata( + + self._executor.submit( + self._generate_label, + response=response, + nn_response=metadata.response, similarity_score=similarity_score, - is_correct=should_have_exploited, - metadata=metadata, - ) - if not should_have_exploited: - self.cache.add(prompt=prompt, response=response) - self.cache.update_metadata( - embedding_id=embedding_id, embedding_metadata=metadata + embedding_id=embedding_id, + prompt=prompt, ) + return False, response, metadata.response + def _generate_label( + self, + response: str, + nn_response: str, + similarity_score: float, + embedding_id: int, + prompt: str, + ): + """Generate a label for a response and update metadata. + + This function runs in a background thread. It compares the newly generated + response with the nearest neighbor's response to determine if the cache + *should* have been hit. It then updates the metadata with this new + observation and adds the new response to the cache if it was dissimilar. + + Args: + response (str): The newly generated response. + nn_response (str): The cached response of the nearest neighbor. + similarity_score (float): The similarity between the query and the neighbor. + embedding_id (int): The ID of the nearest neighbor embedding to update. + prompt (str): The original prompt, to be cached if the new response is kept. + """ + try: + should_have_exploited = self.similarity_evaluator.answers_similar( + a=response, b=nn_response + ) + + label: int = 1 if should_have_exploited else 0 + observation: Tuple[float, int] = (round(similarity_score, 3), label) + self.cache.add_observation( + embedding_id=embedding_id, observation=observation + ) + + if not should_have_exploited: + self.cache.add(prompt=prompt, response=response) + + except Exception as e: + self._logger.error( + f"Error in background label generation: {e}", exc_info=True + ) + + def __enter__(self): + """Enter the runtime context related to this object.""" + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Exit the runtime context and shutdown the thread pool.""" + if hasattr(self, "_executor") and self._executor: + self._executor.shutdown(wait=True) + class _Action(Enum): EXPLORE = "explore" @@ -107,7 +182,14 @@ class _Action(Enum): class _Algorithm: + """Implements the Bayesian algorithm for the DynamicLocalThresholdPolicy.""" + def __init__(self, delta: float): + """Initializes the algorithm. + + Args: + delta (float): The desired error bound the cache needs to maintain. + """ self.delta: float = delta self.P_c: float = 1.0 - self.delta self.epsilon_grid: np.ndarray = np.linspace(1e-6, 1 - 1e-6, 50) @@ -161,31 +243,21 @@ def __init__(self, delta: float): 48: 0.01531, } - def update_metadata( - self, similarity_score: float, is_correct: bool, metadata: EmbeddingMetadataObj - ) -> None: - """ - Update the metadata with the new observation - Args - similarity_score: float - The similarity score between the query and the embedding - is_correct: bool - Whether the query was correct - metadata: EmbeddingMetadataObj - The metadata of the embedding - """ - if is_correct: - metadata.observations.append((round(similarity_score, 3), 1)) - else: - metadata.observations.append((round(similarity_score, 3), 0)) - def select_action( self, similarity_score: float, metadata: EmbeddingMetadataObj ) -> _Action: - """ - Select the action to take based on the similarity score, observations, and accuracy target - Args - similarity_score: float - The similarity score between the query and the embedding - metadata: EmbeddingMetadataObj - The metadata of the embedding - Returns - Action - Explore or Exploit + """Select whether to EXPLORE or EXPLOIT based on the learned threshold. + + This method estimates the current threshold `t_hat` from observations, + calculates a confidence-based exploration probability `tau`, and then + randomly decides whether to explore or exploit. + + Args: + similarity_score (float): The similarity of the current prompt to the cache entry. + metadata (EmbeddingMetadataObj): The metadata of the cache entry. + + Returns: + _Action: The action to perform, either EXPLORE or EXPLOIT. """ similarity_score = round(similarity_score, 3) similarities: np.ndarray = np.array([obs[0] for obs in metadata.observations]) @@ -217,16 +289,20 @@ def select_action( def _estimate_parameters( self, similarities: np.ndarray, labels: np.ndarray ) -> Tuple[float, float, float]: - """ - Optimize parameters with logistic regression - Args - similarities: np.ndarray - The similarities of the embeddings - labels: np.ndarray - The labels of the embeddings - metadata: EmbeddingMetadataObj - The metadata of the embedding - Returns - t_hat: float - The estimated threshold - gamma: float - The estimated gamma - var_t: float - The estimated variance of t + """Estimate logistic regression parameters from observations. + + This method fits a logistic regression model to the similarity scores and + labels to estimate the decision boundary parameters. + + Args: + similarities (np.ndarray): The observed similarity scores. + labels (np.ndarray): The observed labels (1 for correct, 0 for incorrect). + + Returns: + Tuple[float, float, float]: A tuple containing: + - t_hat: The estimated similarity threshold. + - gamma: The steepness of the logistic curve. + - var_t: The variance of the threshold estimate. """ similarities = sm.add_constant(similarities) @@ -269,19 +345,21 @@ def _get_var_t( gamma: float, intercept: float, ) -> float: - """ - Compute the variance of t using the delta method - Args - perfect_seperation: bool - Whether the data is perfectly separable - n_observations: int - The number of observations - X: np.ndarray - The design matrix - gamma: float - The gamma parameter - intercept: float - The intercept parameter - Returns - float - The variance of t - Note: - If the data is perfectly separable, we use the variance map to estimate the variance of t - Otherwise, we use the delta method to estimate the variance of t + """Compute the variance of the threshold estimate `t_hat`. + + If the data is perfectly separable, it uses a pre-computed variance map. + Otherwise, it uses the delta method to approximate the variance from the + logistic regression's covariance matrix. + + Args: + perfect_seperation (bool): True if the data is perfectly separable. + n_observations (int): The number of observations. + X (np.ndarray): The design matrix for the regression. + gamma (float): The gamma parameter from the regression. + intercept (float): The intercept from the regression. + + Returns: + float: The variance of the threshold estimate. """ if perfect_seperation: if n_observations in self.variance_map: @@ -310,15 +388,20 @@ def _get_tau( t_hat: float, metadata: EmbeddingMetadataObj, ) -> float: - """ - Find the minimum tau value for the given similarity score - Args - var_t: float - The variance of t - s: float - The similarity score between the query and the nearest neighbor - t_hat: float - The estimated threshold - metadata: EmbeddingMetadataObj - The metadata of the nearest neighbor - Returns - float - The minimum tau value + """Calculate the exploration probability `tau`. + + This method computes `tau`, the probability of choosing to EXPLORE. It's + based on finding the worst-case (minimum) confidence that the current + action is correct, considering the uncertainty in the threshold `t_hat`. + + Args: + var_t (float): The variance of the threshold estimate. + s (float): The similarity score of the current request. + t_hat (float): The estimated threshold. + metadata (EmbeddingMetadataObj): The metadata of the cache entry. + + Returns: + float: The calculated exploration probability, `tau`. """ t_primes: List[float] = self._get_t_primes(t_hat=t_hat, var_t=var_t) likelihoods = self._likelihood(s=s, t=t_primes, gamma=metadata.gamma) @@ -329,13 +412,14 @@ def _get_tau( return round(np.min(taus), 5) def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]: - """ - Compute all possible t_prime values. - Args - t_hat: float - The estimated threshold - var_t: float - The variance of t - Returns - List[float] - The t_prime values + """Compute a grid of possible threshold values based on the confidence interval. + + Args: + t_hat (float): The estimated threshold. + var_t (float): The variance of the threshold estimate. + + Returns: + List[float]: A list of potential threshold values (`t_prime`). """ t_primes: List[float] = np.array( [ @@ -350,29 +434,37 @@ def _get_t_primes(self, t_hat: float, var_t: float) -> List[float]: def _confidence_interval( self, t_hat: float, var_t: float, quantile: float ) -> float: - """ - Return the (upper) quantile-threshold t' such that - P_est( t > t' ) <= 1 - quantile - Args - t_hat: float - The estimated threshold - var_t: float - The variance of t - quantile: float - The quantile - Returns - float - The t_prime value + """Calculate the upper bound of a confidence interval for the threshold `t`. + + This computes a threshold `t'` such that the estimated probability of the true + threshold being greater than `t'` is less than or equal to `1 - quantile`. + + Args: + t_hat (float): The estimated threshold. + var_t (float): The variance of the threshold estimate. + quantile (float): The desired quantile for the confidence interval. + + Returns: + float: The upper bound of the confidence interval (`t_prime`). """ z = norm.ppf(quantile) t_prime = t_hat + z * np.sqrt(var_t) return float(np.clip(t_prime, 0.0, 1.0)) def _likelihood(self, s: float, t: float, gamma: float) -> float: - """ - Compute the likelihood of the given similarity score and threshold - Args - s: float - The similarity score between the query and the nearest neighbor - t: float - The threshold - gamma: float - The gamma parameter - Returns - float - The likelihood of the given similarity score and threshold + """Compute the likelihood of a correct cache hit given a similarity score. + + This function uses the logistic (sigmoid) function to model the + probability of a correct match based on the similarity `s`, a threshold + `t`, and a steepness parameter `gamma`. + + Args: + s (float): The similarity score. + t (float): The decision threshold. + gamma (float): The steepness of the logistic curve. + + Returns: + float: The likelihood of a correct match. """ z = gamma * (s - t) return expit(z) diff --git a/vcache/vcache_policy/strategies/iid_local_threshold.py b/vcache/vcache_policy/strategies/iid_local_threshold.py index 8151beb..b60f582 100644 --- a/vcache/vcache_policy/strategies/iid_local_threshold.py +++ b/vcache/vcache_policy/strategies/iid_local_threshold.py @@ -92,12 +92,12 @@ def process_request( similarity_score=similarity_score, is_correct=should_have_exploited, metadata=metadata, + cache=self.cache, + embedding_id=embedding_id, ) if not should_have_exploited: self.cache.add(prompt=prompt, response=response) - self.cache.update_metadata( - embedding_id=embedding_id, embedding_metadata=metadata - ) + return False, response, metadata.response @@ -113,7 +113,12 @@ def __init__(self, delta: float): self.thold_grid: np.ndarray = np.linspace(0, 1, 20) def update_metadata( - self, similarity_score: float, is_correct: bool, metadata: EmbeddingMetadataObj + self, + similarity_score: float, + is_correct: bool, + metadata: EmbeddingMetadataObj, + cache: Cache, + embedding_id: int, ) -> None: """ Update the metadata with the new observation @@ -121,12 +126,16 @@ def update_metadata( similarity_score: float - The similarity score between the query and the embedding is_correct: bool - Whether the query was correct metadata: EmbeddingMetadataObj - The metadata of the embedding + cache: Cache - The cache to update the metadata for + embedding_id: int - The id of the embedding to update the metadata for """ if is_correct: metadata.observations.append((round(similarity_score, 3), 1)) else: metadata.observations.append((round(similarity_score, 3), 0)) + cache.update_metadata(embedding_id=embedding_id, embedding_metadata=metadata) + def wilson_proportion_ci(self, cdf_estimates, n, confidence): """ Vectorized Wilson score confidence interval for binomial proportions. diff --git a/vcache/vcache_policy/strategies/static_global_threshold.py b/vcache/vcache_policy/strategies/static_global_threshold.py index ff4248a..4b41492 100644 --- a/vcache/vcache_policy/strategies/static_global_threshold.py +++ b/vcache/vcache_policy/strategies/static_global_threshold.py @@ -61,6 +61,7 @@ def process_request( similarity_score, embedding_id = knn[0] metadata = self.cache.get_metadata(embedding_id=embedding_id) is_cache_hit = similarity_score >= self.threshold + if is_cache_hit: return True, metadata.response, metadata.response else: diff --git a/vcache/vcache_policy/vcache_policy.py b/vcache/vcache_policy/vcache_policy.py index 920d1d8..2019db4 100644 --- a/vcache/vcache_policy/vcache_policy.py +++ b/vcache/vcache_policy/vcache_policy.py @@ -5,11 +5,17 @@ class VCachePolicy(ABC): + """Abstract base class for vCache caching policies.""" + @abstractmethod def setup(self, config: VCacheConfig): - """ - Setup the policy with the given config. - config: VCacheConfig - The config to setup the policy with. + """Configure the policy with the necessary components. + + This method is called once to initialize the policy with inference engines, + cache configurations, and other required components. + + Args: + config (VCacheConfig): The configuration object for the policy. """ pass @@ -17,9 +23,19 @@ def setup(self, config: VCacheConfig): def process_request( self, prompt: str, system_prompt: Optional[str] ) -> tuple[bool, str, str]: - """ - prompt: str - The prompt to check for cache hit - system_prompt: Optional[str] - The optional system prompt to use for the response. It will override the system prompt in the VCacheConfig if provided. - returns: tuple[bool, str, str] - [is_cache_hit, actual_response, nn_response] + """Process a request to decide whether to use a cached response. + + This method determines if a prompt can be served from the cache (a hit) + or if it requires a new generation from the inference engine (a miss). + + Args: + prompt (str): The user's prompt. + system_prompt (str, optional): An optional system prompt to guide the LLM. + + Returns: + tuple[bool, str, str]: A tuple containing: + - is_cache_hit (bool): True if the response is from the cache. + - actual_response (str): The response served (from cache or new). + - nn_response (str): The nearest neighbor's response if one was found. """ pass