Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaperQA CLI Script draft (ib.bsb.br/clipaper-qa) #286

Open
marioseixas opened this issue Jun 17, 2024 · 0 comments
Open

PaperQA CLI Script draft (ib.bsb.br/clipaper-qa) #286

marioseixas opened this issue Jun 17, 2024 · 0 comments

Comments

@marioseixas
Copy link

kind works

from paperqa import Answer, Docs, PromptCollection, OpenAILLMModel, OpenAIEmbeddingModel
import argparse
import sys
import pickle
import json
import glob
from pathlib import Path
from typing import List, Dict

def main():
    parser = argparse.ArgumentParser(description="PaperQA CLI")
    parser.add_argument(
        "--pdf_dir",
        type=str,
        default="/home/mario/gpt-researcher/BIBTEX",
        help="Directory containing PDF files",
    )
    parser.add_argument(
        "--question", type=str, help="Question to ask (use quotes for multi-word questions)"
    )
    parser.add_argument(
        "--questions_file",
        type=str,
        help="Path to a text file containing a list of questions, one per line",
    )
    parser.add_argument(
        "--save_embeddings",
        type=str,
        default="paperqa_embeddings.pkl",
        help="Path to save the Docs object with embeddings",
    )
    parser.add_argument(
        "--load_embeddings",
        type=str,
        help="Path to load a pre-saved Docs object with embeddings",
    )
    parser.add_argument(
        "--save_answers",
        type=str,
        default="answers.txt",
        help="Path to save the answers to a text file",
    )
    parser.add_argument(
        "--save_data",
        type=str,
        default="paperqa_data.pkl",
        help="Path to save all data (answers, Docs object) to a pickle file",
    )
    parser.add_argument(
        "--save_embeddings_txt",
        type=str,
        help="Path to save embeddings in a text file (for debugging/analysis)",
    )
    parser.add_argument(
        "--llm",
        type=str,
        default="gpt-4o",
        help="OpenAI LLM model name (e.g., 'gpt-4o', 'gpt-4o')",
    )
    parser.add_argument(
        "--embedding",
        type=str,
        default="text-embedding-3-large",
        help="OpenAI embedding model name (e.g., 'text-embedding-3-large')",
    )
    parser.add_argument(
        "--summary_llm",
        type=str,
        help="OpenAI LLM model name for summarization (defaults to same as --llm)",
    )
    parser.add_argument(
        "--k",
        type=int,
        default=50,
        help="Number of top-k results to retrieve for each query",
    )
    parser.add_argument(
        "--max_sources",
        type=int,
        default=50,
        help="Maximum number of sources to use in the final answer",
    )
    parser.add_argument(
        "--chunk_chars",
        type=int,
        default=3200,
        help="Number of characters per chunk when splitting documents",
    )
    parser.add_argument(
        "--overlap",
        type=int,
        default=1600,
        help="Number of overlapping characters between chunks",
    )
    parser.add_argument(
        "--json_summary",
        action="store_true",
        help="Use JSON format for summarization (requires GPT-3.5-turbo or later)",
    )
    parser.add_argument(
        "--detailed_citations",
        action="store_true",
        help="Include full citations in the context",
    )
    parser.add_argument(
        "--disable_vector_search",
        action="store_true",
        help="Disable vector search and use all text chunks",
    )
    parser.add_argument(
        "--key_filter",
        action="store_true",
        help="Filter evidence by document keys based on question similarity",
    )
    parser.add_argument(
        "--custom_prompt_file",
        type=str,
        help="Path to a JSON file containing custom prompts",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=10,  # Increased batch size for efficiency
        help="Batch size for processing documents (adjust for performance)",
    )
    parser.add_argument(
        "--max_concurrent",
        type=int,
        default=4,
        help="Maximum number of concurrent requests (adjust for performance)",
    )
    parser.add_argument(
        "--strip_citations",
        action="store_true",
        help="Strip citations from the generated answers",
    )
    parser.add_argument(
        "--jit_texts_index",
        action="store_true",
        help="Enable just-in-time text indexing",
    )
    parser.add_argument(
        "--answer_length",
        type=str,
        default="about 100 words",
        help="Specify the desired length of the answer (e.g., 'about 200 words')",
    )
    args = parser.parse_args()

    # Load custom prompts from JSON file if provided
    custom_prompts: Dict[str, str] = {}
    if args.custom_prompt_file:
        try:
            with open(args.custom_prompt_file, "r", encoding="utf-8") as f:
                custom_prompts = json.load(f)
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error loading custom prompts: {e}", file=sys.stderr)
            sys.exit(1)
    # Initialize PromptCollection without arguments
    prompts = PromptCollection()

    # If custom prompts are needed, set them directly
    if custom_prompts:
        for key, value in custom_prompts.items():
            setattr(prompts, key, value)

    llm_model = OpenAILLMModel(config={"model": args.llm, "temperature": 0.1})
    embedding_model = OpenAIEmbeddingModel(config={"model": args.embedding})
    summary_llm_model = (
        OpenAILLMModel(config={"model": args.summary_llm, "temperature": 0.1})
        if args.summary_llm
        else llm_model
    )

    # Create Docs object with supported parameters
    docs = Docs(
        llm=args.llm,  # Pass the model name as a string
        embedding=args.embedding,  # Pass the model name as a string
        summary_llm=args.summary_llm if args.summary_llm else args.llm,  # Pass the model name as a string
        prompts=prompts,  # Use the initialized and updated PromptCollection
        max_concurrent=args.max_concurrent,
        jit_texts_index=args.jit_texts_index,
    )

    # Load existing embeddings if provided
    if args.load_embeddings:
        try:
            with open(args.load_embeddings, "rb") as f:
                docs = pickle.load(f)
            docs.set_client()  # Required after loading from pickle
        except FileNotFoundError as e:
            print(f"Error loading embeddings: {e}", file=sys.stderr)
            sys.exit(1)

    # Add PDF documents
    pdf_dir = Path(args.pdf_dir)
    pdf_files = glob.glob(str(pdf_dir / "*.pdf"))
    for pdf_file in pdf_files:
        if pdf_file not in [doc.dockey for doc in docs.docs.values()]:
            docs.add(pdf_file)

    # Get questions from command line, file, or standard input
    questions: List[str] = []
    if args.question:
        questions.append(args.question)
    if args.questions_file:
        try:
            with open(args.questions_file, "r", encoding="utf-8") as f:
                questions.extend([line.strip() for line in f])
        except FileNotFoundError as e:
            print(f"Error reading questions file: {e}", file=sys.stderr)
            sys.exit(1)
    if not questions:
        questions = [line.strip() for line in sys.stdin]

    # Get answers for each question
    answers = []
    for question in questions:
        answer = docs.query(question, k=args.k, max_sources=args.max_sources)
        print(f"Answer object: {answer}")  # Debug print to inspect the Answer object
        answers.append(str(answer))  # Use str(answer) to store the entire Answer object

    # Save answers to file
    with open(args.save_answers, "w", encoding="utf-8") as f:
        for answer in answers:
            f.write(answer + "\n")

    # Save embeddings to file
    if args.save_embeddings:
        with open(args.save_embeddings, "wb") as f:
            pickle.dump(docs, f)

    # Save all data to file
    if args.save_data:
        with open(args.save_data, "wb") as f:
            pickle.dump({"docs": docs, "answers": answers}, f)

    # Save embeddings to text file (for debugging/analysis)
    if args.save_embeddings_txt:
        with open(args.save_embeddings_txt, "w", encoding="utf-8") as f:
            for doc in docs.docs.values():
                f.write(f"{doc.dockey}\t{doc.embedding.tolist()}\n")

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant