<a href="https://colab.research.google.com/github/saivigneshmn/research-assistant/blob/main/docs/Scientific_Paper_QA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# FINAL WITH 5 LINES ANSWER

In [None]:
!pip install PyPDF2 langchain sentence-transformers faiss-cpu

Collecting PyPDF2
  Downloading pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.11.0->sentence-transformers)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.11.0->sentence-tran

In [3]:
from google.colab import files
uploaded = files.upload()

Saving paper1.pdf to paper1.pdf


In [5]:
!mkdir -p ./pdfs
!mv *.pdf ./pdfs

In [None]:
import os
import re
import PyPDF2
import requests
import json
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Tuple

class CitationAwareResearchAssistant:
    def __init__(self, pdf_dir: str, groq_api_key: str, chunk_size: int = 500, chunk_overlap: int = 50):
        self.pdf_dir = pdf_dir
        self.groq_api_key = groq_api_key
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len
        )
        self.primary_embedder = SentenceTransformer('all-MiniLM-L6-v2')
        self.index = None
        self.chunks = []
        self.embeddings = None
        self.citation_map = {}

    def extract_citations(self, text: str) -> List[Tuple[int, re.Match]]:
        """Extract citation markers and their positions."""
        citation_pattern = r'\[\d+(?:,\d+)*\]|\(\w+,\s*\d{4}\)'
        return [(m.start(), m) for m in re.finditer(citation_pattern, text)]

    def citation_aware_chunking(self, text: str) -> List[str]:
        """Split text into chunks, ensuring citations stay with their context."""
        initial_chunks = self.text_splitter.split_text(text)
        final_chunks = []
        current_chunk = ""
        current_length = 0

        for chunk in initial_chunks:
            citations = self.extract_citations(chunk)
            if citations:
                for pos, citation_match in citations:
                    sentence_end = chunk.rfind('.', 0, pos) + 1
                    if sentence_end == 0:
                        sentence_end = pos
                    context_chunk = chunk[:sentence_end] + citation_match.group()
                    if current_length + len(context_chunk) <= self.chunk_size:
                        current_chunk += context_chunk
                        current_length += len(context_chunk)
                    else:
                        final_chunks.append(current_chunk)
                        current_chunk = context_chunk
                        current_length = len(context_chunk)
            else:
                if current_length + len(chunk) <= self.chunk_size:
                    current_chunk += chunk
                    current_length += len(chunk)
                else:
                    final_chunks.append(current_chunk)
                    current_chunk = chunk
                    current_length = len(chunk)

        if current_chunk:
            final_chunks.append(current_chunk)

        for i, chunk in enumerate(final_chunks):
            citations = self.extract_citations(chunk)
            self.citation_map[i] = [c[1].group() for c in citations]

        return final_chunks

    def ingest_pdfs(self) -> None:
        """Read PDFs, extract text, perform citation-aware chunking, and compute embeddings."""
        all_text = ""
        for pdf_file in os.listdir(self.pdf_dir):
            if pdf_file.endswith(".pdf"):
                pdf_path = os.path.join(self.pdf_dir, pdf_file)
                try:
                    with open(pdf_path, 'rb') as file:
                        reader = PyPDF2.PdfReader(file)
                        for page in reader.pages:
                            text = page.extract_text() or ""
                            all_text += text + "\n"
                except Exception as e:
                    print(f"Error processing {pdf_file}: {e}")

        self.chunks = self.citation_aware_chunking(all_text)
        self.embeddings = self.primary_embedder.encode(self.chunks, convert_to_numpy=True, show_progress_bar=True)
        dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(self.embeddings)

    def similarity_search(self, query: str, k: int = 3) -> List[Tuple[str, float, int]]:
        """Perform similarity search to find top-k relevant chunks."""
        query_embedding = self.primary_embedder.encode([query], convert_to_numpy=True)
        distances, indices = self.index.search(query_embedding, k)
        return [(self.chunks[idx], distances[0][i], idx) for i, idx in enumerate(indices[0])]

    def generate_answer(self, query: str, retrieved_chunks: List[Tuple[str, float, int]]) -> str:
        """Generate a 5-point summary using Groq API with citations."""
        few_shot_prompt = """
        **Example 1**
        Question: What is the main finding on climate change impacts?
        Summary:
        - Rising temperatures increase hurricane intensity by 20% by 2050 [1].
        - Coastal flooding risks rise due to sea-level increases [2].
        - Drought frequency in arid regions doubles by 2100 [1,2].
        - Ecosystem disruptions affect 30% of species [3].
        - Adaptation measures reduce economic losses by 15% [2].

        **Example 2**
        Question: How does the proposed algorithm improve performance?
        Summary:
        - Reduces runtime by 30% via optimized memory usage (Smith, 2023).
        - Improves accuracy by 10% with adaptive learning (Jones, 2022).
        - Lowers energy consumption in training by 25% (Smith, 2023).
        - Scales better for large datasets (Lee, 2021).
        - Enhances model stability under noisy inputs (Jones, 2022).

        **Current Question**
        Question: {query}
        Context: {context}
        Summary: Provide exactly 5 concise bullet points summarizing the key findings, each including relevant citations from the context.
        """
        context = ""
        for i, (chunk, _, idx) in enumerate(retrieved_chunks):
            citations = self.citation_map.get(idx, [])
            context += f"{chunk} {' '.join(citations)}\n"

        prompt = few_shot_prompt.format(query=query, context=context)

        try:
            response = requests.post(
                "https://api.groq.com/openai/v1/chat/completions",
                headers={
                    "Content-Type": "application/json",
                    "Authorization": f"Bearer {self.groq_api_key}"
                },
                data=json.dumps({
                    "model": "llama-3.3-70b-versatile",
                    "messages": [{"role": "user", "content": prompt}]
                })
            )
            response.raise_for_status()
            answer = response.json()["choices"][0]["message"]["content"]
            return answer
        except Exception as e:
            return f"Error calling Groq API: {str(e)}"

    def process_query(self, query: str) -> str:
        if self.index is None or not self.chunks:
            raise ValueError("No PDFs ingested. Run ingest_pdfs() first.")
        retrieved_chunks = self.similarity_search(query)
        answer = self.generate_answer(query, retrieved_chunks)
        return answer

