The reference for this notebook is [here](https://medium.com/aimpact-all-things-ai/prompt-like-a-pro-using-dspy-a-guide-to-build-a-better-local-rag-model-using-dspy-qdrant-and-d8011a3942d9#id_token=eyJhbGciOiJSUzI1NiIsImtpZCI6Ijg3YmJlMDgxNWIwNjRlNmQ0NDljYWM5OTlmMGU1MGU3MmEzZTQzNzQiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiIyMTYyOTYwMzU4MzQtazFrNnFlMDYwczJ0cDJhMmphbTRsamRjbXMwMHN0dGcuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiIyMTYyOTYwMzU4MzQtazFrNnFlMDYwczJ0cDJhMmphbTRsamRjbXMwMHN0dGcuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMTQzMDg4NjQ0NjgxMzkxMjUxNjAiLCJlbWFpbCI6ImRzbWl0aC5zbWl0aDg2NEBnbWFpbC5jb20iLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwibmJmIjoxNzIwMjkwNDM1LCJuYW1lIjoiRGFuaWVsIFNtaXRoIiwicGljdHVyZSI6Imh0dHBzOi8vbGgzLmdvb2dsZXVzZXJjb250ZW50LmNvbS9hL0FDZzhvY0tCWm9sTU8tUkk4VzYySEhTUl9KMFdIS3lFUkdoR1k0Njlsa0FNMEZTSmdSbFhMdz1zOTYtYyIsImdpdmVuX25hbWUiOiJEYW5pZWwiLCJmYW1pbHlfbmFtZSI6IlNtaXRoIiwiaWF0IjoxNzIwMjkwNzM1LCJleHAiOjE3MjAyOTQzMzUsImp0aSI6IjQ2MzJkYTY1Yjc0OWY5MWUxOGJhZTVjMzhkN2YzY2I0ZmJiNzRjYTIifQ.aP4dZS1Pej36uxs6rJo1NHDlRPKSZV26gJWH5UTWVuVtWJ46fQ2LyvaCGG2my4zQnzu_2RB2V3kPREFz8tP8HwrLkwEa0oWIz-uko6le7F67xBRdTWylpnaqPJtGJeZua9QKqwIZTAe9kZPiN8wIx-v7QwriOGqM7AMi-gPBRXEMCH6ToMoCdDp1XH0a-mLtK8FIGUK8V2EFUsOxRXv2vJmo6NnuwxxHtnH36hphhS6WRToVBxTI7Lm85Xr0IJtoDuhVX8VXlYhVMk8lT3srQLv15TiykFFKZFwvCfmr9WzhOmKiQBlZTIIl8ff2_MMVa7Qk7YO3fFuDv6E2WOSYOQ)

In [None]:
from dspy.datasets import HotPotQA
import dspy

## Load the dataset

In [None]:
dataset = HotPotQA(train_seed=1, test_size=0, train_size=1000)
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
dataset = [x.with_inputs('question') for x in dataset.train]
print(len(dataset))

In [None]:
from dspy.retrieve.qdrant_rm import QdrantRM
from qdrant_client import QdrantClient

In [None]:
qdrant_client = QdrantClient(":memory:")  # In-memory load
docs = [x.question + " -> " + x.answer for x in dataset]
ids = list(range(0,len(docs)))

In [None]:
qdrant_client.add(
    collection_name="hotpotqa",
    documents=docs,
    ids=ids
    )

## Define the retriever

In [None]:
qdrant_retriever_model = QdrantRM("hotpotqa", qdrant_client, k=3)

In [None]:
dspy.settings.configure(rm=qdrant_retriever_model)

In [None]:
def get_top_passages(question):
    retrieve = dspy.Retrieve(k=3)
    topK_passages = retrieve(question, k=3).passages
    print(f"Top {retrieve.k} passages for question: {question} \n", '-' * 30, '\n')
    for idx, passage in enumerate(topK_passages):
        print(f'{idx+1}]', passage, '\n')

In [None]:
dev_example = dataset[100]

get_top_passages(dev_example.question)

## Initialize Llama3 Model Using DSPy-Ollama Integration

In [None]:
ollama_model = dspy.OllamaLocal(model="llama3",model_type='text',
                                max_tokens=350,
                                temperature=0.1,
                                top_p=0.8, frequency_penalty=1.17, top_k=40)

In [None]:
ollama_model("tell me about interstellar's plot")

In [None]:
dspy.settings.configure(lm=ollama_model, rm=qdrant_retriever_model)

## Define Signatures for Input and Output

### TODO 
- Add dspy assertions to enforce the output to be a list of classes

In [None]:
# class RetrieveLabelCandidates(dspy.Signature):
#     """Retrieve relevant label candidates for a given text."""
#     text = dspy.InputField()
#     label_candidates = dspy.OutputField(desc="List of potential labels for the given text")

class ClassifyText(dspy.Signature):
    """Classify the text into multiple labels from the given candidates."""
    text = dspy.InputField()
    label_candidates = dspy.InputField(desc="List of possible labels for the text")
    labels = dspy.OutputField(desc="List of applicable labels for the text")

In [None]:
ct = ClassifyText(text = , label_candidates = , labels = )

In [None]:
# class GenerateAnswer(dspy.Signature):
#     """Answer questions with short factoid answers."""

#     context = dspy.InputField(desc="may contain relevant facts or answer keywords")
#     question = dspy.InputField()
#     answer = dspy.OutputField(desc="an answer between 1 to 10 words")

In [None]:
# ga = GenerateAnswer(context="My name is sachin and I like writing blogs", question="What is my name?", answer="Sachin")
# print(ga.model_construct)

## Create a DSPy CoT Module

In [None]:
class RAGMultiLabelClassifier(dspy.Module):
    def __init__(self, num_candidates=5):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_candidates)
        self.classify = dspy.Predict(ClassifyText)
        self.num_candidates = num_candidates

    def forward(self, text):
        # Retrieve relevant documents
        retrieved_docs = self.retrieve(text).passages

        # Generate label candidates
        candidate_result = self.generate_candidates(text=text)
        label_candidates = candidate_result.label_candidates[:self.num_candidates]

        # Classify text
        classification_result = self.classify(text=text, label_candidates=label_candidates)
        return classification_result.labels

In [None]:
# class RAG(dspy.Module):
#     def __init__(self, num_passages=3):
#         super().__init__()
#         self.retrieve = dspy.Retrieve(k=num_passages)
#         self.generate_answer = dspy.ChainOfThought(GenerateAnswer)

#     def forward(self, question):
#         context = self.retrieve(question).passages
#         prediction = self.generate_answer(context=context, question=question)
#         return dspy.Prediction(context=context, answer=prediction.answer)

In [None]:
uncompiled_rag = RAG()

In [None]:
my_question = "is Bank of America Tower taller than empire state building?"
response = uncompiled_rag(my_question)
print(response.answer)

In [None]:
ollama_model.inspect_history(n=1)

In [None]:
my_question = "Was George Alan O'Dowd the most popular in the late 2000s with his rock band?"
response = uncompiled_rag(my_question)
print(response.answer)

In [None]:
ollama_model.inspect_history(n=1)