Testing out ways to extract information from given documents and store it into vector stores

In [5]:
import PyPDF2
import re
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Load the PDF document
pdf_path = "/home/vidhij2/nivi/9789240045989-eng.pdf"
with open(pdf_path, "rb") as file:
    reader = PyPDF2.PdfReader(file)
    text = "\n".join([page.extract_text() for page in reader.pages if page.extract_text()])
    print(text)

# Define a regex pattern to detect sections
section_pattern = r"(\n\d+\.\s[A-Z][^\n]+)"  # Matches headings like "1. Introduction"

# Split text based on sections
sections = re.split(section_pattern, text)
# print(sections)
# Create structured chunks
structured_chunks = []
current_section = "Introduction"

for i in range(len(sections)):
    if re.match(section_pattern, sections[i]):  # If it's a section heading
        current_section = sections[i].strip()
    else:
        structured_chunks.append({"section": current_section, "content": sections[i].strip()})

# Use RecursiveCharacterTextSplitter to split into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=100)
final_chunks = []

for chunk in structured_chunks:
    sub_chunks = text_splitter.split_text(chunk["content"])
    for sub in sub_chunks:
        final_chunks.append({"section": chunk["section"], "content": sub})

# Display some example chunks


KeyboardInterrupt: 

In [6]:
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Define the PDF file path
pdf_path = "/home/vidhij2/nivi/By-gestational month cards.pdf"

# Load the PDF using LangChain's PyPDFLoader
loader = PyPDFLoader(pdf_path)
pages = loader.load()

# Combine all text from pages
full_text = "\n".join([page.page_content for page in pages])

# Define the chunking strategy
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=512,  # Number of characters per chunk
    chunk_overlap=100,  # Overlap to maintain context
    separators=["\n\n", "\n", " ", ""],  # Prioritize splitting by paragraphs
)

# Split text into structured chunks
chunks = text_splitter.split_text(full_text)

# Store chunks in a structured format
chunked_data = [{"chunk_id": idx, "content": chunk} for idx, chunk in enumerate(chunks)]

# Display some chunk examples
for i in range(3):  # Display first 3 chunks
    print(f"Chunk {i+1}:\n{chunked_data[i]['content']}\n{'-'*50}")


Chunk 1:
Maternal nutrition for 
safe motherhood
 Messages: By-gestational month
Cards
1  1 st Month
2  2 nd Month
3  3 rd Month
4  4 th Month
5  5 th Month
6  6 th Month
7  7 th Month
8  8 th Month
9  9 th Month
10  Common Messages 1
11  Common Messages 2
2–3
4–5
6–7
8–9
10–11
12–13
14–15
16–17
18–19
20–21
22–23
No.  Gestational month Page No.
Instructions for users
 Who is the user of these cards?
--------------------------------------------------
Chunk 2:
22–23
No.  Gestational month Page No.
Instructions for users
 Who is the user of these cards? 
Auxiliary Nurse Midwife (ANM) in health facility and during outreach activities like Village Health Sanitation and Nutrition Days (VHSNDs); to be 
used with pregnant women, and also their husband or other family members, if available
In situations where ANMs are unable to provide counselling services, a trained ASHA or other nurses should provide the counselling
 How much time is needed to go through a card?
---------------------------

In [7]:
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

# Load OpenAI embeddings (Replace with any embedding model)
embeddings = OpenAIEmbeddings()

# Store chunks in FAISS
vector_store = FAISS.from_texts(
    [chunk["content"] for chunk in chunked_data], 
    embedding=embeddings
)

# Save FAISS index
vector_store.save_local("gestational_rag_index")


  embeddings = OpenAIEmbeddings()


