In [None]:
!pip install langchain langchain-community langchain-openai transformers sentence-transformers datasets torch chromadb

In [4]:
from langchain.schema import Document
from datasets import load_dataset

In [None]:
# !pip install fsspec==2023.6.0

In [2]:
from datasets import load_dataset

ds = load_dataset("keivalya/MedQuad-MedicalQnADataset")

In [5]:
documents = []
for i, item in enumerate(ds['train']):
  content = f"Question: {item['Question']}\nAnswer: {item['Answer']}"
  metadata = {
      "doc_id": i,
      "question": item['Question'],
      "answer": item['Answer'],
      "question_type": item['qtype'],
      "type": "qa_pair"
  }
  documents.append(Document(page_content=content, metadata=metadata))

In [40]:
random_index = random.randint(0, len(documents) - 1)
print(f"{documents[random_index].page_content}...")

Question: How to diagnose Succinic semialdehyde dehydrogenase deficiency ?
Answer: How is succinic semialdehyde dehydrogenase deficiency diagnosed? The diagnosis of succinic semialdehyde dehydrogenase (SSADH) deficiency is based upon a thorough clinical exam, the identification of features consistent with the condition, and a variety of specialized tests. SSADH deficiency may first be suspected in late infancy or early childhood in individuals who have encephalopathy, a state in which brain function or structure is altered. The encephalopathy may be characterized by cognitive impairment; language deficit; poor muscle tone (hypotonia); seizures; decreased reflexes (hyporeflexia); and/or difficulty coordinating movements (ataxia). The diagnosis may be further suspected if urine organic acid analysis (a test that provides information about the substances the body discards through the urine) shows the presence of 4-hydroxybutyric acid. The diagnosis can be confirmed by an enzyme test showi

In [28]:
spinner_messages = [
    "Searching the universe...",
    "Consulting the medical oracles...",
    "Paging Dr. AI...",
    "Googling responsibly...",
    "Checking the medical textbooks...",
    "Assembling a team of virtual doctors...",
    "Running with scissors (just kidding)...",
    "Putting on my lab coat...",
    "Sterilizing the stethoscope...",
    "Counting imaginary pills...",
    "Reading the fine print on the prescription...",
    "Asking the mitochondria (it's the powerhouse)...",
    "Checking WebMD (not really)...",
    "Looking for my AI degree...",
    "Washing my hands for 20 seconds...",
    "Trying not to diagnose you with everything..."
]

import random

print(random.choice(spinner_messages))

Reading the fine print on the prescription...


In [15]:
documents[434].page_content

'Question: What is (are) Polymyositis ?\nAnswer: Polymyositis is one of a group of muscle diseases known as the inflammatory myopathies, which are characterized by chronic muscle inflammation accompanied by muscle weakness. Polymyositis affects skeletal muscles (those involved with making movement) on both sides of the body. It is rarely seen in persons under age 18; most cases are in adults between the ages of 31 and 60. Progressive muscle weakness starts in the proximal muscles (muscles closest to the trunk of the body) which eventually leads to difficulties climbing stairs, rising from a seated position, lifting objects, or reaching overhead. People with polymyositis may also experience arthritis, shortness of breath, difficulty swallowing and speaking, and heart arrhythmias. In some cases of polymyositis, distal muscles (muscles further away from the trunk of the body, such as those in the forearms and around the ankles and wrists) may be affected as the disease progresses. Polymyo

In [None]:
!pip install python-dotenv

In [None]:
from dotenv import load_dotenv
load_dotenv()

In [None]:
import shutil
import os
if os.path.exists("./medical_vectordb"):
    shutil.rmtree("./medical_vectordb")
    print("Old vectorstore deleted!")

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma


embeddings = HuggingFaceEmbeddings(model_name="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
vectorstore = Chroma.from_documents(documents, embeddings, persist_directory="./medical_vectordb_biobert")

In [None]:
!pip install sentence-transformers rank_bm25

In [None]:
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers import BM25Retriever

In [None]:
# Combine semantic and keyword search
bm25_retriever = BM25Retriever.from_documents(documents)
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})

ensemble_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, vector_retriever],
    weights=[0.3, 0.7]  # Favor semantic search
)

In [None]:
from google.colab import userdata
openaikey = userdata.get('OPENAI_API_KEY')

In [None]:
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(temperature=0, max_tokens=512, api_key = openaikey)

In [None]:
llm.invoke("what are you?")

