In [26]:
import os
import numpy as np
import fitz
from PIL import Image
import io
import base64
from sentence_transformers import SentenceTransformer
import faiss
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
from nltk.tokenize import sent_tokenize
import tempfile
import torch

In [27]:
nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [30]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base", device_map="auto", torch_dtype=torch.float16)

In [31]:
def extract_content_from_pdf(pdf_path, output_dir='extracted_images'):
    """
    Extracts text and images from a PDF file.

    Args:
        pdf_path (str): Path to the PDF file
        output_dir (str): Directory to save extracted images

    Returns:
        Tuple[List[Dict], List[Dict]]: Text data and image data
    """
    os.makedirs(output_dir, exist_ok=True)
    text_data = []
    image_paths = []

    with fitz.open(pdf_path) as pdf:
        for page_num in range(len(pdf)):
            page = pdf[page_num]
            text = page.get_text().strip()
            if text:
                text_data.append({
                    'content': text,
                    'metadata': {'source': pdf_path, 'page': page_num + 1, 'type': 'text'}
                })

            image_list = page.get_images(full=True)
            for img_index, img in enumerate(image_list):
                xref = img[0]
                base_image = pdf.extract_image(xref)
                if base_image:
                    image_bytes = base_image['image']
                    image_ext = base_image['ext']
                    img_filename = f"page_{page_num+1}_img_{img_index+1}.{image_ext}"
                    img_path = os.path.join(output_dir, img_filename)
                    with open(img_path, 'wb') as img_file:
                        img_file.write(image_bytes)
                    image_paths.append({
                        'path': img_path,
                        'metadata': {'source': pdf_path, 'page': page_num + 1, 'image_index': img_index + 1, 'type': 'image'}
                    })

    return text_data, image_paths

In [32]:
def semantic_chunking(text_data, chunk_size=500, overlap=50, percentile_threshold=85):
    """
    Splits text into semantic chunks with overlap.

    Args:
        text_data (List[Dict]): Extracted text data
        chunk_size (int): Approximate size of each chunk in characters
        overlap (int): Number of characters to overlap between chunks
        percentile_threshold (float): Percentile for similarity breakpoint

    Returns:
        List[Dict]: Chunked text data
    """
    chunked_data = []

    for item in text_data:
        text = item['content']
        metadata = item['metadata']

        sentences = sent_tokenize(text)
        if len(sentences) < 2:
            chunked_data.append({'content': text, 'metadata': metadata})
            continue

        sentence_embeddings = embedder.encode(sentences)
        similarities = [
            np.dot(sentence_embeddings[i], sentence_embeddings[i+1]) /
            (np.linalg.norm(sentence_embeddings[i]) * np.linalg.norm(sentence_embeddings[i+1]))
            for i in range(len(sentence_embeddings)-1)
        ]

        threshold = np.percentile(similarities, percentile_threshold)
        breakpoints = [i for i, sim in enumerate(similarities) if sim < threshold]

        start = 0
        current_chunk = []
        current_length = 0

        for i, sentence in enumerate(sentences):
            current_chunk.append(sentence)
            current_length += len(sentence)

            if i in breakpoints or current_length >= chunk_size:
                chunk_text = ' '.join(current_chunk)
                if len(chunk_text) > chunk_size // 2:
                    chunked_data.append({
                        'content': chunk_text,
                        'metadata': {**metadata, 'chunk_index': len(chunked_data)}
                    })

                # Handle overlap
                overlap_sentences = []
                overlap_length = 0
                for s in current_chunk[::-1]:
                    if overlap_length + len(s) <= overlap:
                        overlap_sentences.insert(0, s)
                        overlap_length += len(s)
                    else:
                        break
                current_chunk = overlap_sentences
                current_length = overlap_length

        if current_chunk:
            chunk_text = ' '.join(current_chunk)
            if len(chunk_text) > chunk_size // 2:
                chunked_data.append({
                    'content': chunk_text,
                    'metadata': {**metadata, 'chunk_index': len(chunked_data)}
                })
    return chunked_data

In [33]:
def generate_image_caption(image_path):
    """
    Generates a caption for an image using a simple description (placeholder).

    Args:
        image_path (str): Path to the image file

    Returns:
        str: Generated caption
    """
    try:
        img = Image.open(image_path)
        return f"Image from page {os.path.basename(image_path).split('_')[1]} containing academic content, likely a chart or diagram."
    except Exception as e:
        return f"Error generating caption: {str(e)}"

