In [16]:
import os
import logging
import numpy as np
import pandas as pd
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

# --- Configuration ---
EMBEDDING_DIMENSION = 768
COLLECTION_NAME = "medical_conversations_rag"
ZILLIZ_URI = os.getenv("ZILLIZ_URI", "https://in05-ea2907a9e50a525.serverless.gcp-us-west1.cloud.zilliz.com")
ZILLIZ_TOKEN = os.getenv("ZILLIZ_TOKEN", "b9b45871109db0eb4b3e084c2590479dc8e7d3b0ec126bc80489673fa9d400adec193e3737875fd7a172fe71c574b84c2a756214")

# --- Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("zillis_rag")

# --- Connect to Zilliz/Milvus ---
connections.connect(uri=ZILLIZ_URI, token=ZILLIZ_TOKEN)
logger.info(f"Connected to Zilliz/Milvus at {ZILLIZ_URI}")


INFO:zillis_rag:Connected to Zilliz/Milvus at https://in05-ea2907a9e50a525.serverless.gcp-us-west1.cloud.zilliz.com


In [25]:
# --- Define Collection Schema ---
fields = [
    FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False),
    FieldSchema(name="qtype", dtype=DataType.VARCHAR, max_length=32, is_primary=False),
    FieldSchema(name="Question", dtype=DataType.VARCHAR, max_length=1024, is_primary=False),
    FieldSchema(name="Answer", dtype=DataType.VARCHAR, max_length=65000, is_primary=False),
    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=EMBEDDING_DIMENSION, is_primary=False),
    FieldSchema(name="chunk_index", dtype=DataType.INT64, is_primary=False),
    FieldSchema(name="answer_group_id", dtype=DataType.VARCHAR, max_length=64, is_primary=False),
]
schema = CollectionSchema(fields, description="RAG QnA collection with chunking")


# --- Drop and Create Collection ---
if COLLECTION_NAME in utility.list_collections():
    Collection(COLLECTION_NAME).drop()
collection = Collection(COLLECTION_NAME, schema)
logger.info(f"Collection '{COLLECTION_NAME}' created.")


INFO:zillis_rag:Collection 'medical_conversations_rag' created.


In [26]:
import uuid

# --- Chunking utility ---
def chunk_text(text, chunk_size):
    return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]

# --- Prepare chunked data for insertion ---
chunk_size = 65000
ids = []
qtypes = []
questions = []
answers = []
embeds = []
chunk_indexes = []
group_ids = []

for idx, row in df.iterrows():
    answer = str(row['Answer'])
    answer_chunks = chunk_text(answer, chunk_size)
    group_id = str(uuid.uuid4())
    for chunk_idx, chunk in enumerate(answer_chunks):
        ids.append(str(uuid.uuid4()))
        qtypes.append(str(row['qtype']))
        questions.append(str(row['Question']))
        answers.append(chunk)
        chunk_indexes.append(chunk_idx)
        group_ids.append(group_id)
        # Use the same embedding for all chunks of the same answer
        # (or recompute for each chunk if you prefer)

# Generate embeddings for all questions (repeat for each chunk)
embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
question_embeddings = embedding_model.embed_documents(questions)
embeds = np.array(question_embeddings, dtype=np.float32)

insert_data = [
    ids,
    qtypes,
    questions,
    answers,
    embeds.tolist(),
    chunk_indexes,
    group_ids
]


In [28]:
batch_size = 500
for start in range(0, len(ids), batch_size):
    end = start + batch_size
    batch = [
        ids[start:end],
        qtypes[start:end],
        questions[start:end],
        answers[start:end],
        embeds[start:end].tolist(),
        chunk_indexes[start:end],
        group_ids[start:end]
    ]
    collection.insert(batch)
collection.flush()
logger.info(f"Inserted {len(ids)} rows (with chunking) into '{COLLECTION_NAME}'.")


INFO:zillis_rag:Inserted 16407 rows (with chunking) into 'medical_conversations_rag'.


In [30]:
# --- Create index on embedding field ---
index_params = {
    "metric_type": "L2",  # or "COSINE" if you prefer
    "index_type": "IVF_FLAT",  # or "HNSW", "AUTOINDEX", etc. depending on your needs
    "params": {"nlist": 1024}
}
collection.create_index(field_name="embedding", index_params=index_params)
logger.info("Index created on embedding field.")

# --- Load Collection for Search ---
collection.load()
logger.info("Collection loaded for search.")


INFO:zillis_rag:Index created on embedding field.
INFO:zillis_rag:Collection loaded for search.
INFO:zillis_rag:Collection loaded for search.


In [37]:
# --- RAG Chain Setup ---
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash-latest", temperature=0.3)
prompt_template = ChatPromptTemplate.from_template(
    """
    **Note:** You are an AI assistant using only the provided context from a dataset of patient-doctor conversations. You are not a doctor. Do not provide medical advice beyond the context. If the context lacks information, say so. Use the conversation history to understand references (e.g., pronouns like 'it') if relevant.

    **Conversation History (Recent Queries and Answers):**
    {history}

    **Context (from dataset):**
    {context}

    **Question:**
    {input}

    **Answer (based only on context, using history for clarity if needed):**
    If the question is about hospitals, doctors, or appointments, respond: "Please use the appointment booking system for this query."
    Otherwise, provide an answer based on the context.
    """
)
document_chain = create_stuff_documents_chain(llm, prompt_template)

# --- Retrieval Function ---
from langchain_core.documents import Document

def retrieve_context(query, top_k=3):
    query_emb = embedding_model.embed_query(query)
    results = collection.search(
        data=[query_emb],
        anns_field="embedding",
        param={"metric_type": "L2", "params": {"nprobe": 10}},
        limit=top_k,
        output_fields=["qtype", "Question", "Answer", "chunk_index", "answer_group_id"]
    )
    docs = []
    for hits in results:
        for hit in hits:
            docs.append(Document(
                page_content=getattr(hit.entity, "Answer", ""),
                metadata={
                    "qtype": getattr(hit.entity, "qtype", ""),
                    "Question": getattr(hit.entity, "Question", ""),
                    "chunk_index": getattr(hit.entity, "chunk_index", 0),
                    "answer_group_id": getattr(hit.entity, "answer_group_id", "")
                }
            ))
    return docs

# --- RAG QA Function ---
def rag_qa(question, history=""):
    context = retrieve_context(question)
    input_vars = {
        "history": history,
        "context": context,  # pass as list of Document
        "input": question
    }
    return document_chain.invoke(input_vars)

# --- Example Usage ---
response = rag_qa("What is cancer?")
print(response)




Cancer is a disease where cells grow and multiply uncontrollably, forming a mass called a tumor.  These tumors can be benign (not cancerous) or malignant (cancerous). Malignant tumors can invade nearby tissues and spread to other parts of the body (metastasis).  There are over 100 different types of cancer, most named for where they originate (e.g., lung cancer).  The growth of cancer cells is uncontrolled, unlike normal cells which grow and die in a controlled manner.