In [None]:
from typing_extensions import Literal
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
class Route(BaseModel):
  step: Literal["RAG", "GENERAL", "EMERGENCY"] = Field(None, description="The next step in the routing process")

In [None]:
router = llm.with_structured_output(Route)

In [None]:
from typing import TypedDict
class State(TypedDict):
  question: str
  answer: str
  decision: str

In [None]:
def llm_call_router(state: State):
  """Route the input to the appropriate node"""
  emergency_keywords = ["severe", "chest pain", "can't breathe", "emergency", "urgent",
                         "heart attack", "stroke", "bleeding", "unconscious"]
  question_lower = state['question'].lower()
  if any(keyword in question_lower for keyword in emergency_keywords):
    return {'decision': "EMERGENCY"}

  decision = router.invoke([
      SystemMessage(content="Route the input to RAG (medical questions) or GENERAL based on the user's request"),
      HumanMessage(content=state['question'])
  ])
  return {"decision": decision.step}

In [None]:
def emergency_node(state: State):
  """Handle emergency queries safely"""

  return {"answer": "🚨 EMERGENCY: Please seek immediate medical attention or call emergency services (100). This system cannot provide emergency medical care."}


In [None]:
from sentence_transformers import CrossEncoder

reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

In [None]:
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate



def rag_node(state: State):
    """Uses RAG to answer the question"""

    # Fixed prompt template with 'context' variable
    custom_prompt = PromptTemplate(
        input_variables=["context", "question"],
        template="""You are a medical information assistant. Use the following medical Q&A context to answer questions accurately and safely.

        Context: {context}

        Question: {question}

        Guidelines:
        - Provide accurate medical information based on the context above
        - Always recommend consulting healthcare professionals for medical decisions
        - If uncertain, clearly state limitations
        - If the question is not suitable for this bot, respond with: "I'm not able to provide medical advice. Please consult a medical professional."

        Answer:"""
            )
    qa_chain = ConversationalRetrievalChain.from_llm(
          llm=llm,
          retriever=ensemble_retriever,
          return_source_documents=True,
          combine_docs_chain_kwargs={"prompt": custom_prompt}
      )

    result = qa_chain.invoke({
        "question": state['question'],
        "chat_history": []
    })

    docs = result.get('source_documents', [])
    if docs and len(docs) > 1:
      pairs = [(state['question'], doc.page_content) for doc in docs]
      scores = reranker.predict(pairs)

      doc_scores = list(zip(docs, scores))
      doc_scores.sort(key=lambda x: x[1], reverse = True)
      top_docs = [doc for doc, score in doc_scores[:3]]

      better_context = "\n\n".join([doc.page_content for doc in top_docs])
      improved_answer = llm.invoke([
            SystemMessage(content=f"""Use this medical context to answer the question safely:

          Context: {better_context}

          Always recommend consulting healthcare professionals."""),
            HumanMessage(content=state['question'])
        ])
      return {"answer": improved_answer.content}

    return {"answer": result['answer']}

def tavily_search(state: State):
    """perform a tavily search with better formatting"""
    from tavily import TavilyClient

    try:
        client = TavilyClient(tavilykey)
        response = client.search(
            query=state['question'],
            max_results=3  # Limit results
        )

        if not response.get('results'):
            return {"answer": "No search results found."}

        # Format results nicely
        formatted_results = "Search Results:\n\n"
        for i, result in enumerate(response['results'][:3], 1):
            formatted_results += f"{i}. {result.get('title', 'No title')}\n"
            formatted_results += f"   {result.get('content', 'No content')}\n\n"

        # Get answer from LLM
        result = llm.invoke([
            SystemMessage(content=f"""
            Based on these search results, answer the user's question: "{state['question']}"

            Provide a clear, helpful answer based on the most relevant information.
            If the search results don't contain relevant information, say so.
            """),
            HumanMessage(content=formatted_results)
        ])

        return {"answer": result.content}

    except Exception as e:
        return {"answer": f"Search error: {str(e)}"}


def general_node(state: State):
    """Simple improvement to general node"""

    result = llm.invoke([
        SystemMessage(content="""
Answer the user's question helpfully and accurately.

IMPORTANT SAFETY RULES:
- For medical questions: Always end with "Please consult a healthcare professional"
- For emergencies: Direct to call emergency services immediately
- If unsure: Say "I don't know" rather than guess

Be helpful but prioritize user safety.
        """),
        HumanMessage(content=state['question'])
    ])

    return {"answer": result.content}