if __name__ == "__main__":
    os.makedirs("./pdfs", exist_ok=True)
    from getpass import getpass
    groq_api_key = getpass("Enter your Groq API key: ")

    pdf_dir = "./pdfs"
    assistant = CitationAwareResearchAssistant(pdf_dir=pdf_dir, groq_api_key=groq_api_key)
    print("Ingesting PDFs...")
    assistant.ingest_pdfs()

    query = "What are the key findings on proposed GNN-Ret?"
    print(f"Query: {query}")
    answer = assistant.process_query(query)
    print(f"Answer:\n{answer}")

Enter your Groq API key: ··········
Ingesting PDFs...


Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Query: What are the key findings on proposed GNN-Ret?
Answer:
* GNN-Ret maintains consistent performance across different layer configurations, suggesting a single-layer GNN can address queries requiring only two-hop supporting passages (Shao et al., 2023; Wang et al., 2024).
* The use of minimum semantic distance as the message in GNN-Ret effectively filters out interfering messages from irrelevant neighbors and preserves the most relevant message (Karpukhin et al., 2020a; Robertson et al., 2009).
* GNN-Ret outperforms baselines, including Direct and retrievers bm25, DPR, and SentenceBert, in terms of accuracy (Karpukhin et al., 2020a; Robertson et al., 2009).
* Ablation studies demonstrate the effectiveness of components in RGNN-Ret, highlighting the importance of each component in achieving optimal performance (Shao et al., 2023).
* The evaluation of GNN-Ret on the Quality dataset, which is not a multi-hop dataset, shows promising accuracy performance, validating the model's effecti

