Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions bootstraprag/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def create_zip(project_name):

@click.command()
@click.argument('project_name')
@click.option('--framework', type=click.Choice([]),prompt=False)
@click.option('--framework', type=click.Choice([]), prompt=False)
@click.option('--template', type=click.Choice([]), prompt=False)
@click.option('--observability', type=click.Choice([]), prompt=False)
def create(project_name, framework, template, observability):
Expand All @@ -40,7 +40,8 @@ def create(project_name, framework, template, observability):
]
elif framework == 'None':
framework = 'qdrant'
template_choices = ['simple-search', 'hybrid-search', 'hybrid-search-advanced']
template_choices = ['simple-search', 'multimodal-search', 'hybrid-search', 'hybrid-search-advanced',
'retrieval-quality']
# Use InquirerPy to select template with arrow keys
template = inquirer.select(
message="Which template would you like to use?",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
qdrant-client==1.11.1
qdrant-client==1.11.3
python-dotenv==1.0.1
fastembed==0.3.6
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
qdrant-client==1.11.1
qdrant-client==1.11.3
python-dotenv==1.0.1
fastembed==0.3.6
datasets==3.0.0
datasets==3.0.1
5 changes: 5 additions & 0 deletions bootstraprag/templates/qdrant/multimodal_search/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DB_URL='http://localhost:6333'
DB_API_KEY='th3s3cr3tk3y'
COLLECTION_NAME='YOUR_COLLECTION'
HF_TOKEN='hf_'
TOKENIZERS_PARALLELISM=false
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 21 additions & 0 deletions bootstraprag/templates/qdrant/multimodal_search/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from multimodal_search import MultiModalSearch


documents = [{"caption": "An Architecture describing MediaQ platform",
"image": "images/MediaQ.png"},
{"caption": "An Architecture describing the Advanced RAG",
"image": "images/adv-RAG.png"},
{"caption": "An Architecture describing Vision based RAG",
"image": "images/VisionRAG.png"}
]

mm_search = MultiModalSearch(documents=documents)
# mm_search.search_image_by_text(user_query="propose an advanced RAG architecture")
comment = mm_search.search_text_by_image(image_path='images/VisionRAG.png')
print(comment)






Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from fastembed import TextEmbedding, ImageEmbedding
from qdrant_client import QdrantClient, models
from PIL import Image
from typing import List


class MultiModalSearch:
def __init__(self, documents: List[dict]):
self.documents = documents
text_model_name = "Qdrant/clip-ViT-B-32-text" # CLIP text encoder
self.text_model = TextEmbedding(model_name=text_model_name)
self.text_embeddings_size = self.text_model._get_model_description(text_model_name)[
"dim"] # dimension of text embeddings, produced by CLIP text encoder (512)
self.texts_embeded = list(
self.text_model.embed(
[document["caption"] for document in documents])) # embedding captions with CLIP text encoder

image_model_name = "Qdrant/clip-ViT-B-32-vision" # CLIP image encoder
self.image_model = ImageEmbedding(model_name=image_model_name)
self.image_embeddings_size = self.image_model._get_model_description(image_model_name)[
"dim"] # dimension of image embeddings, produced by CLIP image encoder (512)
self.images_embeded = list(
self.image_model.embed(
[document["image"] for document in documents])) # embedding images with CLIP image encoder

self.client = QdrantClient(url="http://localhost:6333", api_key="th3s3cr3tk3y")

# this method will create the collection if dones not exist and inserts the data into it
def _create_and_insert(self):
if not self.client.collection_exists("text_image"): # creating a Collection
self.client.create_collection(
collection_name="text_image",
vectors_config={ # Named Vectors
"image": models.VectorParams(size=self.image_embeddings_size, distance=models.Distance.COSINE),
"text": models.VectorParams(size=self.text_embeddings_size, distance=models.Distance.COSINE),
}
)

self.client.upload_points(
collection_name="text_image",
points=[
models.PointStruct(
id=idx, # unique id of a point, pre-defined by the user
vector={
"text": self.texts_embeded[idx], # embeded caption
"image": self.images_embeded[idx] # embeded image
},
payload=doc # original image and its caption
)
for idx, doc in enumerate(self.documents)
]
)

def search_image_by_text(self, user_query: str):
find_image = self.text_model.embed(
[
"suggest an architecture for designing Vision RAG platform"]) # query, we embed it, so it also becomes a vector

image_path = self.client.search(
collection_name="text_image", # searching in our collection
query_vector=("image", list(find_image)[0]), # searching only among image vectors with our textual query
with_payload=["image"],
# user-readable information about search results, we are interested to see which image we will find
limit=1 # top-1 similar to the query result
)[0].payload['image']

Image.open(image_path).show()

def search_text_by_image(self, image_path: str):
find_image = self.image_model.embed([image_path]) # embedding our image query

response = self.client.search(
collection_name="text_image",
query_vector=("text", list(find_image)[0]),
# now we are searching only among text vectors with our image query
with_payload=["caption"],
# user-readable information about search results, we are interested to see which caption we will get
limit=1
)[0].payload['caption']

return response
29 changes: 29 additions & 0 deletions bootstraprag/templates/qdrant/multimodal_search/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## What is this project all about

this is a bootstrap project using bootstrap-rag cli tool. This project assume you have docker for desktop installed in your machine.

### Project scaffolding
```
.
├── __init__.py
├── __pycache__
├── .env
├── main.py
├── readme.md
├── requirements.txt
└── measure_retrieval_quality.py
```
- docker-compose.yml: if your machine does not have qdrant installed don't worry run this `docker-compose-dev.yml` in setups folder
- `docker-compose -f docker-compose-dev.yml up -d`
- requirements.txt: this file has all the dependencies that a project need
- measure_retrieval_quality.py: the core logic for retrieval evaluation is present in this file
- main.py: this is the driver code to test.

### How to bring in your own custom logics
- open `measure_retrieval_quality.py` and modify your `_upset_and_index` and `compute_avg_precision_at_k` functions.

or

- create a `new_search_file.py` and extend it from `measure_retrieval_quality.py` then override the base functionality in the new one.


Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
qdrant-client==1.11.3
python-dotenv==1.0.1
fastembed==0.3.6
pillow==10.4.0
4 changes: 4 additions & 0 deletions bootstraprag/templates/qdrant/retrieval_quality/.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
DB_URL='http://localhost:6333'
DB_API_KEY='th3s3cr3tk3y'
COLLECTION_NAME='YOUR_COLLECTION'
HF_TOKEN='hf_'
Empty file.
12 changes: 12 additions & 0 deletions bootstraprag/templates/qdrant/retrieval_quality/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from measure_retrieval_quality import MeasureRetrievalQuality

measure_rq = MeasureRetrievalQuality(collection_name='arxiv-titles-instructorxl-embeddings',
dataset_path='Qdrant/arxiv-titles-instructorxl-embeddings')

# before tuning
print(f"avg(precision@5) = {measure_rq.compute_avg_precision_at_k(k=5)}")

measure_rq.tune_hnsw_configs()

# after tuning
print(f"avg(precision@5) = {measure_rq.compute_avg_precision_at_k(k=5)}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os

from datasets import load_dataset
from qdrant_client import QdrantClient, models
from qdrant_client.conversions.common_types import CollectionInfo
from dotenv import load_dotenv, find_dotenv


class MeasureRetrievalQuality:
def __init__(self, dataset_path: str, collection_name: str, streaming: bool = True):
# path = "Qdrant/arxiv-titles-instructorxl-embeddings"
dataset = load_dataset(
path=dataset_path, split="train", streaming=True,
token=os.environ.get('HF_TOKEN')
)
self.collection_name = collection_name or os.environ.get('COLLECTION_NAME')
dataset_iterator = iter(dataset)
self.train_dataset = [next(dataset_iterator) for _ in range(10000)]
self.test_dataset = [next(dataset_iterator) for _ in range(1000)]
self.client = QdrantClient(url=os.environ.get('DB_URL'), api_key=os.environ.get('DB_API_KEY'))

self._upset_and_index()

def _upset_and_index(self):

if not self.client.collection_exists(collection_name=self.collection_name):
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=768, # Size of the embeddings generated by InstructorXL model
distance=models.Distance.COSINE,
)
)

