# Semantic Cash


**Content:** integrate a semantic caching system into a typical RAG solution.

**Model:** [google/gemma-2b-it](https://huggingface.co/google/gemma-2b-it)

**Vector Database:** ChromaDB - the most popular OpenSource.

**Vector Similarity Search Library**: FIASS

**Dataset:** [keivalya/MedQuad-MedicalQnADataset](https://huggingface.co/datasets/keivalya/MedQuad-MedicalQnADataset)

**Environment:** Kaggle (free)

**Keywords:** RAG, Sematic Cashing, Vector Database, Vector Similarity.

[Source](https://github.com/peremartra/Large-Language-Model-Notebooks-Course/blob/main/2-Vector%20Databases%20with%20LLMs/semantic_cache_chroma_vector_database.ipynb)


The **semantic cache system** stores user queries and decides whether to generate a prompt enriched with information from the vector database or to reuse information from the cache. Its goal is to identify similar or identical **user requests**. When a matching request is found, the corresponding information is retrieved from the cache, reducing the need to query the original source. A RAG system without a semantic cache quickly becomes insufficient in **production** environments, where it may encounter tens to thousands of recurring requests. One way to improve performance is by introducing **one or multiple semantic caches**.

Because the comparison is based on **semantic meaning**, requests may be phrased differently while still referring to the same information need. Although the model’s final response may vary depending on the wording of the request, the information retrieved from the vector database should remain the same. Therefore, the cache system is positioned **between** the **user** and the **vector database**, rather than between the user and the LLM.


In a RAG system, there are **two main time-consuming** steps:
1. Retrieving the information used to construct the enriched prompt.
2. Calling the large language model to generate the response.
A semantic cache can be introduced at either of these points, and in some cases, two separate caches may be used—one for each step.

Placing the cache at the model-response stage may lead to a loss  of influence over the obtained response. If the cache stores model responses, users may perceive that their instructions are not being followed accurately. In contrast, similar user requests typically require the same information to enrich the prompt. For this reason, in the current notebook, the semantic cache is positioned between the user’s request and the retrieval of information from the vector database.

This is a **design decision**. Depending on the nature of the requests and the desired behavior of the system, the cache can be placed at **different points in the pipeline**. While caching model responses can offer the greatest performance gains, it comes at the cost of losing user influence over the response.

## 1. Set Up

## 1.1. Import & Load the libraries

Install/reinstall the correct PyTorch / CUDA / NLP stack:

In [None]:
!pip uninstall -y torch torchvision torchaudio fastai
!pip install --no-cache-dir \
  torch==2.8.0 \
  torchvision==0.23.0 \
  torchaudio==2.8.0 \
  --index-url https://download.pytorch.org/whl/cu126

[accelerate](https://github.com/huggingface/accelerate): necesary to run the Model in a GPU.

In [None]:
!pip install -q \
  transformers \
  accelerate \
  sentence-transformers \
  chromadb \
  datasets \
  faiss-cpu

## 1.2. Load the Dataset

Due to limited free resources and memory constraints, the number of dataset rows is restricted using the `MAX_ROWS` variable.

**Note:** login to Hugging Face is mandatory to use the Gemma Model.


In [4]:
import numpy as np
import pandas as pd

In [6]:
from datasets import load_dataset

data = load_dataset("keivalya/MedQuad-MedicalQnADataset", split='train')

README.md:   0%|          | 0.00/233 [00:00<?, ?B/s]

medDataset_processed.csv:   0%|          | 0.00/22.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/16407 [00:00<?, ? examples/s]

In [7]:
data = data.to_pandas()
data["id"]=data.index   # ChromaDB requires that the data has a unique identifier
data.head(10)

Unnamed: 0,qtype,Question,Answer,id
0,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...,0
1,symptoms,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...,1
2,susceptibility,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...,2
3,exams and tests,How to diagnose Lymphocytic Choriomeningitis (...,"During the first phase of the disease, the mos...",3
4,treatment,What are the treatments for Lymphocytic Chorio...,"Aseptic meningitis, encephalitis, or meningoen...",4
5,prevention,How to prevent Lymphocytic Choriomeningitis (L...,LCMV infection can be prevented by avoiding co...,5
6,information,What is (are) Parasites - Cysticercosis ?,Cysticercosis is an infection caused by the la...,6
7,susceptibility,Who is at risk for Parasites - Cysticercosis? ?,Cysticercosis is an infection caused by the la...,7
8,exams and tests,How to diagnose Parasites - Cysticercosis ?,"If you think that you may have cysticercosis, ...",8
9,treatment,What are the treatments for Parasites - Cystic...,Some people with cysticercosis do not need to ...,9


In [8]:
MAX_ROWS = 1000
DOCUMENT="Answer"
TOPIC="qtype"

subset_data = data.head(MAX_ROWS)

## 1.3. Vector DB

## 1.3.1. Import & Configure

In [9]:
import chromadb
from chromadb.config import Settings   # to change the setting for the ChromaDB system, and customize its behavior

In [10]:
chroma_client = chromadb.PersistentClient(path="/path/to/persist/directory")   # path: where the vector database will be stored

### 1.3.2. Filling & Querying

In [11]:
# creating the collection

collection_name = "med_collection"

if len(chroma_client.list_collections()) > 0 and collection_name in [chroma_client.list_collections()[0].name]:
        chroma_client.delete_collection(name=collection_name)

collection = chroma_client.create_collection(name=collection_name)

In [12]:
# add the data to the collection

collection.add(
    documents=subset_data[DOCUMENT].tolist(),
    metadatas=[{TOPIC: topic} for topic in subset_data[TOPIC].tolist()],
    ids=[f"id{x}" for x in range(MAX_ROWS)],
)

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:02<00:00, 33.8MiB/s]


Define a function to query the ChromaDB Database:

In [13]:
def query_database(query_text, n_results=10):
    results = collection.query(query_texts=query_text, n_results=n_results )
    return results

# 2. Semantic Cache System

**FAISS** is used to implement the cache system. It provides **in-memory storage** of embeddings and works similarly to ChromaDB, but without persistent storage.


**FAISS** (Facebook AI Similarity Search) is a **library** for **fast similarity search** and **clustering** of dense vectors. It is a high-performance vector similarity search library used for embedding-based retrieval:
- It provides fast **nearest-neighbor** search over embeddings.
- It is a low-level engine, not a database.
- Runs inside the application (CPU or GPU).

In [14]:
import faiss

In [None]:
from sentence_transformers import SentenceTransformer

In [16]:
import time
import json

## 2.1. Class: semantic_cache

A class called `semantic_cache` is created that uses its own encoder and provides the necessary functions for the user to perform queries. First, **FAISS** (the cache) is queried; if the returned results exceed the specified threshold, the result is returned from the cache. Otherwise, the result is fetched from **ChromaDB**. The cache is stored in a `.json` file.

The following functions will be used within the `semantic_cache` class.

### 2.1.1. Function: Initializing the Cache

The `FlatLS` **index** is used, which is not the fastest but works well for **small datasets**.

**Note:** depending on the **characteristics of the data** intended for the cache and the expected **dataset size**, another index such as `HNSW` or `IVF` could be utilized.

`IndexFlatL2` is a simple, **exact nearest-neighbor index** using **L2 (Euclidean)** distance.
- The **dimension of the embeddings** should **match** the output size of the chosen Sentence Transformer model.
- This **index** stores vectors in memory and allows quick similarity search.

`is_trained`:
- For **flat (non-compressed) indices** like IndexFlatL2, the index is **always** considered **trained**.
- For **other** FAISS indices (like IVF), you would **need to train the index** on sample vectors first.

`'all-mpnet-base-v2'`:
- Is the **pretrained model** used to convert *text* into *768-dimensional embeddings*.
- Will be used to **transform** user queries or documents into **vectors** that can be stored in the FAISS index.

In [17]:
def init_cache():
    '''
    initialize the components needed for an in-memory FAISS-based semantic cache
    '''
    # creates a FAISS index for storing embeddings
    index = faiss.IndexFlatL2(768)   # 768 is the dimension of the embeddings

    # checks if the FAISS index is “trained”
    if index.is_trained:
        print('Index trained') # this will always be printed! (IndexFlatL2 is a flat index)

    # Initialize Sentence Transformer model
    encoder = SentenceTransformer('all-mpnet-base-v2')
    
    return index, encoder

### 2.1.2. Finction: Retrieving from Cache

In [18]:
def retrieve_cache(json_file):
    '''
    json_file is retrieved from disk in case there is a need to reuse the cache across sessions.
    '''
    try:
        with open(json_file, 'r') as file:
            cache = json.load(file)
    except FileNotFoundError:
        cache = {'questions': [], 'embeddings': [], 'answers': [], 'response_text': []}

    return cache

### 2.1.3. Finction: Storing in Cache

In [19]:
def store_cache(json_file, cache):
    '''
    saves the file containing the cache data to disk
    '''
    with open(json_file, 'w') as file:
        json.dump(cache, file)

### 2.1.4. `semantic_cache` Class & `ask` Method

Euclidean distance **threshold**:
- **0** means **identical** sentences.
- Only sentences **under** the thresold will be return **from cache**.


`self.index.nprobe`:

**Note:** `nprobe` controls how many clusters (or partitions) FAISS searches over. This parameter is only **relevant** for **partitioned indices** such as `IVF`. It has **no effect** for **flat indices** such as `IndexFlatL2`.

A higher `nprobe`:
- Searches more partitions.
- Improves recall (better accuracy).
- Increases query time.
- Setting `nprobe = 8` balances speed vs. accuracy.


`D, I = self.index.search(embedding, 1)` performs a **nearest-neighbor search** in the FAISS index.
- Input parameters:
    - `embedding`: a **vector** (or batch of vectors) representing the **query**.
    - `1`: requests the single **closest vector** (top-1 result).
- FAISS returns:
    - `D`: **distances** to the nearest neighbors.
    - `I`: **indices** (IDs) of the **nearest vectors** in the index.

In [20]:
class semantic_cache:
  def __init__(self, json_file="cache_file.json", thresold=0.35):
      # Initialize Faiss index with Euclidean distance
      self.index, self.encoder = init_cache()

      # set Euclidean distance threshold
      self.euclidean_threshold = thresold

      self.json_file = json_file
      self.cache = retrieve_cache(self.json_file)

  def ask(self, question: str) -> str:
      '''
      Method to retrieve an answer from the cache or generate a new one
      '''
      start_time = time.time()
      try:
          # get the embeddings corresponding to the user question
          embedding = self.encoder.encode([question])

          # search for the nearest neighbor in the index
          # self.index.nprobe = 8
          D, I = self.index.search(embedding, 1)

          if D[0] >= 0:
              if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
                  row_id = int(I[0][0])

                  print('Answer recovered from Cache. ')
                  print(f'{D[0][0]:.3f} smaller than {self.euclidean_threshold}')
                  print(f'Found cache in row: {row_id} with score {D[0][0]:.3f}')
                  print(f'response_text: {self.cache["response_text"][row_id]}')

                  end_time = time.time()
                  elapsed_time = end_time - start_time
                  print(f"Time taken: {elapsed_time:.3f} seconds")
                  return self.cache['response_text'][row_id]

          '''
          Handle the case when:
          there are not enough results
          or Euclidean distance is not met, asking to chromaDB
          '''
          answer  = query_database([question], 1)
          response_text = answer['documents'][0][0]

          self.cache['questions'].append(question)
          self.cache['embeddings'].append(embedding[0].tolist())
          self.cache['answers'].append(answer)
          self.cache['response_text'].append(response_text)

          print('Answer recovered from ChromaDB.')
          print(f'response_text: {response_text}')

          self.index.add(embedding)
          store_cache(self.json_file, self.cache)
          end_time = time.time()
          elapsed_time = end_time - start_time
          print(f"Time taken: {elapsed_time:.3f} seconds")

          return response_text
      except Exception as e:
          raise RuntimeError(f"Error during 'ask' method: {e}")

## 2.2. Test the semantic_cache Class

In [21]:
# Initialize the cache.
cache = semantic_cache('4cache_file.json')

Index trained


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [22]:
# 1st question
results = cache.ask("How work a vaccine?")

Answer recovered from ChromaDB.
response_text: Why Are Childhood Vaccines So Important? It is always better to prevent a disease than to treat it after it occurs. Diseases that used to be common in this country and around the world, including polio, measles, diphtheria, pertussis (whooping cough), rubella (German measles), mumps, tetanus, rotavirus and Haemophilus influenzae type b (Hib) can now be prevented by vaccination. Thanks to a vaccine, one of the most terrible diseases in history – smallpox – no longer exists outside the laboratory. Over the years vaccines have prevented countless cases of disease and saved millions of lives. Immunity Protects us From Disease Immunity is the body’s way of preventing disease.  Children are born with an immune system composed of cells, glands, organs, and fluids located throughout the body. The immune system recognizes germs that enter the body as "foreign invaders” (called antigens) and produces proteins called antibodies to fight them. The fir

In [23]:
# 2nd question that is quite different from the 1st one
results = cache.ask("Explain briefly what is a Periodic Paralyses")

Answer recovered from ChromaDB.
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is cha

In [24]:
# 3rd question that is very similar to 2nd one
results = cache.ask("Briefly explain me what is a periodic paralyses")

Answer recovered from Cache. 
0.018 smaller than 0.35
Found cache in row: 1 with score 0.018
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle w

In [25]:
# 4th question that is also similar but a bit more distinct
question_def = "Write in 20 words what is a periodic paralyses"
results = cache.ask(question_def)

Answer recovered from Cache. 
0.220 smaller than 0.35
Found cache in row: 1 with score 0.220
response_text: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.
                
The two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle w

# 3. Load the Model & Create the Prompt

## 3.1. Load the Model

In [26]:
from torch import cuda, torch
# In a MAC Silicon: device = torch.device('mps')
device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'

In [27]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map="cuda",
                                            dtype=torch.bfloat16)

tokenizer_config.json:   0%|          | 0.00/34.2k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

## 3.2. Create the Extended Prompt

The prompt is constructed using the result from query the `semantic_cache` class and the user’s question.

In [28]:
prompt_template = f"Relevant context: {results}\n\n The user's question: {question_def}"
prompt_template

"Relevant context: Familial periodic paralyses are a group of inherited neurological disorders caused by mutations in genes that regulate sodium and calcium channels in nerve cells. They are characterized by episodes in which the affected muscles become slack, weak, and unable to contract. Between attacks, the affected muscles usually work as normal.\n                \nThe two most common types of periodic paralyses are: Hypokalemic periodic paralysis is characterized by a fall in potassium levels in the blood. In individuals with this mutation attacks often begin in adolescence and are triggered by strenuous exercise, high carbohydrate meals, or by injection of insulin, glucose, or epinephrine. Weakness may be mild and limited to certain muscle groups, or more severe and affect the arms and legs. Attacks may last for a few hours or persist for several days. Some patients may develop chronic muscle weakness later in life. Hyperkalemic periodic paralysis is characterized by a rise in po

In [29]:
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")

In [30]:
input_ids = tokenizer(prompt_template, return_tensors="pt").to("cuda")

# Takeaways

The performance improvement is approximately 50% when accessing the cache directly compared to querying ChromaDB. In this setup, the amount of **data stored** in ChromaDB is **small**, and only a **single instance** of the **cache class** is used. 

In larger projects, this difference typically increases, leading to improvements of around 90–95%. In real-world systems, the data behind the cache is usually **much larger** and may involve retrieving information from **multiple sources**, not just a vector database. It is also common to deploy **multiple cache instances**, often segmented by user typology, since questions tend to repeat more frequently among users who share similar characteristics.