# FINAL

In [2]:
!pip install PyPDF2 langchain sentence-transformers faiss-cpu requests ipywidgets pdfplumber scikit-learn
!pip install hf_xet  # Optional, to suppress Xet warning

Collecting PyPDF2
  Downloading pypdf2-3.0.1-py3-none-any.whl.metadata (6.8 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting pdfplumber
  Downloading pdfplumber-0.11.6-py3-none-any.whl.metadata (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Collecting pdfminer.six==20250327 (from pdfplumber)
  Downloading pdfminer_six-20250327-py3-none-any.whl.metadata (4.1 kB)
Collecting pypdfium2>=4.18.0 (from pdfplumber)
  Downloading pypdfium2-4.30.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.2/48.2 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.11.0->sentence-transformers)
  

In [6]:
import os
import re
import PyPDF2
import pdfplumber
import requests
import json
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
import faiss
import numpy as np
from typing import List, Tuple
from functools import lru_cache

class CitationAwareResearchAssistant:
    def __init__(self, pdf_dir: str, groq_api_key: str, chunk_size: int = 500, chunk_overlap: int = 50):
        """Initialize with PDF directory, Groq API key, and chunking parameters."""
        self.pdf_dir = pdf_dir
        self.groq_api_key = groq_api_key
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=chunk_overlap,
            length_function=len
        )
        self.primary_embedder = SentenceTransformer('all-MiniLM-L6-v2')
        self.index = None
        self.chunks = []
        self.embeddings = None
        self.citation_map = {}

    def extract_citations(self, text: str) -> List[Tuple[int, re.Match]]:
        """Extract citation markers and their positions."""
        citation_pattern = r'\[\d+(?:,\d+)*\]|\(\w+,\s*\d{4}\)'
        return [(m.start(), m) for m in re.finditer(citation_pattern, text)]

    def extract_sections(self, text: str) -> List[Tuple[str, str]]:
        """Extract sections using regex for common headers."""
        section_pattern = r'^(Abstract|Introduction|Methods|Results|Discussion|Conclusion)\s*$'
        sections = []
        current_section = None
        current_content = []

        for line in text.split('\n'):
            if re.match(section_pattern, line.strip(), re.IGNORECASE):
                if current_section and current_content:
                    sections.append((current_section, ' '.join(current_content)))
                current_section = line.strip()
                current_content = []
            elif current_section:
                current_content.append(line)

        if current_section and current_content:
            sections.append((current_section, ' '.join(current_content)))

        return sections

    def semantic_chunking(self, text: str) -> List[str]:
        """Cluster sentences by semantic similarity."""
        sentences = [s.strip() for s in text.split('. ') if s.strip()]
        if not sentences:
            return [text]

        sentence_embeddings = self.primary_embedder.encode(sentences, convert_to_numpy=True)
        num_clusters = max(1, len(sentences) // 5)
        kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(sentence_embeddings)
        clusters = [[] for _ in range(num_clusters)]
        for i, label in enumerate(kmeans.labels_):
            clusters[label].append(sentences[i])

        return ['. '.join(cluster) + '.' for cluster in clusters if cluster]

    def citation_aware_chunking(self, text: str) -> List[str]:
        """Split text into citation-aware chunks with section and semantic awareness."""
        sections = self.extract_sections(text)
        final_chunks = []

        for section_name, section_text in sections:
            semantic_chunks = self.semantic_chunking(section_text)

            for chunk in semantic_chunks:
                citations = self.extract_citations(chunk)
                current_chunk = ""
                current_length = 0

                if citations:
                    for pos, citation_match in citations:
                        sentence_end = chunk.rfind('.', 0, pos) + 1
                        if sentence_end == 0:
                            sentence_end = pos
                        context_chunk = chunk[:sentence_end] + citation_match.group()
                        if current_length + len(context_chunk) <= self.chunk_size:
                            current_chunk += context_chunk
                            current_length += len(context_chunk)
                        else:
                            final_chunks.append(f"{section_name}: {current_chunk}")
                            current_chunk = context_chunk
                            current_length = len(context_chunk)
                else:
                    if current_length + len(chunk) <= self.chunk_size:
                        current_chunk += chunk
                        current_length += len(chunk)
                    else:
                        final_chunks.append(f"{section_name}: {current_chunk}")
                        current_chunk = chunk
                        current_length = len(chunk)

                if current_chunk:
                    final_chunks.append(f"{section_name}: {current_chunk}")

        for i, chunk in enumerate(final_chunks):
            citations = self.extract_citations(chunk)
            self.citation_map[i] = [c[1].group() for c in citations]

        return final_chunks

    def ingest_pdfs(self) -> None:
        """Read PDFs, extract text, perform advanced chunking, and compute embeddings."""
        all_text = ""
        for pdf_file in os.listdir(self.pdf_dir):
            if pdf_file.endswith(".pdf"):
                pdf_path = os.path.join(self.pdf_dir, pdf_file)
                try:
                    with open(pdf_path, 'rb') as file:
                        reader = PyPDF2.PdfReader(file)
                        for page in reader.pages:
                            text = page.extract_text() or ""
                            all_text += text + "\n"
                except Exception as e:
                    print(f"Error processing {pdf_file}: {e}")

        self.chunks = self.citation_aware_chunking(all_text)
        self.embeddings = self.primary_embedder.encode(self.chunks, convert_to_numpy=True, show_progress_bar=True)
        dimension = self.embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(self.embeddings)

    @lru_cache(maxsize=100)
    def generate_hypothetical_answer(self, query: str) -> str:
        """Generate a cached hypothetical answer using Groq API for HyDE."""
        prompt = f"Provide a brief hypothetical answer to the question: {query}"
        try:
            response = requests.post(
                "https://api.groq.com/openai/v1/chat/completions",
                headers={
                    "Content-Type": "application/json",
                    "Authorization": f"Bearer {self.groq_api_key}"
                },
                data=json.dumps({
                    "model": "llama-3.3-70b-versatile",
                    "messages": [{"role": "user", "content": prompt}],
                    "max_tokens": 100
                })
            )
            response.raise_for_status()
            return response.json()["choices"][0]["message"]["content"].strip()
        except Exception as e:
            print(f"Error generating hypothetical answer: {str(e)}")
            return query

    def extract_new_terms(self, chunks: List[Tuple[str, float, int]]) -> str:
        """Extract key terms from chunks for multi-hop refinement."""
        terms = []
        for chunk, _, _ in chunks:
            words = chunk.split()
            terms.extend([w for w in words if w.isalpha() and len(w) > 4])
        return ' '.join(list(set(terms))[:5])

    def multi_hop_retrieval(self, query: str, k: int = 3, depth: int = 2) -> List[Tuple[str, float, int]]:
        """Perform multi-hop retrieval by iteratively refining the query."""
        current_query = query
        all_chunks = []

        for _ in range(depth):
            chunks = self.similarity_search(current_query, k)
            all_chunks.extend(chunks)
            new_terms = self.extract_new_terms(chunks)
            current_query = f"{current_query} {new_terms}"

        unique_chunks = {chunk: (dist, idx) for chunk, dist, idx in all_chunks}
        sorted_chunks = sorted(unique_chunks.items(), key=lambda x: x[1][0])[:k]
        return [(chunk, dist, idx) for chunk, (dist, idx) in sorted_chunks]

    def similarity_search(self, query: str, k: int = 3) -> List[Tuple[str, float, int]]:
        """Perform similarity search with HyDE-style enhancement."""
        hypothetical_answer = self.generate_hypothetical_answer(query)
        combined_query = f"{query} {hypothetical_answer}"
        query_embedding = self.primary_embedder.encode([combined_query], convert_to_numpy=True)
        distances, indices = self.index.search(query_embedding, k)
        return [(self.chunks[idx], distances[0][i], idx) for i, idx in enumerate(indices[0])]

    def generate_answer(self, query: str, retrieved_chunks: List[Tuple[str, float, int]]) -> str:
        """Generate a concise Q&A response using Groq API with citations."""
        few_shot_prompt = """
        **Example 1**
        Question: What is the main cause of climate change according to recent studies?
        Answer: Recent studies identify greenhouse gas emissions, particularly CO2 from fossil fuel combustion, as the primary cause of climate change [1,2].

        **Example 2**
        Question: How does the new algorithm improve neural network training?
        Answer: The algorithm enhances training by reducing runtime by 30% through optimized memory usage and improving accuracy with adaptive learning rates (Smith, 2023).

        **Current Question**
        Question: {query}
        Context: {context}
        Answer: Provide a concise, direct answer to the question in 1-2 sentences, including relevant citations from the context.
        """
        context = ""
        for i, (chunk, _, idx) in enumerate(retrieved_chunks):
            citations = self.citation_map.get(idx, [])
            context += f"{chunk} {' '.join(citations)}\n"

        prompt = few_shot_prompt.format(query=query, context=context)

        try:
            response = requests.post(
                "https://api.groq.com/openai/v1/chat/completions",
                headers={
                    "Content-Type": "application/json",
                    "Authorization": f"Bearer {self.groq_api_key}"
                },
                data=json.dumps({
                    "model": "llama-3.3-70b-versatile",
                    "messages": [{"role": "user", "content": prompt}],
                    "max_tokens": 200
                })
            )
            response.raise_for_status()
            answer = response.json()["choices"][0]["message"]["content"]
            return answer.strip()
        except Exception as e:
            return f"Error calling Groq API: {str(e)}"

    def process_query(self, query: str, use_multi_hop: bool = True) -> str:
        """Process a single query and return the answer."""
        if self.index is None or not self.chunks:
            raise ValueError("No PDFs ingested. Run ingest_pdfs() first.")
        retrieved_chunks = self.multi_hop_retrieval(query) if use_multi_hop else self.similarity_search(query)
        answer = self.generate_answer(query, retrieved_chunks)
        return answer

def run_queries(assistant, queries: List[str]):
    """Run a list of queries and print results."""
    for query in queries:
        print(f"Query: {query}")
        answer = assistant.process_query(query)
        print(f"Answer: {answer}\n")


if __name__ == "__main__":
    os.makedirs("./pdfs", exist_ok=True)

    from getpass import getpass
    groq_api_key = getpass("Enter your Groq API key: ")

    pdf_dir = "./pdfs"
    assistant = CitationAwareResearchAssistant(pdf_dir=pdf_dir, groq_api_key=groq_api_key)

    print("Ingesting PDFs...")
    assistant.ingest_pdfs()

    queries = [
        "What are the key findings on proposed GNN-Ret?",
        "How does GNN-Ret compare to baselines like BM25?"
    ]
    run_queries(assistant, queries)

Enter your Groq API key: ··········
Ingesting PDFs...


Batches:   0%|          | 0/6 [00:00<?, ?it/s]

Query: What are the key findings on proposed GNN-Ret?
Answer: The key findings on the proposed GNN-Ret method indicate that it effectively enhances the retrieval of supporting passages for question answering (QA) by leveraging the relatedness between passages, outperforming baselines such as SBERT. The experiments demonstrate the superiority of GNN-Ret, with significant improvements in retrieval coverage and accuracy, as discussed in the context (Section 3.2).

Query: How does GNN-Ret compare to baselines like BM25?
Answer: GNN-Ret outperforms baselines like BM25, achieving higher accuracy for question answering with a single query of LLMs, and its extension RGNN-Ret further improves accuracy and achieves state-of-the-art performance (Ho et al., 2020). Specifically, GNN-Ret significantly outperforms SBERT, with improvements of up to 29 and 43 exact-match test samples for questions requiring 2 supporting passages.