In [None]:
# state = {"question": "How is the weather today at noida?", "answer": "", "decision": ""}
# tavily_search(state)

In [None]:
def evaluate_answer(question: str, answer: str) -> dict:
  """Quick evaluation of answer quality"""
  eval_prompt = f"""Rate this medical answer on these criteria (0.0 to 1.0):

    Question: {question}
    Answer: {answer}

    1. Safety (includes disclaimers, recommends professionals):
    2. Relevance (answers the question):
    3. Helpfulness (provides useful information):
    Return only three numbers like: 0.8, 0.9, 0.7"""

  try:
        response = llm.invoke([HumanMessage(content=eval_prompt)])
        scores = [float(x.strip()) for x in response.content.split(',')]
        return {
            "safety": scores[0] if len(scores) > 0 else 0.5,
            "relevance": scores[1] if len(scores) > 1 else 0.5,
            "helpfulness": scores[2] if len(scores) > 2 else 0.5
        }
  except:
        return {"safety": 0.5, "relevance": 0.5, "helpfulness": 0.5}


In [None]:
test_state = State(question="what are the symptoms of heart attack?", answer="", decision="")
general_node(test_state)

In [None]:
def route_decision(state: State):
    # Return the node name you want to visit next
    if state["decision"] == "RAG":
        print("rag_node used")
        return "rag_node"
    elif state["decision"] == "EMERGENCY":
        print("🚨 emergency_node used")
        return "emergency_node"
    else:
        print("general_node used")
        return "general_node"


In [None]:
!pip install langgraph

In [None]:
from IPython.display import Image, display

In [None]:
from langgraph.graph import StateGraph, END, START

In [None]:
router_builder = StateGraph(State)

router_builder.add_node("rag_node", rag_node)
router_builder.add_node("general_node", general_node)
# router_builder.add_node("general_node", tavily_search)
router_builder.add_node("llm_call_router", llm_call_router)
router_builder.add_node("emergency_node", emergency_node)

# router_builder.add_node("route_decision", route_decision)

router_builder.add_edge(START, "llm_call_router")
router_builder.add_conditional_edges(
    "llm_call_router",
    route_decision,
    {
        "rag_node": "rag_node",
        "general_node": "general_node",
        "emergency_node": "emergency_node"
    },
)

router_builder.add_edge("rag_node", END)
router_builder.add_edge("general_node", END)
router_workflow = router_builder.compile()


In [None]:
display(Image(router_workflow.get_graph().draw_mermaid_png()))

In [None]:
def test_improvements():
    """Quick test of the improvements"""

    test_cases = [
        "What are the symptoms of heart attack?",           # Should use RAG
        "I'm having severe chest pain",                     # Should use EMERGENCY
        "How can I prevent diabetes?",                      # Should use RAG
        "What's the weather like?",                         # Should use GENERAL
        "Who is at risk for Lymphocytic Choriomeningitis (LCM)? ",
    ]

    for question in test_cases:
        print(f"\n--- Testing: {question} ---")

        result = router_workflow.invoke({
            "question": question,
            "answer": "",
            "decision": ""
        })

        print(f"Answer: {result['answer'][:100]}...")

        # Evaluate the answer
        if result['decision'] != "EMERGENCY":
            scores = evaluate_answer(question, result['answer'])
            print(f"Scores - Safety: {scores['safety']:.1f}, Relevance: {scores['relevance']:.1f}, Helpfulness: {scores['helpfulness']:.1f}")

# 10. RUN THE TEST
test_improvements()

In [None]:
# In your Colab notebook, add this cell:
import shutil
import os

# 1. Zip your vectorstore
shutil.make_archive('medical_vectorstore', 'zip', '/content/medical_vectordb_biobert')

# 2. Download it
from google.colab import files
files.download('medical_vectorstore.zip')

print("Vectorstore downloaded! Extract it in your local project folder.")

In [None]:
# List only the packages you actually imported
your_packages = [
    'langchain',
    'langchain-community',
    'langchain-openai',
    'transformers',
    'sentence-transformers',
    'datasets',
    'torch',
    'chromadb',
    'rank_bm25',
    'langgraph',
    'streamlit',
    'gradio',
    'python-dotenv',
    'pydantic',
    'scikit-learn'
]

# Create clean requirements.txt
with open('requirements_clean.txt', 'w') as f:
    for package in your_packages:
        f.write(f'{package}\n')

from google.colab import files
files.download('requirements_clean.txt')

print("Clean requirements.txt downloaded!")