In [1]:
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

print("Loading NCCN/ESMO Guidelines...")
loader = PyPDFLoader("data/nccn_multiple_myeloma_guidelines.pdf")
documents = loader.load()

# Split into overlapping clinical chunks
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000, 
    chunk_overlap=150,
    separators=["\n\n", "\n", ".", " "] # Tries to break at paragraphs first
)
medical_chunks = text_splitter.split_documents(documents)
print(f"Created {len(medical_chunks)} clinical chunks.")

ModuleNotFoundError: No module named 'langchain.document_loaders'

In [None]:
from langchain_community.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from sentence_transformers import SentenceTransformer
from typing import List

# 1. Build the Custom Medical Embedding Class
class PubMedEmbeddings(Embeddings):
    def __init__(self, model_name="NeuML/pubmedbert-base-embeddings"):
        self.model = SentenceTransformer(model_name)

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return [self.model.encode(t).tolist() for t in texts]

    def embed_query(self, text: str) -> List[float]:
        return self.model.encode(text).tolist()

print("Initializing PubMedBERT Embeddings...")
medical_embeddings = PubMedEmbeddings()

# 2. Build and Save the FAISS Vector Database
print("Embedding chunks and building FAISS index...")
vector_db = FAISS.from_documents(medical_chunks, medical_embeddings)
vector_db.save_local("faiss_myeloma_index")
print("âœ… Vector DB successfully built and saved locally.")

In [None]:
from langchain.tools import tool

# Load the DB we just built
retriever = vector_db.as_retriever(search_kwargs={"k": 4}) # Fetch top 4 most relevant chunks

@tool
def retrieve_oncology_guidelines(query: str) -> str:
    """Use this tool to search the NCCN and ESMO guidelines for Multiple Myeloma treatment regimens, dosing, and eligibility."""
    docs = retriever.invoke(query)
    return "\n\n".join([d.page_content for d in docs])

@tool
def get_patient_vajram_profile(patient_id: str) -> str:
    """Use this tool to fetch the patient's latest risk score, bone marrow WSI results, and disease progression status."""
    # In production, this queries your VAJRAM backend database
    return """
    Patient ID: MMRF_2240
    Age: 65, Gender: Male
    - Module 2 (Risk): Critical Risk - Active MM with Renal Impairment (Cr: 2.5).
    - Module 3 (Vision): 45% abnormal plasma cell infiltration in bone marrow.
    - Module 4 (Progression): Rapid progression detected. M-Spike increased by 1.2 g/dL.
    - Transplant Eligibility: Eligible.
    """

tools = [retrieve_oncology_guidelines, get_patient_vajram_profile]

In [None]:
from langchain.agents import create_tool_calling_agent, AgentExecutor
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import HuggingFacePipeline
from transformers import pipeline

# 1. Load MedGemma 1.5 locally (using Hugging Face pipelines)
llm_pipeline = pipeline("text-generation", model="google/medgemma-1.5-4b-it", device_map="auto", max_new_tokens=500)
local_medgemma = HuggingFacePipeline(pipeline=llm_pipeline)

# 2. Bind the tools to the LLM
llm_with_tools = local_medgemma.bind_tools(tools)

# 3. Define the Agent's Persona
prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a world-class hematology-oncology AI assistant. You must ALWAYS use the `get_patient_vajram_profile` tool to understand the patient, and the `retrieve_oncology_guidelines` tool to formulate your treatment plan. Never guess treatment regimens."),
    ("user", "{input}"),
    ("placeholder", "{agent_scratchpad}"),
])

# 4. Create the Executor
agent = create_tool_calling_agent(llm_with_tools, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

# --- INFERENCE TEST ---
response = agent_executor.invoke({
    "input": "Can you recommend a first-line induction therapy for patient MMRF_2240?"
})
print("\n--- FINAL ONCOLOGY RECOMMENDATION ---")
print(response["output"])