In [34]:
def process_images(image_paths):
    """
    Processes images and generates captions.

    Args:
        image_paths (List[Dict]): Paths to extracted images

    Returns:
        List[Dict]: Image data with captions
    """
    image_data = []
    for img_item in image_paths:
        caption = generate_image_caption(img_item['path'])
        image_data.append({
            'content': caption,
            'metadata': img_item['metadata'],
            'image_path': img_item['path']
        })
    return image_data

In [37]:
class VectorStore:
    def __init__(self):
        self.index = None
        self.contents = []
        self.metadata = []

    def add_items(self, items, embeddings):
        """
        Adds items and their embeddings to the FAISS index.

        Args:
            items (List[Dict]): Content items
            embeddings (List[np.ndarray]): Corresponding embeddings
        """
        dimension = embeddings[0].shape[0]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(np.array(embeddings).astype('float32'))
        self.contents = [item['content'] for item in items]
        self.metadata = [item.get('metadata', {}) for item in items]

    def search(self, query_embedding, k=5):
        """
        Performs similarity search in the vector store.

        Args:
            query_embedding (np.ndarray): Query embedding
            k (int): Number of results to return

        Returns:
            List[Dict]: Top-k similar items
        """
        if self.index is None:
            return []

        distances, indices = self.index.search(np.array([query_embedding]).astype('float32'), k)
        results = []
        for idx, distance in zip(indices[0], distances[0]):
            results.append({
                'content': self.contents[idx],
                'metadata': self.metadata[idx],
                'similarity': 1 / (1 + distance)
            })
        return results

In [38]:
def process_document(pdf_path, chunk_size=500, percentile_threshold=90):
    """
    Processes a PDF document for RAG.

    Args:
        pdf_path (str): Path to the PDF
        chunk_size (int): Approximate chunk size
        percentile_threshold (float): Percentile for semantic chunking

    Returns:
        Tuple[VectorStore, Dict]: Vector store and document info
    """
    text_data, image_paths = extract_content_from_pdf(pdf_path)
    chunked_text = semantic_chunking(text_data, chunk_size, percentile_threshold)
    image_data = process_images(image_paths)

    all_items = chunked_text + image_data
    embeddings = embedder.encode([item['content'] for item in all_items])

    vector_store = VectorStore()
    vector_store.add_items(all_items, embeddings)

    doc_info = {
        'text_count': len(chunked_text),
        'image_count': len(image_data),
        'total_items': len(all_items)
    }

    return vector_store, doc_info

In [40]:
def query_rag(query, vector_store, k=5):
    """
    Processes a query using the RAG system.

    Args:
        query (str): User query
        vector_store (VectorStore): Vector store with document content
        k (int): Number of results to retrieve

    Returns:
        Dict: Query results and generated response
    """
    query_embedding = embedder.encode([query])[0]
    results = vector_store.search(query_embedding, k)

    context = '\n\n'.join(
        f"[{'Text' if r['metadata']['type'] == 'text' else 'Image Caption'} from page {r['metadata'].get('page', 'unknown')}]:\n{r['content']}"
        for r in results
    )

    prompt = f"""You are an AI assistant specializing in answering questions based on clinical and mental health documents. Use the provided context to answer the query accurately. If the context is insufficient, state so clearly.
Query: {query}
Context:
{context}
Answer: """
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(model.device)
    outputs = model.generate(
        **inputs,
        max_length=200,
        num_beams=4,
        temperature=0.1,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response = response.split('Answer:')[-1].strip() if 'Answer:' in response else response.strip()

    return {
        #'query': query,
        'response': response,
        #'results': results,
        #'text_results_count': len([r for r in results if r['metadata']['type'] == 'text']),
        #'image_results_count': len([r for r in results if r['metadata']['type'] == 'image'])
    }

In [45]:
pdf_path = "9241544228_eng.pdf"
vector_store, doc_info = process_document(pdf_path)

query = "Give me the correct coded classification for the following diagnosis: Recurrent depressive disorder, currently in remission"
result = query_rag(query, vector_store)
print(result)

{'response': 'Recurrent depressive disorder, currently in remission'}
