In [None]:
# bre install poppler libomp
#pip install colpali_engine qdrant-client stamina rich einops hfxet qwen-vl-utils byaldi torch==2.5.1 torchvision pymupdf

In [None]:
import base64, fitz, stamina, time, torch

from colpali_engine.models import ColQwen2, ColQwen2Processor
from io import BytesIO
from IPython.display import Markdown, display
from pdf2image import convert_from_path
from PIL import Image
from qdrant_client import QdrantClient, models
from qwen_vl_utils import process_vision_info
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoTokenizer
from transformers.utils.import_utils import is_flash_attn_2_available

In [None]:
@stamina.retry(on=Exception, attempts=3)
def upsert_to_qdrant(client, collection_name, batch):
    try:
        client.upsert(
            collection_name=collection_name,
            points=batch,
            wait=False,
        )
    except Exception as e:
        print(f"Error during upsert: {e}")
        return False
    return True

In [None]:
def conver_pdf_to_images(pdf_path):
    images = []
    document = fitz.open(pdf_path)
    for page_num in range(len(document)):
        page = document.load_page(page_num)
        pix = page.get_pixmap()
        img = Image.open(BytesIO(pix.tobytes("png")))
        images.append(img)
    return images

In [None]:
def printmd(string):
    display(Markdown(string))

In [None]:
collection_name = "vision_rag"
client = QdrantClient(path=f"/Users/sachinjalota/Documents/Codes/qdrant_db/{collection_name}")

# client = QdrantClient(
#     url="https://29aa936c-6e75-4f42-9cb6-c91d9fac98bf.europe-west3-0.gcp.cloud.qdrant.io:6333", 
#     api_key="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.AfL-bLBWLlm7hhOItoIGJ3PXJxAT4XhwmMW_Mi5tGgs",
# )

collections = client.get_collections().collections
collection_names = [col.name for col in collections]

if collection_name not in collection_names:
    client.create_collection(
        collection_name=collection_name,
        on_disk_payload=True,  # store the payload on disk
        vectors_config=models.VectorParams(
            size=128,
            distance=models.Distance.COSINE,
            on_disk=True, # move original vectors to disk
            multivector_config=models.MultiVectorConfig(
                comparator=models.MultiVectorComparator.MAX_SIM
            ),
            quantization_config=models.BinaryQuantization(
            binary=models.BinaryQuantizationConfig(
                always_ram=True  # keep only quantized vectors in RAM
                ),
            ),
        ),
    )

In [None]:
model_name = 'vidore/colqwen2-v1.0'
model = ColQwen2.from_pretrained(model_name, 
                                 torch_dtype=torch.bfloat16, 
                                 device_map="mps",)
model = model.eval()

processor = ColQwen2Processor.from_pretrained(model_name)
model

In [None]:
gen_model_name = "vidore/colqwen2-base"
# gen_model = Qwen2VLForConditionalGeneration.from_pretrained(gen_model_name, 
#                                                             torch_dtype=torch.bfloat16).to('mps').eval()
gen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", 
                                                               torch_dtype=torch.bfloat16, 
                                                               attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,).to('mps').eval()

gen_processor_name = "Qwen/Qwen2-VL-2B-Instruct"
max_pixels = 512*28*28 
# gen_processor = AutoProcessor.from_pretrained(gen_processor_name, 
#                                               max_pixels=max_pixels)
gen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", 
                                              max_pixels=max_pixels)

In [None]:
pdf_path = ['/Users/sachinjalota/Downloads/ruhe_catalogue.pdf', '/Users/sachinjalota/Downloads/nike.pdf', '/Users/sachinjalota/Downloads/RAG_Evaluation.pdf']
# pdf_path = ['/Users/sachinjalota/Downloads/nike.pdf']

images_lst = []
for doc in pdf_path:
    images_lst.extend(conver_pdf_to_images(doc))

batch_size = 4

In [None]:
dataloader = DataLoader(
    dataset=images_lst,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda x: processor.process_images(x),
)

ds  = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    ds.extend(list(torch.unbind(embeddings_doc.to("mps"))))

points = []
for j, embedding in enumerate(ds):
    multivector = embedding.cpu().float().numpy().tolist()
    points.append(
        models.PointStruct(
            id=j,
            vector=multivector,
            payload={
                "source": "internet archive"
            },
        )
    )

try:
    upsert_to_qdrant(client, collection_name, points)
except Exception as e:
    print(f"Error during upsert: {e}")

print("Indexing complete!")

# with tqdm(total=len(images_lst), desc="Indexing Progress") as pbar:
#     for i in range(0, len(images_lst), batch_size):
#         batch = images_lst[i : i + batch_size]

#         with torch.no_grad():
#             batch_images = processor.process_images(batch).to(
#                 model.device
#             )
#             image_embeddings = model(**batch_images)

#         points = []
#         for j, embedding in enumerate(image_embeddings):
#             multivector = embedding.cpu().float().numpy().tolist()
#             points.append(
#                 models.PointStruct(
#                     id=i + j,
#                     vector=multivector,
#                     payload={
#                         "source": "internet archive"
#                     },
#                 )
#             )

#         try:
#             upsert_to_qdrant(client, collection_name, points)
#         except Exception as e:
#             print(f"Error during upsert: {e}")
#             continue

#         pbar.update(batch_size)

# print("Indexing complete!")

In [None]:
collection_info_before = client.get_collection(collection_name)
print("Collection info before update:", collection_info_before)

result = client.update_collection(
    collection_name=collection_name,
    optimizer_config=models.OptimizersConfigDiff(indexing_threshold=10),
)
print("Collection update result:", result)

collection_info_after = client.get_collection(collection_name)
print("Collection info after update:", collection_info_after)

In [None]:
query_text = "give me revenue sku wise"
with torch.no_grad():
    batch_query = processor.process_queries([query_text]).to(
        model.device
    )
    query_embedding = model(**batch_query)
multivector_query = query_embedding[0].cpu().float().numpy().tolist()
multivector_query

In [None]:
start_time = time.time()
search_result = client.query_points(
    collection_name=collection_name,
    query=multivector_query,
    limit=10,
    timeout=100,
    search_params=models.SearchParams(
        quantization=models.QuantizationSearchParams(
            ignore=False,
            rescore=True,
            oversampling=2.0,
        )
    )
)
end_time = time.time()
search_result.points

elapsed_time = end_time - start_time
print(f"Search completed in {elapsed_time:.4f} seconds")

In [None]:
search_result.points

In [None]:
idx = search_result.points[0].id
images_lst[idx]

In [None]:
PROMPT = """
You are an advanced assistant capable of understanding text, images, and tables from documents. Your task is to answer the query using the provided PDF pages with multi-modal data. Follow these steps:

1. Focus only on the provided pages and avoid making assumptions.
2. Identify the type of information required (text, image, table).
3. Extract relevant data from the PDF pages based on the query and analyze thoroughly.
4. If applicable, describe visual elements (e.g., charts, diagrams) in detail.
5. Provide a clear and accurate answer, ensuring it is grounded in the provided pages.
6. If the query cannot be answered based on the pages, explain why and suggest alternative sources.

Query: {query}

Your response:
"""


In [None]:
def get_answer_local(query: str, max_token: int):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": images_lst[idx],
                },
                {"type": "text", "text": PROMPT.format(query=query)},
            ],
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)
    inputs = gen_processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("mps")

    generated_ids = gen_model.generate(**inputs, max_new_tokens=max_token)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = gen_processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text

In [None]:
answer = get_answer_local(query_text, 500)[0]
printmd(answer)