self.client.upload_points( # upload_points is available as of qdrant-client v1.7.1
collection_name=self.collection_name,
points=[
models.PointStruct(
id=item["id"],
vector=item["vector"],
payload=item,
)
for item in self.train_dataset
]
)

# Collection status is green, which means the indexing is finished
while True:
collection_info = self.client.get_collection(collection_name=self.collection_name)
if collection_info.status == models.CollectionStatus.GREEN:
break

def compute_avg_precision_at_k(self, k: int):
precisions = []
for item in self.test_dataset:
ann_result = self.client.query_points(
collection_name=self.collection_name,
query=item["vector"],
limit=k,
).points

knn_result = self.client.query_points(
collection_name=self.collection_name,
query=item["vector"],
limit=k,
search_params=models.SearchParams(
exact=True, # Turns on the exact search mode
),
).points

# We can calculate the precision@k by comparing the ids of the search results
ann_ids = set(item.id for item in ann_result)
knn_ids = set(item.id for item in knn_result)
precision = len(ann_ids.intersection(knn_ids)) / k
precisions.append(precision)

return sum(precisions) / len(precisions)

def tune_hnsw_configs(self):
# Tweaking the HNSW parameters
self.client.update_collection(
collection_name=self.collection_name,
hnsw_config=models.HnswConfigDiff(
m=32, # Increase the number of edges per node from the default 16 to 32
ef_construct=200, # Increase the number of neighbours from the default 100 to 200
)
)

# Collection status is green, which means the indexing is finished
while True:
collection_info = self.client.get_collection(collection_name=self.collection_name)
if collection_info.status == models.CollectionStatus.GREEN:
break
29 changes: 29 additions & 0 deletions bootstraprag/templates/qdrant/retrieval_quality/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## What is this project all about

this is a bootstrap project using bootstrap-rag cli tool. This project assume you have docker for desktop installed in your machine.

### Project scaffolding
```
.
├── __init__.py
├── __pycache__
├── .env
├── main.py
├── readme.md
├── requirements.txt
└── measure_retrieval_quality.py
```
- docker-compose.yml: if your machine does not have qdrant installed don't worry run this `docker-compose-dev.yml` in setups folder
- `docker-compose -f docker-compose-dev.yml up -d`
- requirements.txt: this file has all the dependencies that a project need
- measure_retrieval_quality.py: the core logic for retrieval evaluation is present in this file
- main.py: this is the driver code to test.

### How to bring in your own custom logics
- open `measure_retrieval_quality.py` and modify your `_upset_and_index` and `compute_avg_precision_at_k` functions.

or

- create a `new_search_file.py` and extend it from `measure_retrieval_quality.py` then override the base functionality in the new one.


Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
qdrant-client==1.11.3
python-dotenv==1.0.1
fastembed==0.3.6
datasets==3.0.1
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
qdrant-client==1.10.1
qdrant-client==1.11.3
python-dotenv==1.0.1