<a href="https://colab.research.google.com/github/tomasonjo/blogs/blob/master/llm/nvidia_neo4j_langchain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install --upgrade --quiet langchain-nvidia-ai-endpoints langchain-community neo4j langchain-core nemoguardrails

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/647.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.8/647.5 kB[0m [31m6.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m647.5/647.5 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m303.0/303.0 kB[0m [31m721.9 kB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m33.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.2/92.2 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.0/55.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━

In [54]:
import os

from typing import Optional, Type, List, Dict, Tuple, Any

from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)

# Import things that are needed generically
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools import BaseTool
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_community.graphs import Neo4jGraph
from langchain_community.vectorstores.neo4j_vector import remove_lucene_chars
from langchain.tools.render import format_tool_to_openai_function
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import AIMessage, HumanMessage
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_to_openai_function_messages
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser


In [4]:
os.environ["NVIDIA_API_KEY"] = "nvapi-"

In [5]:
llm = ChatNVIDIA(model="meta/llama-3.1-8b-instruct")

In [6]:
result = llm.invoke("How to use LLMs in combination with Graph Databases? Be concise!")
print(result.content)

To use Large Language Models (LLMs) in combination with Graph Databases, follow these steps:

1. **API Integration**: Use an API like Gremlin or Cypher to query the Graph Database from your LLM development environment.
2. **Graph Data Preparation**: Convert graph data into a format that can be processed by the LLM, such as nodes, edges, and relationships.
3. **Node/Edge Embeddings**: Use techniques like node2vec or GraphSAGE to create embeddings for nodes and edges in the graph.
4. **LLM Training**: Use the prepared graph data and embeddings to train your LLM on graph-specific tasks, such as:
	* **Graph generation**: Train the LLM to predict missing nodes or edges in the graph.
	* **Graph classification**: Train the LLM to classify nodes or edges based on their properties.
	* **Graph property prediction**: Train the LLM to predict properties of nodes or edges, such as centrality or community membership.
5. **Querying the LLM**: Send queries to the trained LLM using the API integration.

In [9]:
os.environ["NEO4J_URI"] = "bolt://18.206.157.187:7687"
os.environ["NEO4J_USERNAME"] = "neo4j"
os.environ["NEO4J_PASSWORD"] = "elevation-reservist-thousands"

graph = Neo4jGraph(refresh_schema=False)

In [23]:
graph.query(
    "CREATE FULLTEXT INDEX drug IF NOT EXISTS FOR (d:Drug) ON EACH [d.name];"
)
graph.query(
    "CREATE FULLTEXT INDEX manufacturer IF NOT EXISTS FOR (d:Manufacturer) ON EACH [d.manufacturerName];"
)

[]

In [26]:
def generate_full_text_query(input: str) -> str:
    """
    Generate a full-text search query for a given input string.

    This function constructs a query string suitable for a full-text search.
    It processes the input string by splitting it into words and appending a
    similarity threshold (~2) to each word, then combines them using the AND
    operator. Useful for mapping movies and people from user questions
    to database values, and allows for some misspelings.
    """
    full_text_query = ""
    words = [el for el in remove_lucene_chars(input).split() if el]
    for word in words[:-1]:
        full_text_query += f" {word}~2 AND"
    full_text_query += f" {words[-1]}~2"
    return full_text_query.strip()


candidate_query = """
CALL db.index.fulltext.queryNodes($index, $fulltextQuery, {limit: $limit})
YIELD node
RETURN coalesce(node.manufacturerName, node.name) AS candidate,
       labels(node)[0] AS label
"""


def get_candidates(input: str, type: str, limit: int = 3) -> List[Dict[str, str]]:
    """
    Retrieve a list of candidate entities from database based on the input string.

    This function queries the Neo4j database using a full-text search. It takes the
    input string, generates a full-text query, and executes this query against the
    specified index in the database. The function returns a list of candidates
    matching the query, with each candidate being a dictionary containing their name
    (or title) and label (either 'Person' or 'Movie').
    """
    ft_query = generate_full_text_query(input)
    candidates = graph.query(
        candidate_query, {"fulltextQuery": ft_query, "index": type, "limit": limit}
    )
    return candidates

In [27]:
get_candidates("Acadia", "manufacturer")

[{'candidate': 'ACADIA PHARMACEUTICALS', 'label': 'Manufacturer'}]

In [28]:
get_candidates("voriconazol", "drug")

[{'candidate': 'VORICONAZOLE', 'label': 'Drug'}]

In [42]:



def get_side_effects(
    drug: Optional[str] = None,
    min_age: Optional[int] = None,
    max_age: Optional[int] = None,
    manufacturer: Optional[str] = None,
) -> str:
    """Get the side effects of a drug."""
    params = {}
    filters = []
    side_effects_base_query = """
    MATCH (c:Case)-[:HAS_REACTION]->(r:Reaction), (c)-[:IS_PRIMARY_SUSPECT]->(d:Drug)
    """
    if drug or min_age or max_age or manufacturer:
        side_effects_base_query += " WHERE "
    if drug:
        candidate_drugs = [el["candidate"] for el in get_candidates(drug, "drug")]
        if not candidate_drugs:
            return "The mentioned drug was not found"
        filters.append("d.name IN $drugs")
        params["drugs"] = candidate_drugs
    if min_age:
        filters.append("c.age > $min_age ")
        params["min_age"] = min_age
    if max_age:
        filters.append("c.age < $max_age ")
        params["max_age"] = max_age
    if manufacturer:
        candidate_manufacturers = [
            el["candidate"] for el in get_candidates(manufacturer, "manufacturer")
        ]
        if not candidate_manufacturers:
            return "The mentioned manufacturer was not found"
        filters.append(
            "EXISTS {(c)<-[:REGISTERED]-(:Manufacturer {manufacturerName: $manufacturer})}"
        )
        params["manufacturer"] = candidate_manufacturers[0]
    if filters:
        side_effects_base_query += " AND ".join(filters)
    side_effects_base_query += """
    RETURN d.name AS drug, r.description AS side_effect, count(*) AS count
    ORDER BY count DESC
    LIMIT 10
    """
    print(side_effects_base_query)
    print(params)
    data = graph.query(side_effects_base_query, params=params)
    return data

In [43]:
get_side_effects(manufacturer="acadia")


    MATCH (c:Case)-[:HAS_REACTION]->(r:Reaction), (c)-[:IS_PRIMARY_SUSPECT]->(d:Drug)
     WHERE EXISTS {(c)<-[:REGISTERED]-(:Manufacturer {manufacturerName: $manufacturer})}
    RETURN d.name AS drug, r.description AS side_effect, count(*) AS count
    ORDER BY count DESC
    LIMIT 10
    
{'manufacturer': 'ACADIA PHARMACEUTICALS'}


[{'drug': 'NUPLAZID', 'side_effect': 'Hallucination', 'count': 13},
 {'drug': 'NUPLAZID', 'side_effect': 'Confusional state', 'count': 7},
 {'drug': 'NUPLAZID', 'side_effect': 'Fall', 'count': 6},
 {'drug': 'NUPLAZID', 'side_effect': 'Delusion', 'count': 5},
 {'drug': 'NUPLAZID', 'side_effect': 'Gait disturbance', 'count': 5},
 {'drug': 'NUPLAZID', 'side_effect': 'Fatigue', 'count': 4},
 {'drug': 'NUPLAZID', 'side_effect': 'Abnormal behaviour', 'count': 3},
 {'drug': 'NUPLAZID',
  'side_effect': 'Product dose omission issue',
  'count': 3},
 {'drug': 'NUPLAZID', 'side_effect': 'Agitation', 'count': 3},
 {'drug': 'NUPLAZID', 'side_effect': 'Death', 'count': 3}]

In [45]:
get_side_effects(drug="aspirin")


    MATCH (c:Case)-[:HAS_REACTION]->(r:Reaction), (c)-[:IS_PRIMARY_SUSPECT]->(d:Drug)
     WHERE d.name IN $drugs
    RETURN d.name AS drug, r.description AS side_effect, count(*) AS count
    ORDER BY count DESC
    LIMIT 10
    
{'drugs': ['ASPIRINE', 'ASPRIN', 'ASPIRIN']}


[]

In [64]:
class SideEffectsInput(BaseModel):
    drug: Optional[str] = Field(description="disease mentioned in the question")
    min_age: Optional[int] = Field(description="Minimum age of the patient")
    max_age: Optional[int] = Field(description="Maximum age of the patient")
    manufacturer: Optional[str] = Field(description="manufacturer of the drug")


class SideEffectsTool(BaseTool):
    name = "SideEffects"
    description = "useful for when you need to find common side effects"
    args_schema: Type[BaseModel] = SideEffectsInput

    def _run(
        self,
        drug: Optional[str],
        min_age: Optional[int],
        max_age: Optional[int],
        manufacturer: Optional[str],
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool."""
        return get_side_effects(drug, min_age, max_age, manufacturer)

    async def _arun(
        self,
        drug: Optional[str],
        min_age: Optional[int],
        max_age: Optional[int],
        manufacturer: Optional[str],
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        """Use the tool asynchronously."""
        return get_side_effects(drug, min_age, max_age, manufacturer)


In [82]:
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import tool


@tool
def get_side_effects(
    drug: Optional[str] = Field(description="disease mentioned in the question"),
    min_age: Optional[int] = Field(description="Minimum age of the patient"),
    max_age: Optional[int] = Field(description="Maximum age of the patient"),
    manufacturer: Optional[str] = Field(description="manufacturer of the drug")
):
    """Useful for when you need to find common side effects."""
    params = {}
    filters = []
    side_effects_base_query = """
    MATCH (c:Case)-[:HAS_REACTION]->(r:Reaction), (c)-[:IS_PRIMARY_SUSPECT]->(d:Drug)
    """
    if drug or min_age or max_age or manufacturer:
        side_effects_base_query += " WHERE "
    if drug:
        candidate_drugs = [el["candidate"] for el in get_candidates(drug, "drug")]
        if not candidate_drugs:
            return "The mentioned drug was not found"
        filters.append("d.name IN $drugs")
        params["drugs"] = candidate_drugs
    if min_age:
        filters.append("c.age > $min_age ")
        params["min_age"] = min_age
    if max_age:
        filters.append("c.age < $max_age ")
        params["max_age"] = max_age
    if manufacturer:
        candidate_manufacturers = [
            el["candidate"] for el in get_candidates(manufacturer, "manufacturer")
        ]
        if not candidate_manufacturers:
            return "The mentioned manufacturer was not found"
        filters.append(
            "EXISTS {(c)<-[:REGISTERED]-(:Manufacturer {manufacturerName: $manufacturer})}"
        )
        params["manufacturer"] = candidate_manufacturers[0]
    if filters:
        side_effects_base_query += " AND ".join(filters)
    side_effects_base_query += """
    RETURN d.name AS drug, r.description AS side_effect, count(*) AS count
    ORDER BY count DESC
    LIMIT 10
    """
    print(side_effects_base_query)
    print(params)
    data = graph.query(side_effects_base_query, params=params)
    return data

In [87]:
tools = [get_side_effects]
llm_with_tools = llm.bind_tools(tools=tools)
response = llm_with_tools.invoke("What is the most common side effect when using volazopile?")
print(response)
print(response.tool_calls)

content='{"type": "function", "name": "get_side_effects", "parameters": {"drug": "volazopile", "min_age": "0", "max_age": "100", "manufacturer": "manufacturer of volazopile"}}' response_metadata={'role': 'assistant', 'content': '{"type": "function", "name": "get_side_effects", "parameters": {"drug": "volazopile", "min_age": "0", "max_age": "100", "manufacturer": "manufacturer of volazopile"}}', 'token_usage': {'prompt_tokens': 291, 'total_tokens': 342, 'completion_tokens': 51}, 'finish_reason': 'stop', 'model_name': 'meta/llama-3.1-8b-instruct'} id='run-a9c794d3-baf5-491b-b5a1-995b03568207-0' role='assistant'
[]


In [88]:
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "You are a helpful assistant that finds information about common side effects. "
            "If tools require follow up questions, "
            "make sure to ask the user for clarification. Make sure to include any "
            "available options that need to be clarified in the follow up questions "
            "Do only the things the user specifically requested. ",
        ),
        MessagesPlaceholder(variable_name="chat_history"),
        ("user", "{input}"),
        MessagesPlaceholder(variable_name="agent_scratchpad"),
    ]
)

In [69]:
def _format_chat_history(chat_history: List[Tuple[str, str]]):
    buffer = []
    for human, ai in chat_history:
        buffer.append(HumanMessage(content=human))
        buffer.append(AIMessage(content=ai))
    return buffer


agent = (
    {
        "input": lambda x: x["input"],
        "chat_history": lambda x: _format_chat_history(x["chat_history"])
        if x.get("chat_history")
        else [],
        "agent_scratchpad": lambda x: format_to_openai_function_messages(
            x["intermediate_steps"]
        ),
    }
    | prompt
    | llm_with_tools
    | OpenAIFunctionsAgentOutputParser()
)


# Add typing for input
class AgentInput(BaseModel):
    input: str
    chat_history: List[Tuple[str, str]] = Field(
        ..., extra={"widget": {"type": "chat", "input": "input", "output": "output"}}
    )


class Output(BaseModel):
    output: Any


agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True).with_types(
    input_type=AgentInput, output_type=Output
)

In [89]:
agent_executor.invoke({"input": "What is the most common side effect when using volazopile?"})



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mTo find the most common side effect of volazopile, I would need to know more about the drug and its usage. Can you please provide me with the following information to help me narrow down the search? Manufacturer, age range, and the specific condition being treated by volazopile.[0m

[1m> Finished chain.[0m


{'input': 'What is the most common side effect when using volazopile?',
 'output': 'To find the most common side effect of volazopile, I would need to know more about the drug and its usage. Can you please provide me with the following information to help me narrow down the search? Manufacturer, age range, and the specific condition being treated by volazopile.'}