Skip to content

Conversation

@jhyearsley
Copy link
Contributor

Added MongoDB Atlas Retrieval Model. Currently only supports OpenAI embedding provider.

@okhat
Copy link
Collaborator

okhat commented Dec 28, 2023

Hey @jhyearsley , I want to merge this. It's probably affected by the recent OpenAI update though, any changes needed here?

@DanielUH2019 some changes probably also still needed for 1-2 of the retrievers, from my comments in your (merged) PR

@jhyearsley
Copy link
Contributor Author

jhyearsley commented Dec 28, 2023

@okhat thanks for the ping, should be good to go now. I reran the script I've been testing against and it's working as expected

@jhyearsley
Copy link
Contributor Author

jhyearsley commented Dec 28, 2023

Attaching the script in case others find it helpful, want to test themselves, or have any feedback to share. As written the script implements RAG in less than 30 lines of code!!!

The script requires a MongoDB Atlas cluster which stores data in the namespace kb.embedded_content. The embedded_content collection stores chunked data which was embedded with OpenAI text-embedding-ada-002 embedding model. The question in the script is also embedded with text-embedding-ada-002 and Atlas Vector Search is used to retrieve the approximate KNNs which are fed back into GPT as domain specific context to answer the question.

Really cool to see the power and simplicity of DSPy in action 🎉 thanks @okhat for all the recent help!

import dspy
from dspy.retrieve.mongodb_atlas_rm import MongoDBAtlasRM

lm = dspy.OpenAI(model="gpt-3.5-turbo", model_type="chat", max_tokens=1500)
rm = MongoDBAtlasRM(
    db_name="kb",
    collection_name="embedded_content",
    index_name="embedded_content_vector_index",
    k=5,
)

dspy.settings.configure(lm=lm, rm=rm)


## Basic Q&A
class BasicQA(dspy.Signature):
    """Answer MongoDB questions with paragraph answers."""

    question = dspy.InputField()
    answer = dspy.OutputField(
        desc="Focus on MongoDB. Answers should be 1 to 3 paragraphs"
    )


question = "Why is MongoDB Atlas the best way to run MongoDB in the cloud?"
generate_basic_answer = dspy.Predict(BasicQA)
pred = generate_basic_answer(question=question)
print(f"Question: {question}")
print(f"Predicted Answer: {pred.answer}")
lm.inspect_history(n=1)

## Chain of Thought
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA)
pred = generate_answer_with_chain_of_thought(question=question)
print(f"Question: {question}")
print(f"Thought: {pred.rationale.split('.', 1)[1].strip()}")
print(f"Predicted Answer: {pred.answer}")

## Inspecting semantic retrieval results
topK_passages = rm(question).passages
print(f"Top {rm.k} passages for question: {question} \n", "-" * 30, "\n")
for idx, passage in enumerate(topK_passages):
    print(f"{idx+1}]", passage["text"], "\n")


## RAG
class GenerateAnswer(dspy.Signature):
    """Answer MongoDB questions with thorough and in-depth answers."""

    context = dspy.InputField(
        desc="may contain relevant data for answering the question"
    )
    question = dspy.InputField()
    answer = dspy.OutputField(desc="Answers should be 1 to 3 paragraphs")


class RAG(dspy.Module):
    def __init__(self, num_passages=5):
        super().__init__()

        self.retrieve = rm
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)

    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)


rag = RAG()
predRAG = rag(question=question)
print(f"Question: {question}")
print(f"Answer: {predRAG.answer}")
lm.inspect_history(n=1)

@okhat okhat merged commit f94861f into stanfordnlp:main Jan 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants