In [1]:
import dspy
import os
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
import chromadb
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction, SentenceTransformerEmbeddingFunction

###  Utilities Functions, we can move to common python file so that can be used elsewhere

In [16]:
def format_text_to_width(string, max_width=72):
    """
    Formats text to fit within a specified width, ensuring that words are not broken.

    This function wraps text at the last whitespace character before the max_width limit,
    ensuring that each line does not exceed this width. It applies this logic recursively
    to the entire text string.

    Parameters:
    - string (str): The text to be formatted.
    - max_width (int): The maximum number of characters allowed per line before wrapping.

    Returns:
    - str: The formatted text with appropriate line breaks.
    """
    if len(string) <= max_width:
        return string
    else:
        # Find the last space before the max width to avoid breaking words
        break_point = string[:max_width].rsplit(' ', 1)[0]

        # Handle the case where a single word is longer than the max width
        if not break_point:
            break_point = string[:max_width]

        # Recursively format the remainder of the text
        return break_point + '\n' + format_text_to_width(string[len(break_point)+1:], max_width)

In [2]:
# Load the model
turbo = dspy.OpenAI(model='gpt-3.5-turbo')

In [3]:
# Read the test data
import pandas as pd
df = pd.read_csv('medical_tc_test.csv')


# If you need to convert the DataFrame to a raw text format for whatever reason:
text = df['medical_abstract'].to_string(index=False)
dspy.settings.configure(lm=turbo)
text

'Obstructive sleep apnea following topical oroph...\nNeutrophil function and pyogenic infections in ...\nA phase II study of combined methotrexate and t...\nFlow cytometric DNA analysis of parathyroid tum...\nParaneoplastic vasculitic neuropathy: a treatab...\nTreatment of childhood angiomatous diseases wit...\nExpression of major histocompatibility complex ...\nQuestionable role of CNS radioprophylaxis in th...\nReversibility of hepatic fibrosis in experiment...\nCurrent status of duplex Doppler ultrasound in ...\nThe importance of congenital hypertrophy of the...\nHuman papillomavirus in women with vulvar intra...\nGentamicin iontophoresis in the treatment of ba...\nRepeat hepatic resection for primary and metast...\nEvidence for intraluminal Ca++ regulatory site ...\nGlutamic acid and gamma-aminobutyric acid neuro...\nA useful technique for measurement of back stre...\nThe natural history of ultraviolet radiation-in...\nHereditary internal anal sphincter myopathy cau...\nImmune resp

In [4]:
from langchain.text_splitter import RecursiveCharacterTextSplitter

character_splitter = RecursiveCharacterTextSplitter(
    separators=["\n\n", "\n", ". ", " ", ""],
    chunk_size=256,
    chunk_overlap=0
)
character_split_texts = character_splitter.split_text(text)

print(f"\nTotal chunks: {len(character_split_texts)}\n")


Total chunks: 578



In [5]:
token_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=0, tokens_per_chunk=256)

token_split_texts = []
for text in character_split_texts:
    token_split_texts += token_splitter.split_text(text)

print(f"\nTotal chunks: {len(token_split_texts)}")


Total chunks: 578


In [6]:
token_split_texts[1]

'treatment of childhood angiomatous diseases wit... expression of major histocompatibility complex... questionable role of cns radioprophylaxis in th... reversibility of hepatic fibrosis in experiment... current status of duplex doppler ultrasound in...'

In [7]:
embedding_function = SentenceTransformerEmbeddingFunction()
print("Length of embedding:")
print(len(embedding_function([token_split_texts[0]])[0]))

Length of embedding:


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

384


In [8]:
chroma_client = chromadb.PersistentClient("local_chroma.db")

In [9]:
# Create a new collection

chroma_collection = chroma_client.get_or_create_collection("test-medical_abstract_collection", embedding_function=embedding_function)

ids = [str(i) for i in range(len(token_split_texts))]

In [10]:
chroma_collection.add(ids=ids, documents=token_split_texts)

Batches:   0%|          | 0/19 [00:00<?, ?it/s]

In [11]:
chroma_client.list_collections()

[Collection(name=test-medical_abstract_collection)]

In [None]:
chroma_client.list_collections()

In [12]:
chroma_collection.peek(1)

{'ids': ['0'],
 'embeddings': [[-0.05739233270287514,
   -0.026065049692988396,
   0.025520015507936478,
   -0.01608973555266857,
   -0.1337699145078659,
   -0.03625405579805374,
   -0.03604433313012123,
   0.1593852937221527,
   -0.019802089780569077,
   0.009291920810937881,
   -0.015202006325125694,
   -0.017199542373418808,
   0.0726933404803276,
   0.07681836187839508,
   -0.0880700945854187,
   0.04864553362131119,
   -0.02266610972583294,
   0.029413221403956413,
   0.1520378738641739,
   -0.01761646941304207,
   -0.004725049715489149,
   0.1467331349849701,
   -0.023301294073462486,
   -0.04613426700234413,
   -0.009664880111813545,
   -0.01699734665453434,
   -0.022988280281424522,
   0.046910300850868225,
   0.0511341318488121,
   0.019145147874951363,
   -0.06288857012987137,
   0.09893061220645905,
   -0.14823494851589203,
   0.006717968266457319,
   0.122281014919281,
   0.06078067049384117,
   -0.09140251576900482,
   0.049731604754924774,
   -0.053235068917274475,
   0.0

In [13]:
query = "What is obstructive sleep apnea?"

results = chroma_collection.query(query_texts=[query], n_results=2)
retrieved_documents = results['documents'][0]
print(f"Query: {query}")
print(f"\nRetrieved {len(retrieved_documents)} documents\n")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query: What is obstructive sleep apnea?

Retrieved 2 documents



In [18]:
for docs in retrieved_documents:
    print(format_text_to_width(docs))
    print("\n")

systemic hypertension in sleep apnea syndrome.... interruption of
critical aortoiliac collateral... estimating and testing an index of
responsivene... prophylaxis of aphakic cystoid macular edema wi...
carotid - esophageal fistula following a penetrat...


otolaryngologic management of patients with sub... esophageal
adenocarcinoma in a patient with sur... serological arguments for
classifying raynaud's... management of chronic middle ear effusion
with... arterial oxygen saturation during induction of...




In [19]:
turbo = dspy.OpenAI(model='gpt-3.5-turbo')
dspy.settings.configure(lm=turbo)

In [20]:
class GenerateAnswer(dspy.Signature):
    """Answer questions with short factoid answers."""

    context = dspy.InputField(desc="may contain relevant facts")
    question = dspy.InputField()
    answer = dspy.OutputField(desc="Explain with words between 1 and 10 words")

In [27]:
chroma_client.list_collections()

[Collection(name=test-medical_abstract_collection)]

In [28]:
# Modifying the default RAG module because it doesn't work with the SentenceTransformerEmbeddingFunction
class MedicalAbstractRag(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        self.chroma_collection = chroma_client.get_collection("test-medical_abstract_collection")
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
        self.num_passages = num_passages
    
    def forward(self, question):
        context = self.chroma_collection.query(query_texts=[question], n_results=self.num_passages)
        context = context['documents']
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)

In [29]:
rag = MedicalAbstractRag(num_passages=3)

In [30]:
question = "What is sleep apnea?"
rag(question)

C:\Users\Hem Chandra\.cache\chroma\onnx_models\all-MiniLM-L6-v2\onnx.tar.gz: 100%|██████████| 79.3M/79.3M [02:49<00:00, 490kiB/s] 


Prediction(
    context=[['systemic hypertension in sleep apnea syndrome.... interruption of critical aortoiliac collateral... estimating and testing an index of responsivene... prophylaxis of aphakic cystoid macular edema wi... carotid - esophageal fistula following a penetrat...', 'predicting recurrence time of esophageal carcin... the use of ultrasound in evaluating neurologic... control of total peripheral resistance during h... disturbance in daily sleep / wake patterns in pat... percent tumor necrosis as a predictor of treatm...', 'bretylium tosylate versus lidocaine in experime... wrist flexion as an adjunct to the diagnosis of... apneic oxygenation in apnea tests for brain dea... response of spinal cord blood flow and motor an... extensive aneurysmal bone cyst of the mandible :...']],
    answer='Breathing disorder during sleep.'
)

In [31]:
turbo.inspect_history(n=1)





Answer questions with short factoid answers.

---

Follow the following format.

Context: may contain relevant facts

Question: ${question}

Reasoning: Let's think step by step in order to ${produce the answer}. We ...

Answer: Explain with words between 1 and 10 words

---

Context: «['systemic hypertension in sleep apnea syndrome.... interruption of critical aortoiliac collateral... estimating and testing an index of responsivene... prophylaxis of aphakic cystoid macular edema wi... carotid - esophageal fistula following a penetrat...', 'predicting recurrence time of esophageal carcin... the use of ultrasound in evaluating neurologic... control of total peripheral resistance during h... disturbance in daily sleep / wake patterns in pat... percent tumor necrosis as a predictor of treatm...', 'bretylium tosylate versus lidocaine in experime... wrist flexion as an adjunct to the diagnosis of... apneic oxygenation in apnea tests for brain dea... response of spinal cord blood flow and

## Now using the RAG used in the db_retriever_module.py

In [32]:
import sys
import os

# Add the parent directory to sys.path to allow importing from the same level as 'Evaluation'
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

# Now you can import the module abc as if it were in the same directory
import db_retriever_module


In [None]:
chroma_rm = db_retriever_module.ChromadbRetrieverModule(db_collection_name="medical_abstract_data_collection",
                                                        persist_directory="local_chroma.db", 
                                                        local_embed_model="sentence-transformers/paraphrase-MiniLM-L6-v2",
                                                        api_key=os.environ["OPENAI_API_KEY"])

In [None]:
dspy.settings.configure(lm=turbo, rm=chroma_rm)

In [None]:
class RAG(dspy.Module):
    def __init__(self, num_passages=3):
        super().__init__()
        self.retrieve = dspy.Retrieve(k=num_passages)
        self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
    
    def forward(self, question):
        context = self.retrieve(question).passages
        prediction = self.generate_answer(context=context, question=question)
        return dspy.Prediction(context=context, answer=prediction.answer)

In [None]:
rag = RAG(num_passages=3)
question = "What is sleep apnea?"
rag(question)

In [35]:
from dspy.datasets import HotPotQA

# Load the dataset.
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=50, test_size=0)

# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
trainset = [x.with_inputs('question') for x in dataset.train]
devset = [x.with_inputs('question') for x in dataset.dev]

len(trainset), len(devset)

Downloading builder script:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.19k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/566M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/47.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/46.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7405 [00:00<?, ? examples/s]

  table = cls._concat_blocks(blocks, axis=0)


(20, 50)

In [None]:
from dspy.teleprompt import BootstrapFewShot

# Validation logic: check that the predicted answer is correct.
# Also check that the retrieved context does actually contain that answer.
def check_answer_and_context_validity(example, pred, trace=None):
    answer_EM = dspy.evaluate.answer_exact_match(example, pred)
    answer_PM = dspy.evaluate.answer_passage_match(example, pred)
    return answer_EM and answer_PM

# Set up a basic teleprompter, which will compile our RAG program.
teleprompter = BootstrapFewShot(metric=check_answer_and_context_validity)

# Compile!
compiled_rag = teleprompter.compile(RAG(), trainset=trainset)

In [None]:
# Ask any question you like to this simple RAG program.
my_question = "What is obstructive sleep apnea?"

# Get the prediction. This contains `pred.context` and `pred.answer`.
pred = compiled_rag(my_question)

# Print the contexts and the answer.
print(f"Question: {my_question}")
print(f"Predicted Answer: {pred.answer}")
print(f"Retrieved Contexts (truncated): {[c[:200] + '...' for c in pred.context]}")

In [None]:
turbo.inspect_history(n=1)

In [None]:
for name, parameter in compiled_rag.named_predictors():
    print(name)
    print(parameter.demos[0])
    print()