ValidationError: 1 validation error for OpenAIEmbeddings
  Value error, Did not find openai_api_key, please add an environment variable `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a named parameter. [type=value_error, input_value={'model_kwargs': {}, 'cli...20, 'http_client': None}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.10/v/value_error

In [1]:
import pdfplumber
import re
import pandas as pd
# import ace_tools as tools

# Define the PDF file path
pdf_path = "/home/vidhij2/nivi/By-gestational month cards.pdf"

# Extract text from the PDF
text_data = []
with pdfplumber.open(pdf_path) as pdf:
    for page in pdf.pages:
        text = page.extract_text()
        if text:
            text_data.append(text)

# Combine extracted text
full_text = "\n".join(text_data)

# Define a regex pattern to detect major sections (e.g., "1st Month", "2nd Month", etc.)
section_pattern = r"(\d{1,2}(?:st|nd|rd|th)\sMonth|Common Messages \d+|Counselling Tips)"

# Split text based on detected sections
sections = re.split(section_pattern, full_text)

# Organizing extracted text into structured chunks
structured_chunks = []
current_section = "Introduction"

for i in range(len(sections)):
    if re.match(section_pattern, sections[i]):  # If it's a section heading
        current_section = sections[i].strip()
    else:
        structured_chunks.append({"section": current_section, "content": sections[i].strip()})

# Function to split text into smaller chunks while maintaining overlap
def split_text_into_chunks(text, chunk_size=512, overlap=100):
    words = text.split()
    chunks = []
    start = 0

    while start < len(words):
        end = min(start + chunk_size, len(words))
        chunk = " ".join(words[start:end])
        chunks.append(chunk)
        start = end - overlap  # Ensure overlap for better continuity

    return chunks

# Apply chunking to extracted sections
final_chunks = []
for chunk in structured_chunks:
    sub_chunks = split_text_into_chunks(chunk["content"])
    for sub in sub_chunks:
        final_chunks.append({"section": chunk["section"], "content": sub})

# Convert to DataFrame for better visualization
# df_chunks = pd.DataFrame(final_chunks)

# Display chunked content to the user
# tools.display_dataframe_to_user(name="Chunked Gestational Month Cards", dataframe=df_chunks)


: 

In [11]:
structured_chunks

[{'section': 'Introduction',
  'content': 'WHO recommendations on  \nmaternal and newborn care for  \na positive postnatal experience\nWHO recommendations on \nmaternal and newborn care for  \na positive postnatal experience\nWHO recommendations on maternal and newborn care for a positive postnatal experience\nThis publication is the update of the document published in 2014 entitled “WHO recommendations on postnatal \ncare of the mother and newborn”.\nISBN 978-92-4-004598-9 (electronic version)\nISBN 978-92-4-004599-6 (print version)\n© World Health Organization 2022\nSome rights reserved. This work is available under the Creative Commons Attribution-NonCommercial-ShareAlike 3.0 IGO licence (CC BY-NC-SA 3.0 IGO; https:/ / creativecommons.org/licenses/by-nc-sa/3.0/igo ). \nUnder the terms of this licence, you may copy, redistribute and adapt the work for non-commercial purposes, provided the work is appropriately cited, as indicated below. In any use of this work, there should be no sug

In [None]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

# Load embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")

# Convert text chunks into embeddings
embeddings = np.array([model.encode(chunk["content"]) for chunk in final_chunks])

# Store in FAISS index
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings)

# Save FAISS index
faiss.write_index(faiss_index, "medical_rag_index.faiss")


In [9]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np

# Load embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")

# Convert text chunks into embeddings
embeddings = np.array([model.encode(chunk["content"]) for chunk in chunked_data])

# Store in FAISS index
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings)

# Save FAISS index
faiss.write_index(faiss_index, "medical_rag_index_1.faiss")


In [14]:
import faiss

# Load FAISS index
faiss_index = faiss.read_index("medical_rag_index_1.faiss")

query = "What dietary recommendations are given for the 6th month of pregnancy?"
query_embedding = model.encode(query).reshape(1, -1)

# Search for top 3 relevant chunks
D, I = faiss_index.search(query_embedding, 3)

# Retrieve relevant text
retrieved_chunks = [chunked_data[idx] for idx in I[0]]

for chunk in retrieved_chunks:
    # print(f"Section: {chunk['section']}\nContent: {chunk['content']}\n")
    print(chunk)


{'chunk_id': 28, 'content': '(kulthi), soyabean, sweet potato, banana, groundnut, walnut, milk, paneer, egg, chicken, jaggery/sugar, ﬁ sh, and oils in your diet\n3.  Ensure that you eat at least 3 main meals and 2 nutritious snacks to meet increased nutrient requirements\n4.  Continue consuming one IFA tablet with water or lemon juice 1 hour after a meal\n5.  Consume 2 calcium tablets every day with water or milk immediately after meals\nIf the pregnant woman is receiving her ﬁ rst ANC contact in this month, then also refer to ‘What is'}
{'chunk_id': 55, 'content': 'Energy, Vitamin A, Vitamin D, Vitamin E, \nProtein, Fat, Essential fatty acids\n3.  No food should be restricted during pregnancy; at the same time, under or over consumption of any food should be avoided as  it may lead to \nmalnutrition in pregnant woman and her baby\n4.  Always use double forti ﬁ ed salt (iodine and iron); iodine is very critical for your baby’s brain development and iron is necessary for blood'}
{'chunk

In [10]:
final_chunks

[{'section': 'Introduction',
  'content': 'WHO recommendations on  \nmaternal and newborn care for  \na positive postnatal experience\nWHO recommendations on \nmaternal and newborn care for  \na positive postnatal experience\nWHO recommendations on maternal and newborn care for a positive postnatal experience\nThis publication is the update of the document published in 2014 entitled “WHO recommendations on postnatal \ncare of the mother and newborn”.\nISBN 978-92-4-004598-9 (electronic version)\nISBN 978-92-4-004599-6 (print version)'},
 {'section': 'Introduction',
  'content': 'ISBN 978-92-4-004598-9 (electronic version)\nISBN 978-92-4-004599-6 (print version)\n© World Health Organization 2022\nSome rights reserved. This work is available under the Creative Commons Attribution-NonCommercial-ShareAlike 3.0 IGO licence (CC BY-NC-SA 3.0 IGO; https:/ / creativecommons.org/licenses/by-nc-sa/3.0/igo ).'},
 {'section': 'Introduction',
  'content': 'Under the terms of this licence, you may co

In [None]:
import pandas as pd
import ace_tools as tools

df = pd.DataFrame(final_chunks)
tools.display_dataframe_to_user(name="Chunked Medical Document", dataframe=df)


Hierarchical Markdown-Based Chunking

In [3]:
import re
import json
import os
from typing import List, Dict, Any, Optional
from tqdm import tqdm

# LangChain imports
from langchain.document_loaders import TextLoader, PyPDFLoader
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    MarkdownHeaderTextSplitter
)
from langchain.docstore.document import Document
from langchain.schema import BaseDocumentTransformer

class MedicalHeaderTextSplitter(MarkdownHeaderTextSplitter):
    """Custom text splitter for medical documents that respects section headers."""
    
    def __init__(self):
        headers_to_split_on = [
            ("#", "chapter"),
            ("##", "section"),
            ("###", "subsection"),
            ("####", "recommendation"),
            ("#####", "remarks")
        ]
        super().__init__(headers_to_split_on=headers_to_split_on)
        
    def _add_md_header(self, text):
        """Convert medical document headers to markdown format."""
        # Convert chapter headers
        text = re.sub(r'(?m)^(?:Chapter|CHAPTER)\s+(\d+)[\.:]?\s+(.+)$', r'# \1 \2', text)
        
        # Convert section headers
        text = re.sub(r'(?m)^(?:\d+\.\d+\.?\s+|\d+\.\s+)([A-Z][A-Za-z\s\-:]+)$', r'## \1', text)
        
        # Convert subsection headers
        text = re.sub(r'(?m)^(?:[A-Z]\.\d+\.?\s+|[A-Z]\.\s+)([A-Za-z][A-Za-z\s\-:]+)$', r'### \1', text)
        
        # Convert recommendation headers
        text = re.sub(
            r'(?m)^RECOMMENDATION\s+([A-Z0-9\.]+):\s+(.+?)(?:\((?:Recommended|Context-specific|Not recommended).*?\))?$', 
            r'#### RECOMMENDATION \1: \2', 
            text
        )
        
        # Convert remarks sections
        text = re.sub(r'(?m)^Remarks:$', r'##### Remarks:', text)
        
        return text

class MedicalEvidenceExtractor(BaseDocumentTransformer):
    """Extract evidence levels and recommendation types from medical text."""
    
    def __init__(self):
        self.evidence_pattern = re.compile(r'(?:high|moderate|low|very\s+low)(?:-|\s+)(?:quality|certainty)\s+evidence', re.IGNORECASE)
        self.recommendation_type_pattern = re.compile(r'\((Recommended|Context-specific recommendation|Not recommended).*?\)')
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Extract evidence levels and enhance document metadata."""
        for doc in documents:
            # Only process if it's a recommendation
            if doc.metadata.get('heading_type') == 'recommendation':
                # Extract evidence level
                evidence_match = self.evidence_pattern.search(doc.page_content)
                if evidence_match:
                    doc.metadata['evidence_level'] = evidence_match.group(0)
                
                # Extract recommendation type
                rec_type_match = self.recommendation_type_pattern.search(doc.page_content)
                if rec_type_match:
                    doc.metadata['recommendation_type'] = rec_type_match.group(1)
                
                # Extract recommendation ID
                rec_id_match = re.search(r'RECOMMENDATION\s+([A-Z0-9\.]+):', doc.page_content)
                if rec_id_match:
                    doc.metadata['recommendation_id'] = rec_id_match.group(1)
        
        return documents

class TableExtractor(BaseDocumentTransformer):
    """Extract tables as separate documents with metadata."""
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Identify and mark table content."""
        table_pattern = re.compile(r'(Table\s+\d+[\.:]?\s+.*?)(?:\n\n|\Z)', re.DOTALL)
        
        result_docs = []
        for doc in documents:
            # Find tables in the document
            tables = table_pattern.findall(doc.page_content)
            
            # If tables found, create separate documents for them
            if tables:
                # Create a copy of the original document with tables removed
                modified_content = doc.page_content
                for table in tables:
                    modified_content = modified_content.replace(table, "")
                
                # Add the modified document if it still has significant content
                if len(modified_content.strip()) > 100:
                    modified_doc = Document(
                        page_content=modified_content,
                        metadata=doc.metadata.copy()
                    )
                    result_docs.append(modified_doc)
                
                # Add each table as a separate document
                for table in tables:
                    if len(table.strip()) > 50:  # Skip very small tables
                        table_doc = Document(
                            page_content=table,
                            metadata={
                                **doc.metadata.copy(),
                                "chunk_type": "table",
                                "parent_section_path": doc.metadata.get("section_path", [])
                            }
                        )
                        result_docs.append(table_doc)
            else:
                # No tables, keep the original document
                result_docs.append(doc)
                
        return result_docs

class RecursiveHierarchicalSplitter:
    """Split documents hierarchically, preserving parent-child relationships."""
    
    def __init__(
        self, 
        chunk_size: int = 1000, 
        chunk_overlap: int = 200,
        include_metadata: bool = True
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.include_metadata = include_metadata
        
        # Initialize splitters
        self.header_splitter = MedicalHeaderTextSplitter()
        self.paragraph_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", ". ", " ", ""]
        )
        self.evidence_extractor = MedicalEvidenceExtractor()
        self.table_extractor = TableExtractor()
    
    def _build_section_path(self, doc: Document) -> List[str]:
        """Build the hierarchical section path for a document."""
        path = []
        
        # Add chapter if available
        if 'chapter' in doc.metadata:
            path.append(doc.metadata['chapter'])
            
        # Add section if available
        if 'section' in doc.metadata:
            path.append(doc.metadata['section'])
            
        # Add subsection if available
        if 'subsection' in doc.metadata:
            path.append(doc.metadata['subsection'])
        
        return path
    
    def _extract_document_metadata(self, text: str) -> Dict[str, str]:
        """Extract document-level metadata."""
        title = ""
        doc_type = ""
        
        # Extract title from WHO document
        title_match = re.search(r'(?:WHO|World Health Organization)\s+(?:recommendations|guidelines)\s+(?:on|for)\s+([A-Za-z\s\-,]+)', text[:3000])
        if title_match:
            title = title_match.group(1).strip()
            doc_type = "WHO Guidelines"
        
        return {
            "title": title,
            "document_type": doc_type
        }
    
    def process_text(self, text: str) -> List[Document]:
        """Process text into hierarchical chunks."""
        # Extract document metadata
        doc_metadata = self._extract_document_metadata(text)
        
        # Convert headers to markdown format for the splitter
        md_text = self.header_splitter._add_md_header(text)
        
        # Split on headers
        docs = self.header_splitter.split_text(md_text)
        
        # Extract evidence levels and recommendation metadata
        docs = self.evidence_extractor.transform_documents(docs)
        
        # Extract tables
        docs = self.table_extractor.transform_documents(docs)
        
        # Build section paths for each document
        for doc in docs:
            section_path = self._build_section_path(doc)
            doc.metadata['section_path'] = section_path
            
            # Add document metadata
            doc.metadata['document_title'] = doc_metadata.get('title', '')
            doc.metadata['document_type'] = doc_metadata.get('document_type', '')
            
            # Determine chunk type if not already set
            if 'chunk_type' not in doc.metadata:
                if doc.metadata.get('heading_type') == 'recommendation':
                    doc.metadata['chunk_type'] = 'recommendation'
                elif doc.metadata.get('heading_type') == 'remarks':
                    doc.metadata['chunk_type'] = 'remarks'
                else:
                    doc.metadata['chunk_type'] = 'text'
        
        # Further split large chunks while preserving metadata
        final_docs = []
        for doc in docs:
            # Don't split recommendation or remarks sections
            if doc.metadata.get('chunk_type') in ['recommendation', 'remarks', 'table']:
                final_docs.append(doc)
            else:
                # Split text sections into smaller chunks
                if len(doc.page_content) > self.chunk_size:
                    smaller_chunks = self.paragraph_splitter.split_text(doc.page_content)
                    for i, chunk in enumerate(smaller_chunks):
                        chunk_doc = Document(
                            page_content=chunk,
                            metadata={
                                **doc.metadata,
                                'chunk_index': i,
                                'total_chunks': len(smaller_chunks)
                            }
                        )
                        final_docs.append(chunk_doc)
                else:
                    final_docs.append(doc)
        
        # Build relationships between chunks
        chunk_dict = {i: doc for i, doc in enumerate(final_docs)}
        for i, doc in chunk_dict.items():
            # Find parent-child relationships
            if doc.metadata.get('heading_type') == 'remarks':
                # Find the recommendation this belongs to
                for j, other_doc in chunk_dict.items():
                    if (other_doc.metadata.get('heading_type') == 'recommendation' and
                        other_doc.metadata.get('recommendation_id') == doc.metadata.get('recommendation_id')):
                        doc.metadata['parent_id'] = j
                        break
        
        return final_docs
    
    def process_file(self, input_file: str, output_file: str = None) -> List[Document]:
        """Process a file and return LangChain documents."""
        # Load the document
        if input_file.lower().endswith('.pdf'):
            loader = PyPDFLoader(input_file)
            pages = loader.load()
            text = "\n\n".join([page.page_content for page in pages])
        else:
            loader = TextLoader(input_file)
            documents = loader.load()
            text = documents[0].page_content
        
        # Process the text
        docs = self.process_text(text)
        
        # Save to output file if requested
        if output_file:
            # Convert to serializable format
            serializable_docs = []
            for i, doc in enumerate(docs):
                serializable_docs.append({
                    "chunk_id": i,
                    "text": doc.page_content,
                    "metadata": doc.metadata
                })
                
            with open(output_file, 'w', encoding='utf-8') as f:
                json.dump(serializable_docs, f, indent=2, ensure_ascii=False)
            print(f"Saved {len(docs)} chunks to {output_file}")
        
        return docs



    
    # parser = argparse.ArgumentParser(description="Process medical documents with LangChain")
    # parser.add_argument("input_file", help="Path to input file (PDF or text)")
    # parser.add_argument("--output_file", help="Path to output JSON file", default=None)
    # parser.add_argument("--chunk_size", type=int, default=1000, help="Maximum chunk size in characters")
    # parser.add_argument("--chunk_overlap", type=int, default=200, help="Overlap between chunks in characters")
    
    # args = parser.parse_args()
    
    # Default output file if not specified
    # if not args.output_file:
    #     base_name = os.path.splitext(args.input_file)[0]
    #     args.output_file = f"{base_name}_langchain_chunks.json"
    
    # splitter = RecursiveHierarchicalSplitter(
    #     chunk_size=1000,
    #     chunk_overlap=200
    # )

    # docs = splitter.process_file("9789240020306-eng.pdf", "chunks.json")
    # print(f"Generated {len(docs)} chunks")

# if __name__ == "__main__":
#     main()

Saved 282 chunks to chunks.json
Generated 282 chunks


In [5]:
splitter = RecursiveHierarchicalSplitter(
    chunk_size=1000,
    chunk_overlap=200
)

docs = splitter.process_file("9789240020306-eng.pdf", "chunks2.json")
print(f"Generated {len(docs)} chunks")

Saved 282 chunks to chunks2.json
Generated 282 chunks


Above chunking method with embedding and chroma database and querying - openAI

In [6]:
import os
import re
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import argparse
from tqdm import tqdm

# LangChain imports
from langchain.document_loaders import TextLoader, PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain.docstore.document import Document
from langchain.schema import BaseDocumentTransformer
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import os
class MedicalHeaderTextSplitter(MarkdownHeaderTextSplitter):
    """Custom text splitter for medical documents that respects section headers."""
    
    def __init__(self):
        headers_to_split_on = [
            ("#", "chapter"),
            ("##", "section"),
            ("###", "subsection"),
            ("####", "recommendation"),
            ("#####", "remarks")
        ]
        super().__init__(headers_to_split_on=headers_to_split_on)
        
    def _add_md_header(self, text):
        """Convert medical document headers to markdown format."""
        # Convert chapter headers
        text = re.sub(r'(?m)^(?:Chapter|CHAPTER)\s+(\d+)[\.:]?\s+(.+)$', r'# \1 \2', text)
        
        # Convert section headers
        text = re.sub(r'(?m)^(?:\d+\.\d+\.?\s+|\d+\.\s+)([A-Z][A-Za-z\s\-:]+)$', r'## \1', text)
        
        # Convert subsection headers
        text = re.sub(r'(?m)^(?:[A-Z]\.\d+\.?\s+|[A-Z]\.\s+)([A-Za-z][A-Za-z\s\-:]+)$', r'### \1', text)
        
        # Convert recommendation headers
        text = re.sub(
            r'(?m)^RECOMMENDATION\s+([A-Z0-9\.]+):\s+(.+?)(?:\((?:Recommended|Context-specific|Not recommended).*?\))?$', 
            r'#### RECOMMENDATION \1: \2', 
            text
        )
        
        # Convert remarks sections
        text = re.sub(r'(?m)^Remarks:$', r'##### Remarks:', text)
        
        return text

class MedicalEvidenceExtractor(BaseDocumentTransformer):
    """Extract evidence levels and recommendation types from medical text."""
    
    def __init__(self):
        self.evidence_pattern = re.compile(r'(?:high|moderate|low|very\s+low)(?:-|\s+)(?:quality|certainty)\s+evidence', re.IGNORECASE)
        self.recommendation_type_pattern = re.compile(r'\((Recommended|Context-specific recommendation|Not recommended).*?\)')
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Extract evidence levels and enhance document metadata."""
        for doc in documents:
            # Only process if it's a recommendation
            if doc.metadata.get('heading_type') == 'recommendation':
                # Extract evidence level
                evidence_match = self.evidence_pattern.search(doc.page_content)
                if evidence_match:
                    doc.metadata['evidence_level'] = evidence_match.group(0)
                
                # Extract recommendation type
                rec_type_match = self.recommendation_type_pattern.search(doc.page_content)
                if rec_type_match:
                    doc.metadata['recommendation_type'] = rec_type_match.group(1)
                
                # Extract recommendation ID
                rec_id_match = re.search(r'RECOMMENDATION\s+([A-Z0-9\.]+):', doc.page_content)
                if rec_id_match:
                    doc.metadata['recommendation_id'] = rec_id_match.group(1)
        
        return documents

class TableExtractor(BaseDocumentTransformer):
    """Extract tables as separate documents with metadata."""
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Identify and mark table content."""
        table_pattern = re.compile(r'(Table\s+\d+[\.:]?\s+.*?)(?:\n\n|\Z)', re.DOTALL)
        
        result_docs = []
        for doc in documents:
            # Find tables in the document
            tables = table_pattern.findall(doc.page_content)
            
            # If tables found, create separate documents for them
            if tables:
                # Create a copy of the original document with tables removed
                modified_content = doc.page_content
                for table in tables:
                    modified_content = modified_content.replace(table, "")
                
                # Add the modified document if it still has significant content
                if len(modified_content.strip()) > 100:
                    modified_doc = Document(
                        page_content=modified_content,
                        metadata=doc.metadata.copy()
                    )
                    result_docs.append(modified_doc)
                
                # Add each table as a separate document
                for table in tables:
                    if len(table.strip()) > 50:  # Skip very small tables
                        table_doc = Document(
                            page_content=table,
                            metadata={
                                **doc.metadata.copy(),
                                "chunk_type": "table",
                                "parent_section_path": doc.metadata.get("section_path", [])
                            }
                        )
                        result_docs.append(table_doc)
            else:
                # No tables, keep the original document
                result_docs.append(doc)
                
        return result_docs

class MedicalDocumentProcessor:
    """Process medical documents into semantically meaningful chunks."""
    
    def __init__(
        self, 
        chunk_size: int = 1000, 
        chunk_overlap: int = 200,
        embedding_model_name: str = "text-embedding-ada-002"
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.embedding_model_name = embedding_model_name
        
        # Initialize document processing pipeline
        self.header_splitter = MedicalHeaderTextSplitter()
        self.paragraph_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", ". ", " ", ""]
        )
        self.evidence_extractor = MedicalEvidenceExtractor()
        self.table_extractor = TableExtractor()
        
        # Initialize embedding model
        self.embeddings = OpenAIEmbeddings(model=embedding_model_name)
    
    def _extract_document_metadata(self, text: str) -> Dict[str, str]:
        """Extract document-level metadata."""
        title = ""
        doc_type = ""
        
        # Extract title from WHO document
        title_match = re.search(r'(?:WHO|World Health Organization)\s+(?:recommendations|guidelines)\s+(?:on|for)\s+([A-Za-z\s\-,]+)', text[:3000])
        if title_match:
            title = title_match.group(1).strip()
            doc_type = "WHO Guidelines"
            
        # If no specific match, try to extract a general title
        if not title:
            title_match = re.search(r'^([A-Z][A-Za-z\s\-:,]+(?:Guidelines|Recommendations|Guidance))', text[:1000])
            if title_match:
                title = title_match.group(1).strip()
                doc_type = "Medical Guidelines"
        
        return {
            "title": title,
            "document_type": doc_type
        }
    
    def _build_section_path(self, doc: Document) -> List[str]:
        """Build the hierarchical section path for a document."""
        path = []
        
        # Add chapter if available
        if 'chapter' in doc.metadata:
            path.append(doc.metadata['chapter'])
            
        # Add section if available
        if 'section' in doc.metadata:
            path.append(doc.metadata['section'])
            
        # Add subsection if available
        if 'subsection' in doc.metadata:
            path.append(doc.metadata['subsection'])
        
        return path
    
    def process_text(self, text: str, source_name: str = "") -> List[Document]:
        """Process text into hierarchical chunks."""
        # Extract document metadata
        doc_metadata = self._extract_document_metadata(text)
        
        # Add source information to metadata
        if source_name:
            doc_metadata["source"] = source_name
        
        # Convert headers to markdown format for the splitter
        md_text = self.header_splitter._add_md_header(text)
        
        # Split on headers
        docs = self.header_splitter.split_text(md_text)
        
        # Extract evidence levels and recommendation metadata
        docs = self.evidence_extractor.transform_documents(docs)
        
        # Extract tables
        docs = self.table_extractor.transform_documents(docs)
        
        # Build section paths for each document
        for doc in docs:
            section_path = self._build_section_path(doc)
            doc.metadata['section_path'] = section_path
            
            # Add document metadata
            doc.metadata['document_title'] = doc_metadata.get('title', '')
            doc.metadata['document_type'] = doc_metadata.get('document_type', '')
            
            # Determine chunk type if not already set
            if 'chunk_type' not in doc.metadata:
                if doc.metadata.get('heading_type') == 'recommendation':
                    doc.metadata['chunk_type'] = 'recommendation'
                elif doc.metadata.get('heading_type') == 'remarks':
                    doc.metadata['chunk_type'] = 'remarks'
                else:
                    doc.metadata['chunk_type'] = 'text'
        
        # Further split large chunks while preserving metadata
        final_docs = []
        for doc in docs:
            # Don't split recommendation or remarks sections
            if doc.metadata.get('chunk_type') in ['recommendation', 'remarks', 'table']:
                final_docs.append(doc)
            else:
                # Split text sections into smaller chunks
                if len(doc.page_content) > self.chunk_size:
                    smaller_chunks = self.paragraph_splitter.split_text(doc.page_content)
                    for i, chunk in enumerate(smaller_chunks):
                        chunk_doc = Document(
                            page_content=chunk,
                            metadata={
                                **doc.metadata,
                                'chunk_index': i,
                                'total_chunks': len(smaller_chunks)
                            }
                        )
                        final_docs.append(chunk_doc)
                else:
                    final_docs.append(doc)
        
        return final_docs
    
    def load_documents(self, input_path: str) -> List[Document]:
        """Load documents from file or directory."""
        documents = []
        
        if os.path.isdir(input_path):
            # Process all files in directory
            for filename in os.listdir(input_path):
                file_path = os.path.join(input_path, filename)
                if os.path.isfile(file_path):
                    documents.extend(self._load_single_document(file_path))
        else:
            # Process single file
            documents = self._load_single_document(input_path)
        
        return documents
    
    def _load_single_document(self, file_path: str) -> List[Document]:
        """Load and process a single document."""
        print(f"Processing {file_path}...")
        
        # Load the document
        if file_path.lower().endswith('.pdf'):
            loader = PyPDFLoader(file_path)
            pages = loader.load()
            text = "\n\n".join([page.page_content for page in pages])
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
        
        # Process the text
        source_name = os.path.basename(file_path)
        return self.process_text(text, source_name)
    
    def create_vector_store(self, documents: List[Document], persist_directory: str = None) -> Chroma:
        """Create a vector store from processed documents."""
        # Create Chroma DB with metadata
        db = Chroma.from_documents(
            documents=documents,
            embedding=self.embeddings,
            persist_directory=persist_directory,
            collection_metadata={"hnsw:space": "cosine"}  # Optimize for medical text similarity
        )
        
        if persist_directory:
            db.persist()
            print(f"Vector database persisted to {persist_directory}")
        
        return db
    
    def load_vector_store(self, persist_directory: str) -> Chroma:
        """Load an existing vector store."""
        return Chroma(
            persist_directory=persist_directory,
            embedding_function=self.embeddings
        )



In [9]:
processor = MedicalDocumentProcessor(
        chunk_size=1000,
        chunk_overlap=200
    )

In [8]:
pip install openai

Collecting openai
  Downloading openai-1.65.2-py3-none-any.whl.metadata (27 kB)
Collecting jiter<1,>=0.4.0 (from openai)
  Downloading jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.2 kB)
Downloading openai-1.65.2-py3-none-any.whl (473 kB)
Downloading jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (345 kB)
Installing collected packages: jiter, openai
Successfully installed jiter-0.8.2 openai-1.65.2
Note: you may need to restart the kernel to use updated packages.


In [1]:
def main():
    parser = argparse.ArgumentParser(description="Medical Document RAG Pipeline")
    parser.add_argument("--input", required=True, help="Path to input file or directory")
    parser.add_argument("--db_path", required=True, help="Path to store/load Chroma database")
    parser.add_argument("--rebuild_db", action="store_true", help="Rebuild the vector database")
    parser.add_argument("--chunk_size", type=int, default=1000, help="Chunk size for text splitting")
    parser.add_argument("--chunk_overlap", type=int, default=200, help="Chunk overlap for text splitting")
    parser.add_argument("--query", help="Query to run against the RAG system")
    parser.add_argument("--filter_recommendations", action="store_true", help="Filter results to only recommendations")
    parser.add_argument("--evidence_level", choices=["high", "moderate", "low"], help="Filter by evidence level")
    
    args = parser.parse_args()
    
    # Set up processor
    processor = MedicalDocumentProcessor(
        chunk_size=1000,
        chunk_overlap=200
    )
    
    # Set up vector store
    if args.rebuild_db or not os.path.exists(args.db_path):
        # Create database directory if it doesn't exist
        os.makedirs(args.db_path, exist_ok=True)
        
        # Process documents
        documents = processor.load_documents(args.input)
        print(f"Processed {len(documents)} chunks")
        
        # Create vector store
        vector_store = processor.create_vector_store(documents, args.db_path)
    else:
        # Load existing vector store
        vector_store = processor.load_vector_store(args.db_path)
        print(f"Loaded existing vector database from {args.db_path}")
    
    # Set up RAG system
    rag = MedicalRAG(vector_store)
    
    # Run query if provided
    if args.query:
        if args.filter_recommendations:
            result = rag.query_recommendations(args.query, args.evidence_level)
        else:
            result = rag.query(args.query)
        
        print("\n" + "="*50)
        print("QUESTION:")
        print(result["question"])
        print("\nANSWER:")
        print(result["answer"])
        print("\nSOURCES:")
        for i, source in enumerate(result["sources"]):
            print(f"\n{i+1}. {source['metadata'].get('document_title', 'Unknown document')}")
            if "section_path" in source["metadata"]:
                print(f"   Section: {' > '.join(source['metadata']['section_path'])}")
            if "chunk_type" in source["metadata"]:
                print(f"   Type: {source['metadata']['chunk_type']}")
            if "evidence_level" in source["metadata"]:
                print(f"   Evidence level: {source['metadata']['evidence_level']}")
    else:
        print("\nVector store is ready. Use --query to ask a question.")
        print("Example: python medical_rag.py --db_path ./chroma_db --query \"What are the recommendations for iron supplementation in pregnant women?\"")

if __name__ == "__main__":
    main()

NameError: name 'argparse' is not defined

# FAISS Vector store with heirarchical chunking and text-embedding-ada-002 embedding

In [None]:
import os
import re
import json
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import argparse
from tqdm import tqdm

# LangChain imports
from langchain.document_loaders import TextLoader, PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain.docstore.document import Document
from langchain.schema import BaseDocumentTransformer
from langchain.vectorstores import Chroma
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import os
import uuid
from langchain.vectorstores import FAISS
class MedicalHeaderTextSplitter(MarkdownHeaderTextSplitter):
    """Custom text splitter for medical documents that respects section headers."""
    
    def __init__(self):
        headers_to_split_on = [
            ("#", "chapter"),
            ("##", "section"),
            ("###", "subsection"),
            ("####", "recommendation"),
            ("#####", "remarks")
        ]
        super().__init__(headers_to_split_on=headers_to_split_on)
        
    def _add_md_header(self, text):
        """Convert medical document headers to markdown format."""
        # Convert chapter headers
        text = re.sub(r'(?m)^(?:Chapter|CHAPTER)\s+(\d+)[\.:]?\s+(.+)$', r'# \1 \2', text)
        
        # Convert section headers
        text = re.sub(r'(?m)^(?:\d+\.\d+\.?\s+|\d+\.\s+)([A-Z][A-Za-z\s\-:]+)$', r'## \1', text)
        
        # Convert subsection headers
        text = re.sub(r'(?m)^(?:[A-Z]\.\d+\.?\s+|[A-Z]\.\s+)([A-Za-z][A-Za-z\s\-:]+)$', r'### \1', text)
        
        # Convert recommendation headers
        text = re.sub(
            r'(?m)^RECOMMENDATION\s+([A-Z0-9\.]+):\s+(.+?)(?:\((?:Recommended|Context-specific|Not recommended).*?\))?$', 
            r'#### RECOMMENDATION \1: \2', 
            text
        )
        
        # Convert remarks sections
        text = re.sub(r'(?m)^Remarks:$', r'##### Remarks:', text)
        
        return text

class MedicalEvidenceExtractor(BaseDocumentTransformer):
    """Extract evidence levels and recommendation types from medical text."""
    
    def __init__(self):
        self.evidence_pattern = re.compile(r'(?:high|moderate|low|very\s+low)(?:-|\s+)(?:quality|certainty)\s+evidence', re.IGNORECASE)
        self.recommendation_type_pattern = re.compile(r'\((Recommended|Context-specific recommendation|Not recommended).*?\)')
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Extract evidence levels and enhance document metadata."""
        for doc in documents:
            # Only process if it's a recommendation
            if doc.metadata.get('heading_type') == 'recommendation':
                # Extract evidence level
                evidence_match = self.evidence_pattern.search(doc.page_content)
                if evidence_match:
                    doc.metadata['evidence_level'] = evidence_match.group(0)
                
                # Extract recommendation type
                rec_type_match = self.recommendation_type_pattern.search(doc.page_content)
                if rec_type_match:
                    doc.metadata['recommendation_type'] = rec_type_match.group(1)
                
                # Extract recommendation ID
                rec_id_match = re.search(r'RECOMMENDATION\s+([A-Z0-9\.]+):', doc.page_content)
                if rec_id_match:
                    doc.metadata['recommendation_id'] = rec_id_match.group(1)
        
        return documents

class TableExtractor(BaseDocumentTransformer):
    """Extract tables as separate documents with metadata."""
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Identify and mark table content."""
        table_pattern = re.compile(r'(Table\s+\d+[\.:]?\s+.*?)(?:\n\n|\Z)', re.DOTALL)
        
        result_docs = []
        for doc in documents:
            # Find tables in the document
            tables = table_pattern.findall(doc.page_content)
            
            # If tables found, create separate documents for them
            if tables:
                # Create a copy of the original document with tables removed
                modified_content = doc.page_content
                for table in tables:
                    modified_content = modified_content.replace(table, "")
                
                # Add the modified document if it still has significant content
                if len(modified_content.strip()) > 100:
                    modified_doc = Document(
                        page_content=modified_content,
                        metadata=doc.metadata.copy()
                    )
                    result_docs.append(modified_doc)
                
                # Add each table as a separate document
                for table in tables:
                    if len(table.strip()) > 50:  # Skip very small tables
                        table_doc = Document(
                            page_content=table,
                            metadata={
                                **doc.metadata.copy(),
                                "chunk_type": "table",
                                "parent_section_path": doc.metadata.get("section_path", [])
                            }
                        )
                        result_docs.append(table_doc)
            else:
                # No tables, keep the original document
                result_docs.append(doc)
                
        return result_docs

class MedicalDocumentProcessor:
    """Process medical documents into semantically meaningful chunks."""
    
    def __init__(
        self, 
        chunk_size: int = 1000, 
        chunk_overlap: int = 200,
        embedding_model_name: str = "text-embedding-ada-002",
        db_directory: str = "medical_db"
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.embedding_model_name = embedding_model_name
        
        self.db_directory = db_directory
        # Initialize document processing pipeline
        self.header_splitter = MedicalHeaderTextSplitter()
        self.paragraph_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", ". ", " ", ""]
        )
        self.evidence_extractor = MedicalEvidenceExtractor()
        self.table_extractor = TableExtractor()
        
        # Initialize embedding model
        self.embeddings = OpenAIEmbeddings(model=embedding_model_name)
        index_path = os.path.join(db_directory, "index.faiss")
        if os.path.exists(index_path):
            self.db = FAISS.load_local(
                folder_path=db_directory,
                embeddings=self.embeddings,
                allow_dangerous_deserialization=True
            )
        else:
            self.db = None
    
    def _extract_document_metadata(self, text: str) -> Dict[str, str]:
        """Extract document-level metadata."""
        title = ""
        doc_type = ""
        
        # Extract title from WHO document
        title_match = re.search(r'(?:WHO|World Health Organization)\s+(?:recommendations|guidelines)\s+(?:on|for)\s+([A-Za-z\s\-,]+)', text[:3000])
        if title_match:
            title = title_match.group(1).strip()
            doc_type = "WHO Guidelines"
            
        # If no specific match, try to extract a general title
        if not title:
            title_match = re.search(r'^([A-Z][A-Za-z\s\-:,]+(?:Guidelines|Recommendations|Guidance))', text[:1000])
            if title_match:
                title = title_match.group(1).strip()
                doc_type = "Medical Guidelines"
        
        return {
            "title": title,
            "document_type": doc_type
        }
    
    def _build_section_path(self, doc: Document) -> List[str]:
        """Build the hierarchical section path for a document."""
        path = []
        
        # Add chapter if available
        if 'chapter' in doc.metadata:
            path.append(doc.metadata['chapter'])
            
        # Add section if available
        if 'section' in doc.metadata:
            path.append(doc.metadata['section'])
            
        # Add subsection if available
        if 'subsection' in doc.metadata:
            path.append(doc.metadata['subsection'])
        
        return path
    
    def process_text(self, text: str, source_name: str = "") -> List[Document]:
        """Process text into hierarchical chunks."""
        # Extract document metadata
        doc_metadata = self._extract_document_metadata(text)
        
        # Add source information to metadata
        if source_name:
            doc_metadata["source"] = source_name
        
        # Convert headers to markdown format for the splitter
        md_text = self.header_splitter._add_md_header(text)
        
        # Split on headers
        docs = self.header_splitter.split_text(md_text)
       
        
        # Extract evidence levels and recommendation metadata
        docs = self.evidence_extractor.transform_documents(docs)
        
        # Extract tables
        docs = self.table_extractor.transform_documents(docs)
        
        # Build section paths for each document
        for doc in docs:
            section_path = self._build_section_path(doc)
            doc.metadata['section_path'] = section_path
            
            # Add document metadata
            doc.metadata['document_title'] = doc_metadata.get('title', '')
            doc.metadata['document_type'] = doc_metadata.get('document_type', '')
            
            # Determine chunk type if not already set
            if 'chunk_type' not in doc.metadata:
                if doc.metadata.get('heading_type') == 'recommendation':
                    doc.metadata['chunk_type'] = 'recommendation'
                elif doc.metadata.get('heading_type') == 'remarks':
                    doc.metadata['chunk_type'] = 'remarks'
                else:
                    doc.metadata['chunk_type'] = 'text'
        
        # Further split large chunks while preserving metadata
        final_docs = []
        for doc in docs:
            # Don't split recommendation or remarks sections
            if doc.metadata.get('chunk_type') in ['recommendation', 'remarks', 'table']:
                final_docs.append(doc)
            else:
                # Split text sections into smaller chunks
                if len(doc.page_content) > self.chunk_size:
                    smaller_chunks = self.paragraph_splitter.split_text(doc.page_content)
                    for i, chunk in enumerate(smaller_chunks):
                        chunk_doc = Document(
                            page_content=chunk,
                            metadata={
                                **doc.metadata,
                                'chunk_index': i,
                                'total_chunks': len(smaller_chunks)
                            }
                        )
                        final_docs.append(chunk_doc)
                else:
                    final_docs.append(doc)
        
        return final_docs
        # return docs
    
    def load_documents(self, input_path: str) -> List[Document]:
        """Load documents from file or directory."""
        documents = []
        
        if os.path.isdir(input_path):
            # Process all files in directory
            for filename in os.listdir(input_path):
                file_path = os.path.join(input_path, filename)
                if os.path.isfile(file_path):
                    documents.extend(self._load_single_document(file_path))
        else:
            # Process single file
            documents = self._load_single_document(input_path)
        
        return documents
    
    def _load_single_document(self, file_path: str) -> List[Document]:
        """Load and process a single document."""
        print(f"Processing {file_path}...")
        
        # Load the document
        if file_path.lower().endswith('.pdf'):
            loader = PyPDFLoader(file_path)
            pages = loader.load()
            text = "\n\n".join([page.page_content for page in pages])
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
        
        # Process the text
        source_name = os.path.basename(file_path)
        return self.process_text(text, source_name)
    
    def create_vector_store(self, documents: List[Document], collection_name: Optional[str] = None) :
        """Create a vector store from processed documents."""
        # Create Chroma DB with metadata
        if collection_name is None:
            collection_name = f"medical_{uuid.uuid4().hex[:8]}"
        
        # Create database directory if it doesn't exist
        index_path = os.path.join(self.db_directory, collection_name)
        os.makedirs(index_path, exist_ok=True)
        
        # Initialize FAISS index from documents
        self.db = FAISS.from_documents(
            documents=documents,
            embedding=self.embeddings
        )
        
        self.db.save_local(index_path)
        # print(f"Stored {len()} chunks in collection '{collection_name}'")
        
        
        return self.db

    




In [None]:
processor = MedicalDocumentProcessor(
        chunk_size=1000,
        chunk_overlap=200
    )

NameError: name 'MedicalDocumentProcessor' is not defined

In [None]:

documents = processor.load_documents("/home/vidhij2/nivi/maternal-documents")
print(f"Processed {len(documents)} chunks")
        
# Create vector store
vector_store = processor.create_vector_store(documents,collection_name="heirchical_chunking_rag")

    #     # Load existing vector store
    #     vector_store = processor.load_vector_store(args.db_path)
    #     print(f"Loaded existing vector database from {args.db_path}")
    

Processing /home/vidhij2/nivi/maternal-documents/Training_Manual_for_Medical_Methods_of_Abortion_(MMA)_in_Early_Gestation.pdf...
Processing /home/vidhij2/nivi/maternal-documents/Care During Pregnancy and Childbirth Training Manual for CHO at AB-HWC.pdf...
Processing /home/vidhij2/nivi/maternal-documents/9789241549912-eng.pdf...
Processing /home/vidhij2/nivi/maternal-documents/Midwifery-Educators-and-Nurse-Practitioner-Midwives.pdf...
Processing /home/vidhij2/nivi/maternal-documents/Operationalization-of-Midwifery-Units.pdf...
Processing /home/vidhij2/nivi/maternal-documents/Scope of Practice Document .pdf...
Processing /home/vidhij2/nivi/maternal-documents/sba_guidelines_for_skilled_attendance_at_birth.pdf...
Processing /home/vidhij2/nivi/maternal-documents/Guidelines_on_Midwifery_Services_in_India.pdf...
Processing /home/vidhij2/nivi/maternal-documents/Guidance_Note_on_optimizing_post_natal_care.pdf...
Processing /home/vidhij2/nivi/maternal-documents/JSSK_Final_English.pdf...
Processi

In [None]:
vector_store = processor.create_vector_store(documents,collection_name="heirchical_chunking_rag_openAI")

In [None]:
# vector_store.save_local("/home/vidhij2/nivi/medical_db/heirchical_chunking_rag_openAI")
print(f"✅ FAISS index saved successfully with {vector_store.index.ntotal} v vectors.")


✅ FAISS index saved successfully with 3020 v vectors.


In [None]:
vector_store

<langchain_community.vectorstores.faiss.FAISS at 0x7ff89f07a220>

In [None]:
documents

[Document(metadata={'section_path': [], 'document_title': '', 'document_type': '', 'chunk_type': 'text', 'chunk_index': 0, 'total_chunks': 8}, page_content='i\nTraining Manual\nFor MMA in Early Gestation\nTRAINING MANUAL\nFor Medical Methods of Abortion (MMA)\nin Early Gestation  \nTraining Manual\nFor MMA in Early Gestation\nTRAINING MANUAL\nFor Medical Methods of Abortion (MMA)\nin Early Gestation\nJULY 2022  \nTraining Manual\nFor MMA in Early Gestation\ni\nHealthy Village, Healthy Nation  \nDR. SUMITA GHOSH\nAdditional Commissioner\nTelefax : 011-23063178\nE-mail : sumita.ghosh@nic.in  \nPreface\nGOVERNMENT OF INDIA\nMINISTRY OF HEALTH & FAMILY WELFARE\nNIRMAN BHAVAN, MAULANA AZAD ROAD\nNEW DELHI - 110011  \n29’h July, 2022  \nThe Medical Termination of Pregnancy Act, 1971 has recently been amended and the\nMedical Termination of Pregnancy (Amendment ) Act, 2021 and the Medical\nTermination of Pregnancy (Amendment) Rules, 2021 as prescribed under the Act have\ncome into force. Thes

Loading the vector store

In [None]:
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")

# Path where FAISS index is stored
faiss_index_path = "medical_db/heirchical_chunking_rag_openAI"

# Load FAISS vector store
vector_store = FAISS.load_local(folder_path=faiss_index_path, embeddings=embeddings, allow_dangerous_deserialization=True)

print(f"FAISS index loaded successfully with {len(vector_store.index)} vectors.")

TypeError: object of type 'IndexFlat' has no len()

*-- only rag 

In [2]:
class MedicalRAG:
    """RAG system for medical documents with advanced filtering and retrieval."""
    
    def __init__(
        self,
        vector_store: Chroma,
        model_name: str = "gpt-4-turbo",
        temperature: float = 0.0,
    ):
        self.vector_store = vector_store
        self.llm = ChatOpenAI(model_name=model_name, temperature=temperature)
        
        # Create RAG prompt template
        self.rag_prompt = PromptTemplate(
            template="""You are a medical information assistant that helps healthcare professionals by providing evidence-based information from medical guidelines and literature.

Context information from medical documents:
{context}

Question: {question}

Instructions:
1. Answer based only on the provided context. If the information isn't in the context, say "I don't have enough information to answer this question based on the provided medical guidelines."
2. Cite specific recommendations, evidence levels, and document sources when available.
3. Be concise but comprehensive.
4. If multiple recommendations or conflicting guidance exists in the context, present all perspectives.
5. When quoting recommendations, preserve their exact wording.

Answer:""",
            input_variables=["context", "question"]
        )
        
        # Setup retrieval chain
        self.retriever = vector_store.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 6}
        )
        
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": self.rag_prompt}
        )
    
    def query(self, question: str, filters: Dict[str, Any] = None) -> Dict[str, Any]:
        """Query the RAG system with optional metadata filtering."""
        if filters:
            # Apply metadata filters to the retriever
            self.retriever.search_kwargs["filter"] = filters
        
        # Get answer
        result = self.qa_chain({"query": question})
        
        # Format response
        sources = []
        for doc in result.get("source_documents", []):
            # Add source information
            source_info = {
                "content": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
                "metadata": doc.metadata
            }
            sources.append(source_info)
        
        return {
            "question": question,
            "answer": result["result"],
            "sources": sources
        }
    
    def query_recommendations(self, question: str, evidence_level: str = None) -> Dict[str, Any]:
        """Query specifically for medical recommendations with optional evidence filtering."""
        filters = {"chunk_type": "recommendation"}
        
        if evidence_level:
            # Filter by evidence level (e.g., "high", "moderate", "low")
            filters["evidence_level"] = {"$regex": f"{evidence_level.lower()}.*evidence"}
        
        return self.query(question, filters)



NameError: name 'Chroma' is not defined

A complete Retrieval-Augmented Generation pipeline with document loading, chunking,
    embedding, database storage, and querying capabilities.

In [4]:
import os
import uuid
import torch
from typing import List, Dict, Any, Optional

import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader, PyPDFLoader, DirectoryLoader
from langchain.vectorstores import Chroma, FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from langchain.llms import OpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from dotenv import load_dotenv

# Load environment variables
load_dotenv()


class RAGPipeline:
    """
    A complete Retrieval-Augmented Generation pipeline with document loading, chunking,
    embedding, database storage, and querying capabilities.
    """

    def __init__(
        self,
        embedding_model: str = "BAAI/bge-large-en-v1.5",
        llm_model: str = "gpt-3.5-turbo-instruct",
        db_directory: str = "db",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        model_kwargs: dict = {"device": "cpu"},
        encode_kwargs: dict = {"normalize_embeddings": True}
    ):
        """
        Initialize the RAG pipeline.
        
        Args:
            embedding_model: The Hugging Face model to use for embeddings (can be Llama or any other compatible model)
            llm_model: The LLM model to use for generation
            db_directory: Directory to store the vector database
            chunk_size: Size of text chunks
            chunk_overlap: Overlap between chunks
            model_kwargs: Additional keyword arguments for the embedding model
            encode_kwargs: Additional keyword arguments for the encoding process
        """
        self.embedding_model = embedding_model
        # self.llm_model = llm_model
        self.db_directory = db_directory
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        
        # Initialize embeddings with Hugging Face (compatible with Llama models)
        self.embeddings = HuggingFaceEmbeddings(
            model_name=embedding_model,
            model_kwargs=model_kwargs,
            encode_kwargs=encode_kwargs
        )
        
        # Initialize LLM
        # self.llm = OpenAI(model_name=llm_model)
        
        # Initialize text splitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, 
            chunk_overlap=chunk_overlap
        )
        
        # Initialize vectorstore if it exists
        if os.path.exists(db_directory):
            self.db = Chroma(
                persist_directory=db_directory,
                embedding_function=self.embeddings
            )
        else:
            self.db = None

    def load_documents(self, file_path: str) -> List[Document]:
        """
        Load documents from a file or directory.
        
        Args:
            file_path: Path to file or directory
            
        Returns:
            List of loaded documents
        """
        if os.path.isdir(file_path):
            # Load from directory
            loader = DirectoryLoader(
                file_path,
                glob="**/*.*",
                loader_cls=self._get_loader_for_extension
            )
            documents = loader.load()
        else:
            # Load single file
            extension = os.path.splitext(file_path)[1].lower()
            loader_cls = self._get_loader_for_extension(extension)
            loader = loader_cls(file_path)
            documents = loader.load()
            
        print(f"Loaded {len(documents)} documents from {file_path}")
        return documents

    def _get_loader_for_extension(self, extension: str):
        """
        Get the appropriate document loader for a file extension.
        
        Args:
            extension: File extension
            
        Returns:
            Document loader class
        """
        if extension == '.pdf':
            return PyPDFLoader
        else:
            # Default to text loader
            return TextLoader

    def process_documents(self, documents: List[Document]) -> List[Document]:
        """
        Split documents into chunks.
        
        Args:
            documents: List of documents to process
            
        Returns:
            List of document chunks
        """
        chunks = self.text_splitter.split_documents(documents)
        print(f"Split into {len(chunks)} chunks")
        return chunks

    def store_documents(self, chunks: List[Document], collection_name: Optional[str] = None):
        """
        Store document chunks in the vector database.
        
        Args:
            chunks: Document chunks to store
            collection_name: Optional name for the collection
        """
        # Generate a collection name if not provided
        if collection_name is None:
            collection_name = f"collection_{uuid.uuid4().hex[:8]}"
            
        # Initialize or get the database
        self.db = FAISS.from_documents(
            documents=chunks,
            embedding=self.embeddings
        )
        os.makedirs(self.db_directory, exist_ok=True)
        
        # Save the index to disk with collection name in the path
        index_path = os.path.join(self.db_directory, collection_name)
        self.db.save_local(index_path)
        
        print(f"Stored {len(chunks)} chunks in collection '{collection_name}'")
        
        return collection_name
    
    def retrieve_chunks(
        self, 
        query: str, 
        n_results: int = 4,
        collection_name: Optional[str] = None
    ) -> List[Document]:
        """
        Retrieve relevant document chunks for a query.
        
        Args:
            query: The query string
            n_results: Number of chunks to retrieve
            collection_name: Optional collection to search in
            
        Returns:
            List of relevant document chunks
        """
        if self.db is None:
            raise ValueError("No database has been created or loaded.")
            
        # If collection name is provided, use that collection
        if collection_name:
            index_path = os.path.join(self.db_directory, collection_name)
            if os.path.exists(index_path):
                db = FAISS.load_local(
                    folder_path=index_path,
                    embeddings=self.embeddings,
                    allow_dangerous_deserialization=True
                )
            else:
                raise ValueError(f"Collection '{collection_name}' not found.")
        else:
            db = self.db
            
        # Retrieve chunks
        chunks = db.similarity_search(query, k=n_results)
        return chunks
    def query(
        self, 
        query: str, 
        n_results: int = 4,
        collection_name: Optional[str] = None
    ) -> Dict[str, Any]:
        """
    #     Perform a query using the RAG pipeline.
        
    #     Args:
    #         query: The query string
    #         n_results: Number of chunks to retrieve
    #         collection_name: Optional collection to search in
            
    #     Returns:
    #         Dictionary with the query result and relevant chunks
    #     """
    #     if self.db is None:
    #         raise ValueError("No database has been created or loaded.")
            
    #     # If collection name is provided, use that collection
    #     if collection_name:
    #         db = Chroma(
    #             persist_directory=self.db_directory,
    #             embedding_function=self.embeddings,
    #             collection_name=collection_name
    #         )
    #     else:
    #         db = self.db
            
    #     # Create a retriever
    #     retriever = db.as_retriever(search_kwargs={"k": n_results})
        
    #     # Create a prompt template
    #     template = """
    #     You are an assistant that answers questions based on the provided context.
        
    #     Context:
    #     {context}
        
    #     Question:
    #     {question}
        
    #     Answer:
    #     """
        
    #     prompt = PromptTemplate(
    #         input_variables=["context", "question"],
    #         template=template
    #     )
        
    #     # Create a QA chain
    #     qa_chain = RetrievalQA.from_chain_type(
    #         llm=self.llm,
    #         chain_type="stuff",
    #         retriever=retriever,
    #         chain_type_kwargs={"prompt": prompt}
    #     )
        
    #     # Execute the query
    #     result = qa_chain({"query": query})
        
    #     # Also retrieve the chunks for reference
        chunks = self.retrieve_chunks(query, n_results, collection_name)
        return chunks
    #     return {
    #         "query": query,
    #         "answer": result["result"],
    #         "chunks": chunks
    #     }

    def ingest_and_store(self, file_path: str, collection_name: Optional[str] = None) -> str:
        """
        Complete pipeline to ingest, process, and store documents.
        
        Args:
            file_path: Path to file or directory to ingest
            collection_name: Optional name for the collection
            
        Returns:
            Collection name
        """
        # Load documents
        documents = self.load_documents(file_path)
        
        # Process into chunks
        chunks = self.process_documents(documents)
        
        # Store in database
        collection_name = self.store_documents(chunks, collection_name)
        
        return collection_name


# Example usage
if __name__ == "__main__":
    # Initialize the RAG pipeline with Llama or other Hugging Face embedding model
    pipeline = RAGPipeline(
        embedding_model="sentence-transformers/all-MiniLM-L6-v2",  # Lightweight model for testing
        # For Llama embeddings, you can use models like:
        # - "meta-llama/Llama-2-7b-hf" (need appropriate permissions)
        # - "NousResearch/Llama-2-7b-hf" (open-access version)
        # - Or any other compatible model from Hugging Face
        model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
    )
    
    # To use a specific Llama model variant optimized for embeddings:
    # pipeline = RAGPipeline(
    #     embedding_model="llamaindex/text-embedding-llama",
    #     model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
    # )
    
    # Ingest and store documents
    collection_name = pipeline.ingest_and_store("/home/vidhij2/nivi/documents/9789240020306-eng.pdf")
    
    # Query the documents
    result = pipeline.query("What is the main topic of these documents?", collection_name=collection_name)
    
    print("\nQuery:", result["query"])
    print("\nAnswer:", result["answer"])
    print("\nRelevant chunks:")
    for i, chunk in enumerate(result["chunks"]):
        print(f"\nChunk {i+1}:")
        print(chunk.page_content[:200] + "..." if len(chunk.page_content) > 200 else chunk.page_content)

Loaded 98 documents from /home/vidhij2/nivi/documents/9789240020306-eng.pdf
Split into 289 chunks
Stored 289 chunks in collection 'collection_283f791b'


TypeError: list indices must be integers or slices, not str

In [7]:
result

[Document(id='6e71a534-7d45-44f9-adf2-d833ef029dac', metadata={'producer': 'Adobe PDF Library 15.0', 'creator': 'Adobe InDesign 16.1 (Macintosh)', 'creationdate': '2021-03-03T11:04:16-08:00', 'moddate': '2021-03-03T11:04:30-08:00', 'trapped': '/False', 'source': '/home/vidhij2/nivi/documents/9789240020306-eng.pdf', 'total_pages': 98, 'page': 8, 'page_label': '1'}, page_content='1\noverview\nOverview\nOverviewPart'),
 Document(id='2a74957c-7e1a-47fe-ac4c-eb8e128ef0ba', metadata={'producer': 'Adobe PDF Library 15.0', 'creator': 'Adobe InDesign 16.1 (Macintosh)', 'creationdate': '2021-03-03T11:04:16-08:00', 'moddate': '2021-03-03T11:04:30-08:00', 'trapped': '/False', 'source': '/home/vidhij2/nivi/documents/9789240020306-eng.pdf', 'total_pages': 98, 'page': 93, 'page_label': '86'}, page_content='used for any other notes, annotations, or communication messages within the team.\nConcept mappings Depending on which systems you plan on interoperating with, other columns will most likely need t

In [1]:
import os
import uuid
import torch
from typing import List, Dict, Any, Optional
from transformers import pipeline
import numpy as np
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import TextLoader, PyPDFLoader, DirectoryLoader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from dotenv import load_dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load environment variables
load_dotenv()


class RAGPipeline:
    """
    A complete Retrieval-Augmented Generation pipeline with document loading, chunking,
    embedding, database storage, and querying capabilities.
    """

    def __init__(
        self,
        embedding_model: str = "BAAI/bge-large-en-v1.5",
        llm_model: str = "google/flan-t5-base",
        db_directory: str = "db",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        model_kwargs: dict = {"device": "cpu"},
        encode_kwargs: dict = {"normalize_embeddings": True},
        use_gpu: bool = None
    ):
        """
        Initialize the RAG pipeline.
        
        Args:
            embedding_model: The Hugging Face model to use for embeddings (can be Llama or any other compatible model)
            llm_model: The Hugging Face model to use for generation
            db_directory: Directory to store the vector database
            chunk_size: Size of text chunks
            chunk_overlap: Overlap between chunks
            model_kwargs: Additional keyword arguments for the embedding model
            encode_kwargs: Additional keyword arguments for the encoding process
            use_gpu: Whether to use GPU for LLM. If None, will auto-detect GPU.
        """
        self.embedding_model = embedding_model
        self.llm_model = "/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct"
        self.db_directory = db_directory
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        
        # Determine device
        if use_gpu is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
        
        # Initialize embeddings with Hugging Face (compatible with Llama models)
        self.embeddings = HuggingFaceEmbeddings(
            model_name=embedding_model,
            model_kwargs=model_kwargs,
            encode_kwargs=encode_kwargs
        )
        
        # Initialize local LLM using Hugging Face models
        tokenizer = AutoTokenizer.from_pretrained(llm_model)
        model = AutoModelForCausalLM.from_pretrained(llm_model
)
        
        # Create text generation pipeline with LLM
        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.95,
            repetition_penalty=1.15
        )
        
        # Create LangChain wrapper around the pipeline
        self.llm = HuggingFacePipeline(pipeline=pipe)
        
        # Initialize text splitter
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, 
            chunk_overlap=chunk_overlap
        )
        
        # Initialize vectorstore if it exists
        if os.path.exists(os.path.join(db_directory, "index.faiss")):
            self.db = FAISS.load_local(
                folder_path=db_directory,
                embeddings=self.embeddings,
                allow_dangerous_deserialization=True
            )
        else:
            self.db = None

    def load_documents(self, file_path: str) -> List[Document]:
        """
        Load documents from a file or directory.
        
        Args:
            file_path: Path to file or directory
            
        Returns:
            List of loaded documents
        """
        if os.path.isdir(file_path):
            # Load from directory
            loader = DirectoryLoader(
                file_path,
                glob="**/*.*",
                loader_cls=self._get_loader_for_extension
            )
            documents = loader.load()
        else:
            # Load single file
            extension = os.path.splitext(file_path)[1].lower()
            loader_cls = self._get_loader_for_extension(extension)
            loader = loader_cls(file_path)
            documents = loader.load()
            
        print(f"Loaded {len(documents)} documents from {file_path}")
        return documents

    def _get_loader_for_extension(self, extension: str):
        """
        Get the appropriate document loader for a file extension.
        
        Args:
            extension: File extension
            
        Returns:
            Document loader class
        """
        if extension == '.pdf':
            return PyPDFLoader
        else:
            # Default to text loader
            return TextLoader

    def process_documents(self, documents: List[Document]) -> List[Document]:
        """
        Split documents into chunks.
        
        Args:
            documents: List of documents to process
            
        Returns:
            List of document chunks
        """
        chunks = self.text_splitter.split_documents(documents)
        print(f"Split into {len(chunks)} chunks")
        return chunks

    def store_documents(self, chunks: List[Document], collection_name: Optional[str] = None):
        """
        Store document chunks in the vector database.
        
        Args:
            chunks: Document chunks to store
            collection_name: Optional name for the collection
        """
        # Generate a collection name if not provided
        if collection_name is None:
            collection_name = f"collection_{uuid.uuid4().hex[:8]}"
            
        # Initialize FAISS index from documents
        self.db = FAISS.from_documents(
            documents=chunks,
            embedding=self.embeddings
        )
        
        # Create directory if it doesn't exist
        os.makedirs(self.db_directory, exist_ok=True)
        
        # Save the index to disk with collection name in the path
        index_path = os.path.join(self.db_directory, collection_name)
        self.db.save_local(index_path)
        
        print(f"Stored {len(chunks)} chunks in collection '{collection_name}'")
        
        return collection_name

    def retrieve_chunks(
        self, 
        query: str, 
        n_results: int = 4,
        collection_name: Optional[str] = None
    ) -> List[Document]:
        """
        Retrieve relevant document chunks for a query.
        
        Args:
            query: The query string
            n_results: Number of chunks to retrieve
            collection_name: Optional collection to search in
            
        Returns:
            List of relevant document chunks
        """
        if self.db is None:
            raise ValueError("No database has been created or loaded.")
            
        # If collection name is provided, load that collection
        if collection_name:
            index_path = os.path.join(self.db_directory, collection_name)
            if os.path.exists(index_path):
                db = FAISS.load_local(
                    folder_path=index_path,
                    embeddings=self.embeddings,
                    allow_dangerous_deserialization=True
                )
            else:
                raise ValueError(f"Collection '{collection_name}' not found.")
        else:
            db = self.db
            
        # Retrieve chunks
        chunks = db.similarity_search(query, k=n_results)
        return chunks

    def query(
        self, 
        query: str, 
        n_results: int = 4,
        collection_name: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Perform a query using the RAG pipeline.
        
        Args:
            query: The query string
            n_results: Number of chunks to retrieve
            collection_name: Optional collection to search in
            
        Returns:
            Dictionary with the query result and relevant chunks
        """
        if self.db is None:
            raise ValueError("No database has been created or loaded.")
            
        # If collection name is provided, load that collection
        if collection_name:
            index_path = os.path.join(self.db_directory, collection_name)
            if os.path.exists(index_path):
                db = FAISS.load_local(
                    folder_path=index_path,
                    embeddings=self.embeddings,
                    allow_dangerous_deserialization=True
                )
            else:
                raise ValueError(f"Collection '{collection_name}' not found.")
        else:
            db = self.db
            
        # Create a retriever
        retriever = db.as_retriever(search_kwargs={"k": n_results})
        
        # Create a prompt template
        template = """
        You are an assistant that answers questions based on the provided context.
        
        Context:
        {context}
        
        Question:
        {question}
        
        Answer:
        """
        
        prompt = PromptTemplate(
            input_variables=["context", "question"],
            template=template
        )
        
        # Create a QA chain
        qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=retriever,
            chain_type_kwargs={"prompt": prompt}
        )
        
        # Execute the query
        result = qa_chain({"query": query})
        
        # Also retrieve the chunks for reference
        chunks = self.retrieve_chunks(query, n_results, collection_name)
        
        return {
            "query": query,
            "answer": result["result"],
            "chunks": chunks
        }

    def ingest_and_store(self, file_path: str, collection_name: Optional[str] = None) -> str:
        """
        Complete pipeline to ingest, process, and store documents.
        
        Args:
            file_path: Path to file or directory to ingest
            collection_name: Optional name for the collection
            
        Returns:
            Collection name
        """
        # Load documents
        documents = self.load_documents(file_path)
        
        # Process into chunks
        chunks = self.process_documents(documents)
        
        # Store in database
        collection_name = self.store_documents(chunks, collection_name)
        
        return collection_name


# Example usage
if __name__ == "__main__":
    # Initialize the RAG pipeline with local models
    pipeline = RAGPipeline(
        # Use smaller embedding model for faster processing
        embedding_model="sentence-transformers/all-MiniLM-L6-v2",
        # Options for LLM:
        # - "google/flan-t5-base" (smaller, faster)
        # - "google/flan-t5-large" (better quality)
        # - "tiiuae/falcon-7b-instruct" (more capable but requires more GPU memory)
        # - "TheBloke/Llama-2-7B-Chat-GGML" (quantized Llama model, good for CPU)
        llm_model="/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct",
        model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"}
    )
    
    # Ingest and store documents
    

  from .autonotebook import tqdm as notebook_tqdm
  self.embeddings = HuggingFaceEmbeddings(
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.31s/it]
Device set to use cuda:0
  self.llm = HuggingFacePipeline(pipeline=pipe)


In [3]:
collection_name = pipeline.ingest_and_store("/home/vidhij2/nivi/documents/")
    
    # Query the documents
result = pipeline.query("What is the main topic of these documents?", collection_name=collection_name)

print("\nQuery:", result["query"])
print("\nAnswer:", result["answer"])
print("\nRelevant chunks:")
for i, chunk in enumerate(result["chunks"]):
    print(f"\nChunk {i+1}:")
    print(chunk.page_content[:200] + "..." if len(chunk.page_content) > 200 else chunk.page_content)

Error loading file /home/vidhij2/nivi/documents/Care During Pregnancy and Childbirth Training Manual for CHO at AB-HWC.pdf


TypeError: TextLoader.lazy_load() missing 1 required positional argument: 'self'

In [12]:
tokenizer = AutoTokenizer.from_pretrained("/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
    "/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct")

Loading checkpoint shards: 100%|██████████| 2/2 [01:45<00:00, 52.81s/it]


In [4]:
llm_model = "/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(llm_model)
model = AutoModelForCausalLM.from_pretrained(llm_model)

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.49s/it]


Llama 3.2 3b and pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb embedding in the database

In [None]:
import os
import re
import uuid
import torch
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from pathlib import Path

# Document loading and processing
import PyPDF2
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize

# Vector database and embeddings
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter

# LLM and generation
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate


class MedicalPDFProcessor:
    """Process medical PDFs with specialized techniques for handling medical content."""
    
    def __init__(self):
        # Download necessary NLTK resources
        try:
            nltk.data.find('tokenizers/punkt')
        except LookupError:
            nltk.download('punkt', quiet=True)
        
        try:
            nltk.data.find('corpora/stopwords')
        except LookupError:
            nltk.download('stopwords', quiet=True)
            
        self.stop_words = set(stopwords.words('english'))
        
        # Medical-specific abbreviations and terms
        self.medical_abbreviations = {
            "pt": "patient", "pts": "patients", "dx": "diagnosis", 
            "tx": "treatment", "hx": "history", "fx": "fracture",
            "sx": "symptoms", "rx": "prescription", "appt": "appointment",
            "vs": "vital signs", "yo": "year old", "y/o": "year old",
            "labs": "laboratory tests", "hpi": "history of present illness",
            "w/": "with", "s/p": "status post", "c/o": "complains of",
            "p/w": "presents with", "h/o": "history of", "f/u": "follow up"
        }
        
    def extract_text_from_pdf(self, pdf_path: str) -> str:
        """Extract text from a PDF file with medical-specific preprocessing."""
        with open(pdf_path, 'rb') as file:
            reader = PyPDF2.PdfReader(file)
            text = ""
            
            # Extract text from each page
            for page in reader.pages:
                text += page.extract_text() + "\n"
                
        # Basic cleaning
        text = self._clean_text(text)
        
        return text
    
    def _clean_text(self, text: str) -> str:
        """Clean and normalize medical text."""
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Expand common medical abbreviations
        for abbr, expansion in self.medical_abbreviations.items():
            # Only replace when it's a whole word (with word boundaries)
            text = re.sub(r'\b' + re.escape(abbr) + r'\b', expansion, text, flags=re.IGNORECASE)
            
        # Normalize spacing after periods for better sentence splitting
        text = re.sub(r'\.(?! )', '. ', text)
        
        return text
    
    def split_into_sections(self, text: str) -> List[str]:
        """Split medical document into logical sections based on common headers."""
        common_sections = [
            "History", "Physical Examination", "Assessment", "Plan", "Diagnosis",
            "Chief Complaint", "Past Medical History", "Medications", "Allergies",
            "Family History", "Social History", "Review of Systems", "Labs",
            "Imaging", "Discussion", "Conclusion", "Recommendations"
        ]
        
        # Create regex pattern for section headers
        pattern = r'(?i)(?:^|\n)(' + '|'.join(re.escape(s) for s in common_sections) + r')(?::|:)?\s*(?:\n|\s)'
        
        # Find all section headers with their positions
        matches = list(re.finditer(pattern, text))
        
        sections = []
        
        # Extract each section
        for i, match in enumerate(matches):
            start = match.start()
            end = matches[i+1].start() if i < len(matches) - 1 else len(text)
            
            # Get the section header and content
            header = match.group(1)
            content = text[start:end].strip()
            
            # Add the section
            sections.append(f"{header}:\n{content}")
            
        # If no sections were identified, return the whole text as one section
        if not sections:
            sections = [text]
            
        return sections
    
    def process_pdf(self, pdf_path: str) -> List[Document]:
        """Process a medical PDF and return LangChain Document objects."""
        # Extract text
        text = self.extract_text_from_pdf(pdf_path)
        
        # Try to split into sections if possible
        sections = self.split_into_sections(text)
        
        # Create Document objects
        documents = []
        
        filename = os.path.basename(pdf_path)
        
        for i, section in enumerate(sections):
            # Create metadata to track source and section
            metadata = {
                "source": filename,
                "page": i,  # Using i as a proxy for page if real page info isn't available
                "section": section.split(":", 1)[0] if ":" in section else "General"
            }
            
            documents.append(Document(page_content=section, metadata=metadata))
            
        return documents


class MedicalRAGPipeline:
    """
    Retrieval-Augmented Generation pipeline specialized for medical documents.
    Uses FAISS for vector storage and optimized for medical domain content.
    """

    def __init__(
        self,
        embedding_model: str = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
        llm_model: str = "/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct",
        db_directory: str = "medical_db",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        use_gpu: bool = None
    ):
        """
        Initialize the Medical RAG pipeline.
        
        Args:
            embedding_model: Hugging Face model for embeddings (preferably biomedical)
            llm_model: Hugging Face model for generation (preferably with medical knowledge)
            db_directory: Directory to store the FAISS database
            chunk_size: Size of document chunks
            chunk_overlap: Overlap between chunks
            use_gpu: Whether to use GPU. If None, will auto-detect.
        """
        self.embedding_model = embedding_model
        self.llm_model = "/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct"
        self.db_directory = db_directory
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        
        # Determine device
        if use_gpu is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
            
        print(f"Using device: {self.device}")
        
        # Initialize PDF processor
        self.pdf_processor = MedicalPDFProcessor()
        
        # Initialize embeddings - use biomedical-specific models if available
        self.embeddings = HuggingFaceEmbeddings(
            model_name=embedding_model,
            model_kwargs={"device": self.device},
            encode_kwargs={"normalize_embeddings": True}
        )
        
        # Initialize text splitter with medical domain settings
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", " ", ""]
        )
        
        # Initialize LLM - use a model with medical knowledge if possible
        try:
            tokenizer = AutoTokenizer.from_pretrained(llm_model)
            model = AutoModelForCausalLM.from_pretrained(
                llm_model
            )
            
            # Create text generation pipeline
            pipe = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                max_new_tokens=512,
                temperature=0.1,  # Lower temperature for medical accuracy
                top_p=0.95,
                repetition_penalty=1.15
            )
            
            # Create LangChain wrapper
            self.llm = HuggingFacePipeline(pipeline=pipe)
            
        except Exception as e:
            print(f"Error loading LLM: {e}")
            print("Falling back to smaller model...")
            
            # Fallback to a smaller, more widely compatible model
            try:
                tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
                model = AutoModelForCausalLM.from_pretrained(
                    "google/flan-t5-base",
                    device_map=self.device,
                    torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
                )
                
                pipe = pipeline(
                    "text-generation",
                    model=model,
                    tokenizer=tokenizer,
                    max_new_tokens=512,
                    temperature=0.1
                )
                
                self.llm = HuggingFacePipeline(pipeline=pipe)
                
            except Exception as inner_e:
                print(f"Error loading fallback model: {inner_e}")
                raise RuntimeError("Could not initialize LLM. Please check model compatibility.")
        
        # Initialize or load vectorstore if it exists
        index_path = os.path.join(db_directory, "index.faiss")
        if os.path.exists(index_path):
            self.db = FAISS.load_local(
                folder_path=db_directory,
                embeddings=self.embeddings,
                allow_dangerous_deserialization=True
            )
        else:
            self.db = None

    def process_pdf(self, pdf_path: str) -> List[Document]:
        """
        Process a medical PDF into Document objects.
        
        Args:
            pdf_path: Path to the PDF file
            
        Returns:
            List of processed Document objects
        """
        # Extract raw documents from PDF
        raw_documents = self.pdf_processor.process_pdf(pdf_path)
        
        # Split into chunks
        chunks = self.text_splitter.split_documents(raw_documents)
        
        print(f"Processed {pdf_path} into {len(chunks)} chunks")
        
        return chunks

    def process_pdf_directory(self, directory_path: str) -> List[Document]:
        """
        Process all PDFs in a directory.
        
        Args:
            directory_path: Path to directory containing PDFs
            
        Returns:
            List of all Document chunks from all PDFs
        """
        all_chunks = []
        
        # Get all PDF files in the directory
        pdf_files = [f for f in os.listdir(directory_path) if f.lower().endswith('.pdf')]
        
        if not pdf_files:
            print(f"No PDF files found in {directory_path}")
            return all_chunks
        
        print(f"Found {len(pdf_files)} PDF files")
        
        # Process each PDF
        for pdf_file in pdf_files:
            pdf_path = os.path.join(directory_path, pdf_file)
            chunks = self.process_pdf(pdf_path)
            all_chunks.extend(chunks)
            
        return all_chunks

    def store_documents(self, chunks: List[Document], collection_name: Optional[str] = None) -> str:
        """
        Store document chunks in FAISS.
        
        Args:
            chunks: Document chunks to store
            collection_name: Optional name for the collection
            
        Returns:
            Collection name
        """
        # Generate a collection name if not provided
        if collection_name is None:
            collection_name = f"medical_{uuid.uuid4().hex[:8]}"
        
        # Create database directory if it doesn't exist
        index_path = os.path.join(self.db_directory, collection_name)
        os.makedirs(index_path, exist_ok=True)
        
        # Initialize FAISS index from documents
        self.db = FAISS.from_documents(
            documents=chunks,
            embedding=self.embeddings
        )
        
        # Save to disk
        self.db.save_local(index_path)
        
        print(f"Stored {len(chunks)} chunks in collection '{collection_name}'")
        
        return collection_name

    def retrieve_chunks(
        self, 
        query: str, 
        n_results: int = 5,  # Return more results for medical context
        collection_name: Optional[str] = None
    ) -> List[Document]:
        """
        Retrieve relevant document chunks for a medical query.
        
        Args:
            query: The query string
            n_results: Number of chunks to retrieve
            collection_name: Optional collection to search in
            
        Returns:
            List of relevant document chunks
        """
        if self.db is None:
            raise ValueError("No database has been created or loaded.")
            
        # If collection name is provided, load that collection
        if collection_name:
            index_path = os.path.join(self.db_directory, collection_name)
            if os.path.exists(index_path):
                db = FAISS.load_local(
                    folder_path=index_path,
                    embeddings=self.embeddings,
                    allow_dangerous_deserialization=True
                )
            else:
                raise ValueError(f"Collection '{collection_name}' not found.")
        else:
            db = self.db
            
        # Retrieve chunks with MMR for diversity
        chunks = db.max_marginal_relevance_search(
            query, 
            k=n_results,
            fetch_k=n_results*2  # Fetch more candidates for diversity
        )
        
        return chunks

    def query(
        self, 
        query: str, 
        n_results: int = 5,
        collection_name: Optional[str] = None,
        use_mmr: bool = True
    ) -> Dict[str, Any]:
        """
        Perform a medical query using the RAG pipeline.
        
        Args:
            query: The medical query string
            n_results: Number of chunks to retrieve
            collection_name: Optional collection to search in
            use_mmr: Whether to use Maximum Marginal Relevance for diverse results
            
        Returns:
            Dictionary with the query result and relevant chunks
        """
        if self.db is None:
            raise ValueError("No database has been created or loaded.")
            
        # If collection name is provided, load that collection
        if collection_name:
            index_path = os.path.join(self.db_directory, collection_name)
            if os.path.exists(index_path):
                db = FAISS.load_local(
                    folder_path=index_path,
                    embeddings=self.embeddings,
                    allow_dangerous_deserialization=True
                )
            else:
                raise ValueError(f"Collection '{collection_name}' not found.")
        else:
            db = self.db
            
        # Create the appropriate retriever
        search_kwargs = {"k": n_results}
        if use_mmr:
            search_type = "mmr"
            search_kwargs["fetch_k"] = n_results * 2  # Fetch more for diversity
        else:
            search_type = "similarity"
            
        retriever = db.as_retriever(
            search_type=search_type,
            search_kwargs=search_kwargs
        )
        
        # Create a medical-specific prompt template
        template = """
        You are a medical AI assistant. Answer the medical question based only on the following context.
        If you don't know the answer based on the context, admit that you don't know rather than making up information.
        Always maintain patient confidentiality and provide evidence-based answers when possible.
        
        Context:
        {context}
        
        Medical Question:
        {question}
        
        Answer:
        """
        
        prompt = PromptTemplate(
            input_variables=["context", "question"],
            template=template
        )
        
        # Create a QA chain
        qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=retriever,
            chain_type_kwargs={"prompt": prompt}
        )
        
        # Execute the query
        result = qa_chain({"query": query})
        
        # Also retrieve the chunks for reference
        chunks = self.retrieve_chunks(query, n_results, collection_name)
        
        return {
            "query": query,
            "answer": result["result"],
            "chunks": chunks
        }
        
    def ingest_and_store(
        self, 
        pdf_path: str, 
        collection_name: Optional[str] = None
    ) -> str:
        """
        Complete pipeline to ingest, process, and store PDF documents.
        
        Args:
            pdf_path: Path to PDF file or directory containing PDFs
            collection_name: Optional name for the collection
            
        Returns:
            Collection name
        """
        # Check if path is file or directory
        if os.path.isdir(pdf_path):
            # Process directory of PDFs
            chunks = self.process_pdf_directory(pdf_path)
        else:
            # Process single PDF
            chunks = self.process_pdf(pdf_path)
            
        if not chunks:
            raise ValueError(f"No content could be extracted from {pdf_path}")
            
        # Store chunks in FAISS
        collection_name = self.store_documents(chunks, collection_name)
        
        return collection_name


# Example usage
if __name__ == "__main__":
    # Initialize the Medical RAG pipeline
    pipeline = MedicalRAGPipeline(
        # Use PubMedBERT or BioBERT embeddings for medical content
        embedding_model="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
        # Use a medically trained LLM when possible
        llm_model="/data/models/huggingface/meta-llama/Llama-3.2-3B-Instruct",  # Not medical-specific but generally competent
    )
    
    # Ingest a directory of medical PDFs
    collection_name = pipeline.ingest_and_store("/home/vidhij2/nivi/documents")
    
    # Query the system
    result = pipeline.query(
        "What are the treatment options for Type 2 Diabetes?", 
        collection_name=collection_name
    )
    
    print("\nQuery:", result["query"])
    print("\nAnswer:", result["answer"])
    print("\nRelevant evidence:")
    for i, chunk in enumerate(result["chunks"][:3]):  # Show top 3 chunks
        print(f"\nSource {i+1}: {chunk.metadata.get('source', 'Unknown')}")
        print(f"Section: {chunk.metadata.get('section', 'Unknown')}")
        print(chunk.page_content[:200] + "..." if len(chunk.page_content) > 200 else chunk.page_content)

Using device: cuda


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Loading checkpoint shards: 100%|██████████| 2/2 [01:41<00:00, 50.56s/it]
Device set to use cuda:0
  self.llm = HuggingFacePipeline(pipeline=pipe)


Found 4 PDF files
Processed /home/vidhij2/nivi/documents/Care During Pregnancy and Childbirth Training Manual for CHO at AB-HWC.pdf into 183 chunks
Processed /home/vidhij2/nivi/documents/By-gestational month cards.pdf into 34 chunks
Processed /home/vidhij2/nivi/documents/9789240020306-eng.pdf into 253 chunks
Processed /home/vidhij2/nivi/documents/9789240045989-eng.pdf into 1003 chunks


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Stored 1473 chunks in collection 'medical_aecfc5bd'

Query: What are the treatment options for Type 2 Diabetes?

Answer: 
        You are a medical AI assistant. Answer the medical question based only on the following context.
        If you don't know the answer based on the context, admit that you don't know rather than making up information.
        Always maintain patient confidentiality and provide evidence-based answers when possible.
        
        Context:
        to use either pharmacological or non-pharmacological interventions as they adhere to strict dietary routines associated with traditional postnatal practices (low confidence in the evidence). Additional considerations Indirect evidence from a qualitative evidence synthesis exploring uptake of antenatal care (80) indicates that women in a variety of LMICs are more likely to turn to traditional healers, herbal remedies, or traditional birth attendants to treat constipation (moderate confidence). Feasibility A qualitati

In [10]:
collection_name = pipeline.ingest_and_store("/home/vidhij2/nivi/documents/")
    
    # Query the system
result = pipeline.query(
    "What are the treatment options for Type 2 Diabetes?", 
    collection_name=collection_name
)

print("\nQuery:", result["query"])
print("\nAnswer:", result["answer"])
print("\nRelevant evidence:")
for i, chunk in enumerate(result["chunks"][:3]):  # Show top 3 chunks
    print(f"\nSource {i+1}: {chunk.metadata.get('source', 'Unknown')}")
    print(f"Section: {chunk.metadata.get('section', 'Unknown')}")
    print(chunk.page_content[:200] + "..." if len(chunk.page_content) > 200 else chunk.page_content)

Found 4 PDF files
Processed /home/vidhij2/nivi/documents/Care During Pregnancy and Childbirth Training Manual for CHO at AB-HWC.pdf into 183 chunks
Processed /home/vidhij2/nivi/documents/By-gestational month cards.pdf into 34 chunks
Processed /home/vidhij2/nivi/documents/9789240020306-eng.pdf into 253 chunks
Processed /home/vidhij2/nivi/documents/9789240045989-eng.pdf into 1003 chunks


  result = qa_chain({"query": query})
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Stored 1473 chunks in collection 'medical_a0334338'

Query: What are the treatment options for Type 2 Diabetes?

Answer: 
        You are a medical AI assistant. Answer the medical question based only on the following context.
        If you don't know the answer based on the context, admit that you don't know rather than making up information.
        Always maintain patient confidentiality and provide evidence-based answers when possible.
        
        Context:
        to use either pharmacological or non-pharmacological interventions as they adhere to strict dietary routines associated with traditional postnatal practices (low confidence in the evidence). Additional considerations Indirect evidence from a qualitative evidence synthesis exploring uptake of antenatal care (80) indicates that women in a variety of LMICs are more likely to turn to traditional healers, herbal remedies, or traditional birth attendants to treat constipation (moderate confidence). Feasibility A qualitati

In [12]:
result = pipeline.query(
    "Is it normal to have morning sickness all day during pregnancy?", 
    collection_name=collection_name
)

print("\nQuery:", result["query"])
print("\nAnswer:", result["answer"])
print("\nRelevant evidence:")
for i, chunk in enumerate(result["chunks"][:3]):  # Show top 3 chunks
    print(f"\nSource {i+1}: {chunk.metadata.get('source', 'Unknown')}")
    print(f"Section: {chunk.metadata.get('section', 'Unknown')}")
    print(chunk.page_content[:200] + "..." if len(chunk.page_content) > 200 else chunk.page_content)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Query: Is it normal to have morning sickness all day during pregnancy?

Answer: 
        You are a medical AI assistant. Answer the medical question based only on the following context.
        If you don't know the answer based on the context, admit that you don't know rather than making up information.
        Always maintain patient confidentiality and provide evidence-based answers when possible.
        
        Context:
        which appears in the evening and disappears in the morning after a full night’s sleep, could be a normal manifestation of pregnancy. ∙Any oedema of the face, hands, abdominal wall, and vulva is abnormal. Oedema can be suspected if a woman complains of abnormal tightening of any rings on her fingers. ∙If there is oedema in association with high blood pressure, heart disease, anaemia or proteinuria, the woman should be referred to FRU. ∙Non-pitting oedema indicates hypothyroidism or filariasis and requires immediate referral to FRU for investigations. 20Mea

In [6]:
pip install nltk

Collecting nltk
  Downloading nltk-3.9.1-py3-none-any.whl.metadata (2.9 kB)
Downloading nltk-3.9.1-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: nltk
Successfully installed nltk-3.9.1
Note: you may need to restart the kernel to use updated packages.


# OpenAI and pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb embedding in the database

In [3]:
import os
import re
import uuid
import torch
import numpy as np
from typing import List, Dict, Any, Optional, Tuple, Union
from pathlib import Path

# Document loading and processing
import PyPDF2
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize

# Vector database and embeddings
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
# LLM and generation
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate


class MedicalPDFProcessor:
    """Process medical PDFs with specialized techniques for handling medical content."""
    
    def __init__(self):
        # Download necessary NLTK resources
        try:
            nltk.data.find('tokenizers/punkt')
        except LookupError:
            nltk.download('punkt', quiet=True)
        
        try:
            nltk.data.find('corpora/stopwords')
        except LookupError:
            nltk.download('stopwords', quiet=True)
            
        self.stop_words = set(stopwords.words('english'))
        
        # Medical-specific abbreviations and terms
        self.medical_abbreviations = {
            "pt": "patient", "pts": "patients", "dx": "diagnosis", 
            "tx": "treatment", "hx": "history", "fx": "fracture",
            "sx": "symptoms", "rx": "prescription", "appt": "appointment",
            "vs": "vital signs", "yo": "year old", "y/o": "year old",
            "labs": "laboratory tests", "hpi": "history of present illness",
            "w/": "with", "s/p": "status post", "c/o": "complains of",
            "p/w": "presents with", "h/o": "history of", "f/u": "follow up"
        }
        
    def extract_text_from_pdf(self, pdf_path: str) -> str:
        """Extract text from a PDF file with medical-specific preprocessing."""
        with open(pdf_path, 'rb') as file:
            reader = PyPDF2.PdfReader(file)
            text = ""
            
            # Extract text from each page
            for page in reader.pages:
                text += page.extract_text() + "\n"
                
        # Basic cleaning
        text = self._clean_text(text)
        
        return text
    
    def _clean_text(self, text: str) -> str:
        """Clean and normalize medical text."""
        # Remove extra whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        # Expand common medical abbreviations
        for abbr, expansion in self.medical_abbreviations.items():
            # Only replace when it's a whole word (with word boundaries)
            text = re.sub(r'\b' + re.escape(abbr) + r'\b', expansion, text, flags=re.IGNORECASE)
            
        # Normalize spacing after periods for better sentence splitting
        text = re.sub(r'\.(?! )', '. ', text)
        
        return text
    
    def split_into_sections(self, text: str) -> List[str]:
        """Split medical document into logical sections based on common headers."""
        common_sections = [
            "History", "Physical Examination", "Assessment", "Plan", "Diagnosis",
            "Chief Complaint", "Past Medical History", "Medications", "Allergies",
            "Family History", "Social History", "Review of Systems", "Labs",
            "Imaging", "Discussion", "Conclusion", "Recommendations"
        ]
        
        # Create regex pattern for section headers
        pattern = r'(?i)(?:^|\n)(' + '|'.join(re.escape(s) for s in common_sections) + r')(?::|:)?\s*(?:\n|\s)'
        
        # Find all section headers with their positions
        matches = list(re.finditer(pattern, text))
        
        sections = []
        
        # Extract each section
        for i, match in enumerate(matches):
            start = match.start()
            end = matches[i+1].start() if i < len(matches) - 1 else len(text)
            
            # Get the section header and content
            header = match.group(1)
            content = text[start:end].strip()
            
            # Add the section
            sections.append(f"{header}:\n{content}")
            
        # If no sections were identified, return the whole text as one section
        if not sections:
            sections = [text]
            
        return sections
    
    def process_pdf(self, pdf_path: str) -> List[Document]:
        """Process a medical PDF and return LangChain Document objects."""
        # Extract text
        text = self.extract_text_from_pdf(pdf_path)
        
        # Try to split into sections if possible
        sections = self.split_into_sections(text)
        
        # Create Document objects
        documents = []
        
        filename = os.path.basename(pdf_path)
        
        for i, section in enumerate(sections):
            # Create metadata to track source and section
            metadata = {
                "source": filename,
                "page": i,  # Using i as a proxy for page if real page info isn't available
                "section": section.split(":", 1)[0] if ":" in section else "General"
            }
            
            documents.append(Document(page_content=section, metadata=metadata))
            
        return documents


class MedicalRAGPipeline:
    """
    Retrieval-Augmented Generation pipeline specialized for medical documents.
    Uses FAISS for vector storage and optimized for medical domain content.
    """

    def __init__(
        self,
        embedding_model: str = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
        llm_model: str = "gpt-4-turbo",
        db_directory: str = "medical_db",
        chunk_size: int = 1000,
        chunk_overlap: int = 200,
        use_gpu: bool = None
    ):
        """
        Initialize the Medical RAG pipeline.
        
        Args:
            embedding_model: Hugging Face model for embeddings (preferably biomedical)
            llm_model: Hugging Face model for generation (preferably with medical knowledge)
            db_directory: Directory to store the FAISS database
            chunk_size: Size of document chunks
            chunk_overlap: Overlap between chunks
            use_gpu: Whether to use GPU. If None, will auto-detect.
        """
        self.embedding_model = embedding_model
        self.llm_model = "gpt-4-turbo"
        self.db_directory = db_directory
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        
        self.llm = ChatOpenAI(model_name= self.llm_model)
        # Determine device
        if use_gpu is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
            
        print(f"Using device: {self.device}")
        
        # Initialize PDF processor
        self.pdf_processor = MedicalPDFProcessor()
        
        # Initialize embeddings - use biomedical-specific models if available
        self.embeddings = HuggingFaceEmbeddings(
            model_name=embedding_model,
            model_kwargs={"device": self.device},
            encode_kwargs={"normalize_embeddings": True}
        )
        
        # Initialize text splitter with medical domain settings
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", " ", ""]
        )
        
        # Initialize LLM - use a model with medical knowledge if possible
        # try:
        #     tokenizer = AutoTokenizer.from_pretrained(llm_model)
        #     model = AutoModelForCausalLM.from_pretrained(
        #         llm_model
        #     )
            
        #     # Create text generation pipeline
        #     pipe = pipeline(
        #         "text-generation",
        #         model=model,
        #         tokenizer=tokenizer,
        #         max_new_tokens=512,
        #         temperature=0.1,  # Lower temperature for medical accuracy
        #         top_p=0.95,
        #         repetition_penalty=1.15
        #     )
            
        #     # Create LangChain wrapper
        #     self.llm = HuggingFacePipeline(pipeline=pipe)
            
        # except Exception as e:
        #     print(f"Error loading LLM: {e}")
        #     print("Falling back to smaller model...")
            
        #     # Fallback to a smaller, more widely compatible model
        #     try:
        #         tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
        #         model = AutoModelForCausalLM.from_pretrained(
        #             "google/flan-t5-base",
        #             device_map=self.device,
        #             torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        #         )
                
        #         pipe = pipeline(
        #             "text-generation",
        #             model=model,
        #             tokenizer=tokenizer,
        #             max_new_tokens=512,
        #             temperature=0.1
        #         )
                
        #         self.llm = HuggingFacePipeline(pipeline=pipe)
                
        #     except Exception as inner_e:
        #         print(f"Error loading fallback model: {inner_e}")
        #         raise RuntimeError("Could not initialize LLM. Please check model compatibility.")
        
        # Initialize or load vectorstore if it exists
        index_path = os.path.join(db_directory, "index.faiss")
        if os.path.exists(index_path):
            self.db = FAISS.load_local(
                folder_path=db_directory,
                embeddings=self.embeddings,
                allow_dangerous_deserialization=True
            )
        else:
            self.db = None

    def process_pdf(self, pdf_path: str) -> List[Document]:
        """
        Process a medical PDF into Document objects.
        
        Args:
            pdf_path: Path to the PDF file
            
        Returns:
            List of processed Document objects
        """
        # Extract raw documents from PDF
        raw_documents = self.pdf_processor.process_pdf(pdf_path)
        
        # Split into chunks
        chunks = self.text_splitter.split_documents(raw_documents)
        
        print(f"Processed {pdf_path} into {len(chunks)} chunks")
        
        return chunks

    def process_pdf_directory(self, directory_path: str) -> List[Document]:
        """
        Process all PDFs in a directory.
        
        Args:
            directory_path: Path to directory containing PDFs
            
        Returns:
            List of all Document chunks from all PDFs
        """
        all_chunks = []
        
        # Get all PDF files in the directory
        pdf_files = [f for f in os.listdir(directory_path) if f.lower().endswith('.pdf')]
        
        if not pdf_files:
            print(f"No PDF files found in {directory_path}")
            return all_chunks
        
        print(f"Found {len(pdf_files)} PDF files")
        
        # Process each PDF
        for pdf_file in pdf_files:
            pdf_path = os.path.join(directory_path, pdf_file)
            chunks = self.process_pdf(pdf_path)
            all_chunks.extend(chunks)
            
        return all_chunks

    def store_documents(self, chunks: List[Document], collection_name: Optional[str] = None) -> str:
        """
        Store document chunks in FAISS.
        
        Args:
            chunks: Document chunks to store
            collection_name: Optional name for the collection
            
        Returns:
            Collection name
        """
        # Generate a collection name if not provided
        if collection_name is None:
            collection_name = f"medical_{uuid.uuid4().hex[:8]}"
        
        # Create database directory if it doesn't exist
        index_path = os.path.join(self.db_directory, collection_name)
        os.makedirs(index_path, exist_ok=True)
        
        # Initialize FAISS index from documents
        self.db = FAISS.from_documents(
            documents=chunks,
            embedding=self.embeddings
        )
        
        # Save to disk
        self.db.save_local(index_path)
        
        print(f"Stored {len(chunks)} chunks in collection '{collection_name}'")
        
        return collection_name

    def retrieve_chunks(
        self, 
        query: str, 
        n_results: int = 5,  # Return more results for medical context
        collection_name: Optional[str] = None
    ) -> List[Document]:
        """
        Retrieve relevant document chunks for a medical query.
        
        Args:
            query: The query string
            n_results: Number of chunks to retrieve
            collection_name: Optional collection to search in
            
        Returns:
            List of relevant document chunks
        """
        if self.db is None:
            raise ValueError("No database has been created or loaded.")
            
        # If collection name is provided, load that collection
        if collection_name:
            index_path = os.path.join(self.db_directory, collection_name)
            if os.path.exists(index_path):
                db = FAISS.load_local(
                    folder_path=index_path,
                    embeddings=self.embeddings,
                    allow_dangerous_deserialization=True
                )
            else:
                raise ValueError(f"Collection '{collection_name}' not found.")
        else:
            db = self.db
            
        # Retrieve chunks with MMR for diversity
        chunks = db.max_marginal_relevance_search(
            query, 
            k=n_results,
            fetch_k=n_results*2  # Fetch more candidates for diversity
        )
        
        return chunks

    def query(
        self, 
        query: str, 
        n_results: int = 5,
        collection_name: Optional[str] = None,
        use_mmr: bool = True
    ) -> Dict[str, Any]:
        """
        Perform a medical query using the RAG pipeline.
        
        Args:
            query: The medical query string
            n_results: Number of chunks to retrieve
            collection_name: Optional collection to search in
            use_mmr: Whether to use Maximum Marginal Relevance for diverse results
            
        Returns:
            Dictionary with the query result and relevant chunks
        """
        if self.db is None:
            raise ValueError("No database has been created or loaded.")
            
        # If collection name is provided, load that collection
        if collection_name:
            index_path = os.path.join(self.db_directory, collection_name)
            if os.path.exists(index_path):
                db = FAISS.load_local(
                    folder_path=index_path,
                    embeddings=self.embeddings,
                    allow_dangerous_deserialization=True
                )
            else:
                raise ValueError(f"Collection '{collection_name}' not found.")
        else:
            db = self.db
            
        # Create the appropriate retriever
        search_kwargs = {"k": n_results}
        if use_mmr:
            search_type = "mmr"
            search_kwargs["fetch_k"] = n_results * 2  # Fetch more for diversity
        else:
            search_type = "similarity"
            
        retriever = db.as_retriever(
            search_type=search_type,
            search_kwargs=search_kwargs
        )
        
        # Create a medical-specific prompt template
        template = """
        You are a medical AI assistant. Answer the medical question based only on the following context.
        If you don't know the answer based on the context, admit that you don't know rather than making up information.
        Always maintain patient confidentiality and provide evidence-based answers when possible.
        
        Context:
        {context}
        
        Medical Question:
        {question}
        
        Answer:
        """
        
        prompt = PromptTemplate(
            input_variables=["context", "question"],
            template=template
        )
        
        # Create a QA chain
        qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=retriever,
            chain_type_kwargs={"prompt": prompt}
        )
        
        # Execute the query
        result = qa_chain({"query": query})
        
        # Also retrieve the chunks for reference
        chunks = self.retrieve_chunks(query, n_results, collection_name)
        
        return {
            "query": query,
            "answer": result["result"],
            "chunks": chunks
        }
        
    def ingest_and_store(
        self, 
        pdf_path: str, 
        collection_name: Optional[str] = None
    ) -> str:
        """
        Complete pipeline to ingest, process, and store PDF documents.
        
        Args:
            pdf_path: Path to PDF file or directory containing PDFs
            collection_name: Optional name for the collection
            
        Returns:
            Collection name
        """
        # Check if path is file or directory
        if os.path.isdir(pdf_path):
            # Process directory of PDFs
            chunks = self.process_pdf_directory(pdf_path)
        else:
            # Process single PDF
            chunks = self.process_pdf(pdf_path)
            
        if not chunks:
            raise ValueError(f"No content could be extracted from {pdf_path}")
            
        # Store chunks in FAISS
        collection_name = self.store_documents(chunks, collection_name)
        
        return collection_name


# Example usage
if __name__ == "__main__":
    # Initialize the Medical RAG pipeline
    pipeline = MedicalRAGPipeline(
        # Use PubMedBERT or BioBERT embeddings for medical content
        embedding_model="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb",
        # Use a medically trained LLM when possible
        llm_model="gpt-4-turbo",  # Not medical-specific but generally competent
    )
    
    # Ingest a directory of medical PDFs
    collection_name = pipeline.ingest_and_store("/home/vidhij2/nivi/documents")
    
    # Query the system
    result = pipeline.query(
        "What are the treatment options for Type 2 Diabetes?", 
        collection_name=collection_name
    )
    
    print("\nQuery:", result["query"])
    print("\nAnswer:", result["answer"])
    print("\nRelevant evidence:")
    for i, chunk in enumerate(result["chunks"][:3]):  # Show top 3 chunks
        print(f"\nSource {i+1}: {chunk.metadata.get('source', 'Unknown')}")
        print(f"Section: {chunk.metadata.get('section', 'Unknown')}")
        print(chunk.page_content[:200] + "..." if len(chunk.page_content) > 200 else chunk.page_content)

Using device: cuda
Found 4 PDF files
Processed /home/vidhij2/nivi/documents/Care During Pregnancy and Childbirth Training Manual for CHO at AB-HWC.pdf into 183 chunks
Processed /home/vidhij2/nivi/documents/By-gestational month cards.pdf into 34 chunks
Processed /home/vidhij2/nivi/documents/9789240020306-eng.pdf into 253 chunks
Processed /home/vidhij2/nivi/documents/9789240045989-eng.pdf into 1003 chunks
Stored 1473 chunks in collection 'medical_9a54d0f7'


  result = qa_chain({"query": query})



Query: What are the treatment options for Type 2 Diabetes?

Answer: Based on the provided context, there is no specific information about the treatment options for Type 2 Diabetes. Therefore, I cannot provide an answer based solely on the given context. For comprehensive and accurate information regarding the treatment of Type 2 Diabetes, consulting current medical guidelines or a healthcare professional would be advisable.

Relevant evidence:

Source 1: 9789240045989-eng.pdf
Section: WHO recommendations on maternal and newborn care for a positive postnatal experience WHO recommendations on maternal and newborn care for a positive postnatal experience WHO recommendations on maternal and newborn care for a positive postnatal experience This publication is the update of the document published in 2014 entitled “WHO recommendations on postnatal care of the mother and newborn”. ISBN 978-92-4-004598-9 (electronic version) ISBN 978-92-4-004599-6 (print version) © World Health Organization 20

In [6]:
result = pipeline.query(
        "What is morning sickness", 
        collection_name=collection_name
    )
    
print("\nQuery:", result["query"])
print("\nAnswer:", result["answer"])
print("\nRelevant evidence:")
for i, chunk in enumerate(result["chunks"][:3]):  # Show top 3 chunks
    print(f"\nSource {i+1}: {chunk.metadata.get('source', 'Unknown')}")
    print(f"Section: {chunk.metadata.get('section', 'Unknown')}")
    print(chunk.page_content[:200] + "..." if len(chunk.page_content) > 200 else chunk.page_content)


Query: What is morning sickness

Answer: Morning sickness refers to nausea and vomiting that typically occurs during pregnancy, most commonly in the first trimester. It can happen at any time of the day but is often worse in the morning. While the exact cause of morning sickness is not known, it is believed to be related to the hormonal changes occurring in pregnancy. For most women, morning sickness subsides by the second trimester. However, if it is severe, it may require medical attention to ensure the health and hydration of the mother.

Relevant evidence:

Source 1: Care During Pregnancy and Childbirth Training Manual for CHO at AB-HWC.pdf
Section: 1Training Manual on Care During Pregnancy and Child Birth for Community Health Officer at Ayushman Bharat - Health and Wellness Centres 2021 2 3Table of Contents Page No. Chapter 1
which appears in the evening and disappears in the morning after a full night’s sleep, could be a normal manifestation of pregnancy. ∙Any oedema of the face

In [None]:
import os
import re
import json
import pickle
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
import argparse
from tqdm import tqdm

# LangChain imports
from langchain.document_loaders import TextLoader, PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
from langchain.docstore.document import Document
from langchain.schema import BaseDocumentTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import os

class MedicalHeaderTextSplitter(MarkdownHeaderTextSplitter):
    """Custom text splitter for medical documents that respects section headers."""
    
    def __init__(self):
        headers_to_split_on = [
            ("#", "chapter"),
            ("##", "section"),
            ("###", "subsection"),
            ("####", "recommendation"),
            ("#####", "remarks")
        ]
        super().__init__(headers_to_split_on=headers_to_split_on)
        
    def _add_md_header(self, text):
        """Convert medical document headers to markdown format."""
        # Convert chapter headers
        text = re.sub(r'(?m)^(?:Chapter|CHAPTER)\s+(\d+)[\.:]?\s+(.+)$', r'# \1 \2', text)
        
        # Convert section headers
        text = re.sub(r'(?m)^(?:\d+\.\d+\.?\s+|\d+\.\s+)([A-Z][A-Za-z\s\-:]+)$', r'## \1', text)
        
        # Convert subsection headers
        text = re.sub(r'(?m)^(?:[A-Z]\.\d+\.?\s+|[A-Z]\.\s+)([A-Za-z][A-Za-z\s\-:]+)$', r'### \1', text)
        
        # Convert recommendation headers
        text = re.sub(
            r'(?m)^RECOMMENDATION\s+([A-Z0-9\.]+):\s+(.+?)(?:\((?:Recommended|Context-specific|Not recommended).*?\))?$', 
            r'#### RECOMMENDATION \1: \2', 
            text
        )
        
        # Convert remarks sections
        text = re.sub(r'(?m)^Remarks:$', r'##### Remarks:', text)
        
        return text

class MedicalEvidenceExtractor(BaseDocumentTransformer):
    """Extract evidence levels and recommendation types from medical text."""
    
    def __init__(self):
        self.evidence_pattern = re.compile(r'(?:high|moderate|low|very\s+low)(?:-|\s+)(?:quality|certainty)\s+evidence', re.IGNORECASE)
        self.recommendation_type_pattern = re.compile(r'\((Recommended|Context-specific recommendation|Not recommended).*?\)')
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Extract evidence levels and enhance document metadata."""
        for doc in documents:
            # Only process if it's a recommendation
            if doc.metadata.get('heading_type') == 'recommendation':
                # Extract evidence level
                evidence_match = self.evidence_pattern.search(doc.page_content)
                if evidence_match:
                    doc.metadata['evidence_level'] = evidence_match.group(0)
                
                # Extract recommendation type
                rec_type_match = self.recommendation_type_pattern.search(doc.page_content)
                if rec_type_match:
                    doc.metadata['recommendation_type'] = rec_type_match.group(1)
                
                # Extract recommendation ID
                rec_id_match = re.search(r'RECOMMENDATION\s+([A-Z0-9\.]+):', doc.page_content)
                if rec_id_match:
                    doc.metadata['recommendation_id'] = rec_id_match.group(1)
        
        return documents

class TableExtractor(BaseDocumentTransformer):
    """Extract tables as separate documents with metadata."""
    
    def transform_documents(
        self, documents: List[Document], **kwargs
    ) -> List[Document]:
        """Identify and mark table content."""
        table_pattern = re.compile(r'(Table\s+\d+[\.:]?\s+.*?)(?:\n\n|\Z)', re.DOTALL)
        
        result_docs = []
        for doc in documents:
            # Find tables in the document
            tables = table_pattern.findall(doc.page_content)
            
            # If tables found, create separate documents for them
            if tables:
                # Create a copy of the original document with tables removed
                modified_content = doc.page_content
                for table in tables:
                    modified_content = modified_content.replace(table, "")
                
                # Add the modified document if it still has significant content
                if len(modified_content.strip()) > 100:
                    modified_doc = Document(
                        page_content=modified_content,
                        metadata=doc.metadata.copy()
                    )
                    result_docs.append(modified_doc)
                
                # Add each table as a separate document
                for table in tables:
                    if len(table.strip()) > 50:  # Skip very small tables
                        table_doc = Document(
                            page_content=table,
                            metadata={
                                **doc.metadata.copy(),
                                "chunk_type": "table",
                                "parent_section_path": doc.metadata.get("section_path", [])
                            }
                        )
                        result_docs.append(table_doc)
            else:
                # No tables, keep the original document
                result_docs.append(doc)
                
        return result_docs

class MedicalDocumentProcessor:
    """Process medical documents into semantically meaningful chunks."""
    
    def __init__(
        self, 
        chunk_size: int = 1000, 
        chunk_overlap: int = 200,
        embedding_model_name: str = "text-embedding-ada-002"
    ):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.embedding_model_name = embedding_model_name
        
        # Initialize document processing pipeline
        self.header_splitter = MedicalHeaderTextSplitter()
        self.paragraph_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            separators=["\n\n", "\n", ". ", " ", ""]
        )
        self.evidence_extractor = MedicalEvidenceExtractor()
        self.table_extractor = TableExtractor()
        
        # Initialize embedding model
        self.embeddings = OpenAIEmbeddings(model=embedding_model_name)
    
    def _extract_document_metadata(self, text: str) -> Dict[str, str]:
        """Extract document-level metadata."""
        title = ""
        doc_type = ""
        
        # Extract title from WHO document
        title_match = re.search(r'(?:WHO|World Health Organization)\s+(?:recommendations|guidelines)\s+(?:on|for)\s+([A-Za-z\s\-,]+)', text[:3000])
        if title_match:
            title = title_match.group(1).strip()
            doc_type = "WHO Guidelines"
            
        # If no specific match, try to extract a general title
        if not title:
            title_match = re.search(r'^([A-Z][A-Za-z\s\-:,]+(?:Guidelines|Recommendations|Guidance))', text[:1000])
            if title_match:
                title = title_match.group(1).strip()
                doc_type = "Medical Guidelines"
        
        return {
            "title": title,
            "document_type": doc_type
        }
    
    def _build_section_path(self, doc: Document) -> List[str]:
        """Build the hierarchical section path for a document."""
        path = []
        
        # Add chapter if available
        if 'chapter' in doc.metadata:
            path.append(doc.metadata['chapter'])
            
        # Add section if available
        if 'section' in doc.metadata:
            path.append(doc.metadata['section'])
            
        # Add subsection if available
        if 'subsection' in doc.metadata:
            path.append(doc.metadata['subsection'])
        
        return path
    
    def process_text(self, text: str, source_name: str = "") -> List[Document]:
        """Process text into hierarchical chunks."""
        # Extract document metadata
        doc_metadata = self._extract_document_metadata(text)
        
        # Add source information to metadata
        if source_name:
            doc_metadata["source"] = source_name
        
        # Convert headers to markdown format for the splitter
        md_text = self.header_splitter._add_md_header(text)
        
        # Split on headers
        docs = self.header_splitter.split_text(md_text)
        
        # Extract evidence levels and recommendation metadata
        docs = self.evidence_extractor.transform_documents(docs)
        
        # Extract tables
        docs = self.table_extractor.transform_documents(docs)
        
        # Build section paths for each document
        for doc in docs:
            section_path = self._build_section_path(doc)
            doc.metadata['section_path'] = section_path
            
            # Add document metadata
            doc.metadata['document_title'] = doc_metadata.get('title', '')
            doc.metadata['document_type'] = doc_metadata.get('document_type', '')
            
            # Determine chunk type if not already set
            if 'chunk_type' not in doc.metadata:
                if doc.metadata.get('heading_type') == 'recommendation':
                    doc.metadata['chunk_type'] = 'recommendation'
                elif doc.metadata.get('heading_type') == 'remarks':
                    doc.metadata['chunk_type'] = 'remarks'
                else:
                    doc.metadata['chunk_type'] = 'text'
        
        # Further split large chunks while preserving metadata
        final_docs = []
        for doc in docs:
            # Don't split recommendation or remarks sections
            if doc.metadata.get('chunk_type') in ['recommendation', 'remarks', 'table']:
                final_docs.append(doc)
            else:
                # Split text sections into smaller chunks
                if len(doc.page_content) > self.chunk_size:
                    smaller_chunks = self.paragraph_splitter.split_text(doc.page_content)
                    for i, chunk in enumerate(smaller_chunks):
                        chunk_doc = Document(
                            page_content=chunk,
                            metadata={
                                **doc.metadata,
                                'chunk_index': i,
                                'total_chunks': len(smaller_chunks)
                            }
                        )
                        final_docs.append(chunk_doc)
                else:
                    final_docs.append(doc)
        
        return final_docs
    
    def load_documents(self, input_path: str) -> List[Document]:
        """Load documents from file or directory."""
        documents = []
        
        if os.path.isdir(input_path):
            # Process all files in directory
            for filename in os.listdir(input_path):
                file_path = os.path.join(input_path, filename)
                if os.path.isfile(file_path):
                    documents.extend(self._load_single_document(file_path))
        else:
            # Process single file
            documents = self._load_single_document(input_path)
        
        return documents
    
    def _load_single_document(self, file_path: str) -> List[Document]:
        """Load and process a single document."""
        print(f"Processing {file_path}...")
        
        # Load the document
        if file_path.lower().endswith('.pdf'):
            loader = PyPDFLoader(file_path)
            pages = loader.load()
            text = "\n\n".join([page.page_content for page in pages])
        else:
            with open(file_path, 'r', encoding='utf-8') as f:
                text = f.read()
        
        # Process the text
        source_name = os.path.basename(file_path)
        return self.process_text(text, source_name)
    
    def create_vector_store(self, documents: List[Document], persist_directory: str = None) -> FAISS:
        """Create a FAISS vector store from processed documents."""
        # Create FAISS index from documents
        db = FAISS.from_documents(
            documents=documents,
            embedding=self.embeddings
        )
        
        # Save to disk if persist_directory is provided
        if persist_directory:
            # Ensure the directory exists
            os.makedirs(persist_directory, exist_ok=True)
            
            # Save FAISS index
            index_path = os.path.join(persist_directory, "faiss_index")
            db.save_local(index_path)
            
            # Save documents separately to preserve metadata
            docs_path = os.path.join(persist_directory, "documents.pkl")
            with open(docs_path, 'wb') as f:
                pickle.dump(documents, f)
                
            print(f"Vector database persisted to {persist_directory}")
        
        return db
    
    def load_vector_store(self, persist_directory: str) -> FAISS:
        """Load an existing FAISS vector store."""
        # Ensure the paths exist
        index_path = os.path.join(persist_directory, "faiss_index")
        
        # Load FAISS index
        db = FAISS.load_local(index_path, self.embeddings)
        
        # Optionally load documents to restore metadata
        docs_path = os.path.join(persist_directory, "documents.pkl")
        if os.path.exists(docs_path):
            with open(docs_path, 'rb') as f:
                documents = pickle.load(f)
                
            # Map the loaded documents to the FAISS index for accurate retrieval
            db.docstore.documents = {str(i): doc for i, doc in enumerate(documents)}
            
        return db
    

class MedicalRAG:
    """RAG system for medical documents with advanced filtering and retrieval using FAISS."""
    
    def __init__(
        self,
        vector_store: FAISS,
        model_name: str = "gpt-4-turbo",
        temperature: float = 0.0,
    ):
        self.vector_store = vector_store
        self.llm = ChatOpenAI(model_name=model_name, temperature=temperature)
        
        # Create RAG prompt template
        self.rag_prompt = PromptTemplate(
            template="""You are a medical information assistant that helps healthcare professionals by providing evidence-based information from medical guidelines and literature.

Context information from medical documents:
{context}

Question: {question}

Instructions:
1. Answer based only on the provided context. If the information isn't in the context, say "I don't have enough information to answer this question based on the provided medical guidelines."
2. Cite specific recommendations, evidence levels, and document sources when available.
3. Be concise but comprehensive.
4. If multiple recommendations or conflicting guidance exists in the context, present all perspectives.
5. When quoting recommendations, preserve their exact wording.

Answer:""",
            input_variables=["context", "question"]
        )
        
        # Setup retrieval chain with FAISS
        self.retriever = vector_store.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 6}
        )
        
        self.qa_chain = RetrievalQA.from_chain_type(
            llm=self.llm,
            chain_type="stuff",
            retriever=self.retriever,
            return_source_documents=True,
            chain_type_kwargs={"prompt": self.rag_prompt}
        )
    
    def query(self, question: str, metadata_filters: Dict[str, Any] = None) -> Dict[str, Any]:
        """Query the RAG system with optional metadata filtering."""
        # Note: FAISS doesn't support native metadata filtering like Chroma
        # We'll retrieve more documents and filter them manually
        
        # Get the original result
        result = self.qa_chain({"query": question})
        
        # If metadata filters are applied, we need to manually filter the results
        source_documents = result.get("source_documents", [])
        
        if metadata_filters:
            filtered_docs = []
            for doc in source_documents:
                is_match = True
                for key, value in metadata_filters.items():
                    # Handle regex patterns in metadata filters
                    if isinstance(value, dict) and "$regex" in value:
                        pattern = value["$regex"]
                        if key not in doc.metadata or not re.search(pattern, str(doc.metadata.get(key, ''))):
                            is_match = False
                            break
                    # Handle exact matches
                    elif key not in doc.metadata or doc.metadata[key] != value:
                        is_match = False
                        break
                
                if is_match:
                    filtered_docs.append(doc)
            
            # If we have filtered docs, create a new result
            if filtered_docs:
                # Join filtered documents content
                filtered_context = "\n\n".join([doc.page_content for doc in filtered_docs])
                
                # Get a new answer using the filtered context
                prompt = self.rag_prompt.format(context=filtered_context, question=question)
                result["result"] = self.llm.invoke(prompt).content
                result["source_documents"] = filtered_docs
        
        # Format response
        sources = []
        for doc in result.get("source_documents", []):
            # Add source information
            source_info = {
                "content": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content,
                "metadata": doc.metadata
            }
            sources.append(source_info)
        
        return {
            "question": question,
            "answer": result["result"],
            "sources": sources
        }
    
    def query_recommendations(self, question: str, evidence_level: str = None) -> Dict[str, Any]:
        """Query specifically for medical recommendations with optional evidence filtering."""
        filters = {"chunk_type": "recommendation"}
        
        if evidence_level:
            # Filter by evidence level (e.g., "high", "moderate", "low")
            filters["evidence_level"] = {"$regex": f"{evidence_level.lower()}.*evidence"}
        
        return self.query(question, filters)

# Example usage
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Medical RAG with FAISS")
    parser.add_argument("--input", type=str, required=True, help="Input file or directory path")
    parser.add_argument("--output", type=str, default="./faiss_index", help="Directory to save the FAISS index")
    parser.add_argument("--query", type=str, help="Query to run against the loaded index")
    
    args = parser.parse_args()
    
    # Initialize processor
    processor = MedicalDocumentProcessor()
    
    # If the output directory exists and contains a FAISS index, load it
    if os.path.exists(args.output) and os.path.exists(os.path.join(args.output, "faiss_index")):
        print(f"Loading existing FAISS index from {args.output}")
        db = processor.load_vector_store(args.output)
    else:
        # Process documents and create a new index
        print(f"Processing documents from {args.input}")
        documents = processor.load_documents(args.input)
        print(f"Processed {len(documents)} document chunks")
        
        # Create and save vector store
        db = processor.create_vector_store(documents, args.output)
    
    # Initialize RAG system
    rag = MedicalRAG(db)
    
    # Run a query if provided
    if args.query:
        result = rag.query(args.query)
        print("\nQuestion:", result["question"])
        print("\nAnswer:", result["answer"])
        print("\nSources:")
        for i, source in enumerate(result["sources"]):
            print(f"\nSource {i+1}:")
            print(f"Content: {source['content']}")
            print(f"Metadata: {source['metadata']}")