In [69]:
import pymupdf  # aka fitz
import uuid
import os
from sentence_transformers import SentenceTransformer
from PIL import Image
import io
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
from qdrant_client.models import Distance, VectorParams

In [70]:
import os
os.environ["HF_HUB_READ_TIMEOUT"] = "60"
os.environ["HF_HUB_CONNECT_TIMEOUT"] = "60"

In [71]:
model = SentenceTransformer('clip-ViT-B-32')
client = QdrantClient("http://localhost:6333")

Loading weights: 100%|██████████| 398/398 [00:01<00:00, 274.34it/s, Materializing param=visual_projection.weight]                                
CLIPModel LOAD REPORT from: clip-ViT-B-32/0_CLIPModel
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
vision_model.embeddings.position_ids | UNEXPECTED |  | 
text_model.embeddings.position_ids   | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [72]:
def initialise_db():
    client = QdrantClient("http://localhost:6333")
    if not client.collection_exists(collection_name="col_1"):
        client.create_collection(
        collection_name="col_1",
        vectors_config=VectorParams(size=512, distance=Distance.COSINE)
        )


    return client

In [None]:
import pymupdf
import uuid
import os
import io
from PIL import Image
from sentence_transformers import SentenceTransformer
from qdrant_client.models import PointStruct

def populate_db(doc_path, client):
    model = SentenceTransformer("clip-ViT-B-32")
    doc = pymupdf.open(doc_path)

    # -------- TEXT --------
    for page_no, page in enumerate(doc):
        page_text = page.get_text().strip()
        if not page_text:
            continue

        text_embedding = model.encode(page_text).tolist()

        client.upsert(
            collection_name="col_1",
            points=[PointStruct(
                id=str(uuid.uuid4()),
                vector=text_embedding,
                payload={
                    "page_no": page_no,
                    "text": page_text,
                    "type": "text"
                }
            )]
        )

    # -------- IMAGES --------
    os.makedirs("images")

    for page_no, page in enumerate(doc):
        for img in page.get_images(full=True):
            xref = img[0]

            base = doc.extract_image(xref)
            image_bytes = base["image"]
            image_ext = base["ext"]

            filename = f"{uuid.uuid4()}.{image_ext}"
            filepath = f"images/{filename}"

            with open(filepath, "wb") as f:
                f.write(image_bytes)

            pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")


            image_embedding = model.encode(pil_img).tolist()

            client.upsert(
                collection_name="col_1",
                points=[PointStruct(
                    id=str(uuid.uuid4()),
                    vector=image_embedding,
                    payload={
                        "page_no": page_no,
                        "filename": filename,
                        "type": "image"
                    }
                )]
            )


In [74]:
def get_topk(client:QdrantClient, query_vector):

    """
    Takes in a QdrantClient object and a query vector, and returns the top k nearest vectors to it
    """

    results = client.query_points(
        collection_name="col_1",
        query=query_vector,
        limit=50
    )

    return results

In [76]:
def main(doc_path, query):

    # DOC_PATH = '../docs/problem_statement.pdf'
    
    client = initialise_db()
    populate_db(doc_path=doc_path, client=client)

    model = SentenceTransformer('clip-ViT-B-32')
    query_embedding = model.encode(query)

    top_k = get_topk(client, query_embedding)

    for item in top_k.points:
        print(f"page_no = {item.payload['page_no']}, score = {item.score}, type = {item.payload['type']} \n")

In [77]:
main('./docs/ss.pdf', 'green grass')

Loading weights: 100%|██████████| 398/398 [00:01<00:00, 278.09it/s, Materializing param=visual_projection.weight]                                
CLIPModel LOAD REPORT from: clip-ViT-B-32/0_CLIPModel
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
vision_model.embeddings.position_ids | UNEXPECTED |  | 
text_model.embeddings.position_ids   | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
Loading weights: 100%|██████████| 398/398 [00:00<00:00, 653.51it/s, Materializing param=visual_projection.weight]                                
CLIPModel LOAD REPORT from: clip-ViT-B-32/0_CLIPModel
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
vision_model.embeddings.position_ids | UNEXPECTED |  | 
text_model.embeddings.position_ids   | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ig

page_no = 15, score = 0.7725273, type = text 

page_no = 4, score = 0.7016804, type = text 

page_no = 3, score = 0.68394315, type = text 

page_no = 5, score = 0.67258656, type = text 

page_no = 6, score = 0.67138594, type = text 

page_no = 7, score = 0.66847545, type = text 

page_no = 8, score = 0.66141737, type = text 

page_no = 9, score = 0.6606197, type = text 

page_no = 12, score = 0.6599006, type = text 

page_no = 11, score = 0.6529188, type = text 

page_no = 13, score = 0.6067456, type = text 

page_no = 14, score = 0.5983734, type = text 

page_no = 0, score = 0.57742244, type = text 

page_no = 10, score = 0.54114044, type = text 

page_no = 2, score = 0.521142, type = text 

page_no = 1, score = 0.50299656, type = text 

page_no = 14, score = 0.26119924, type = image 

page_no = 2, score = 0.2580132, type = image 

page_no = 0, score = 0.25643256, type = image 

page_no = 4, score = 0.23159611, type = image 

page_no = 6, score = 0.22593552, type = image 

page_no = 7