In [None]:
from dotenv import load_dotenv

load_dotenv()

In [2]:
from boto3 import client
from botocore.config import Config

# Increase Bedrock read timeout
config = Config(read_timeout=1000)
client = client(service_name="bedrock-runtime", config=config)

# Define LLM

In [3]:
from langchain_aws.chat_models.bedrock import ChatBedrock

# Use claude 3.5 for creating cypher query and generating graph database
advanced_llm = ChatBedrock(
    model_id="anthropic.claude-3-sonnet-20240229-v1:0",
    temperature=0,
    max_tokens=1000,
    model_kwargs={
        "anthropic_version": "bedrock-2023-05-31",
    },
    config=config
)

# Use less complex llm model for simple request routing task, e.g. Mistral 8*7B
basic_llm = ChatBedrock(
    model_id="mistral.mixtral-8x7b-instruct-v0:1",
    temperature=0,
    max_tokens=1000,
)

# Define Embedding model

In [4]:
from langchain_aws.embeddings import BedrockEmbeddings

embeddings = BedrockEmbeddings(
    model_id="amazon.titan-embed-text-v2:0",
)

# Deine a Web Search Agent to Get Data

In [5]:
import arxiv

search_query = "agent OR 'large language model' OR 'prompt engineering'"
max_results = 2

# Fetch papers from arXiv
client = arxiv.Client()
search = arxiv.Search(
    query=search_query, max_results=max_results, sort_by=arxiv.SortCriterion.Relevance
)


# Define a Vectorstore

In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_milvus import Milvus


docs = []
for result in client.results(search):
    docs.append(
        {"title": result.title, "summary": result.summary, "url": result.entry_id}
    )

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=500, chunk_overlap=50
)
doc_splits = text_splitter.create_documents(
    [doc["summary"] for doc in docs], metadatas=docs
)

print(f"Number of papers: {len(docs)}")
print(f"Number of chunks: {len(doc_splits)}")


# Add to Milvus
vectorstore = Milvus.from_documents(
    documents=doc_splits,
    collection_name="rag_milvus",
    embedding=embeddings,
    connection_args={"uri": "./milvus_ingest.db"},
    index_params={"index_type": "FLAT"},
)
retriever = vectorstore.as_retriever()

# Graph Database Setup

In [None]:
from langchain_neo4j.graphs.neo4j_graph import Neo4jGraph
from langchain_experimental.graph_transformers import LLMGraphTransformer

graph = Neo4jGraph()

graph_transformer = LLMGraphTransformer(
    llm=advanced_llm,
    allowed_nodes=["Paper", "Author", "Topic"],
    node_properties=["title", "summary", "url"],
    allowed_relationships=["AUTHORED", "DISCUSSES", "RELATED_TO"],
)

graph_documents = graph_transformer.convert_to_graph_documents(doc_splits)

graph.add_graph_documents(graph_documents)

print(f"Graph documents: {len(graph_documents)}")
print(f"Nodes from 1st graph doc:{graph_documents[0].nodes}")
print(f"Relationships from 1st graph doc:{graph_documents[0].relationships}")

# Define Native RAG Chain

In [9]:
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

# Prompt
prompt = PromptTemplate(
    template="""You are an assistant for question-answering tasks. 
    Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. 
    Use three sentences maximum and keep the answer concise:
    Question: {question} 
    Context: {context} 
    Answer: 
    """,
    input_variables=["question", "document"],
)



# Post-processing
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# Chain
rag_chain = prompt | basic_llm | StrOutputParser()

# Define GraphRAG Chain

In [12]:
# GraphRAG setup

from langchain.prompts import PromptTemplate
from langchain_neo4j.chains.graph_qa.cypher import GraphCypherQAChain


cypher_prompt = PromptTemplate(
    template="""You are an expert at generating Cypher queries for Neo4j.
    Use the following schema to generate a Cypher query that answers the given question.
    Make the query flexible by using case-insensitive matching and partial string matching where appropriate.
    Focus on searching paper titles as they contain the most relevant information.
    
    Schema:
    {schema}
    
    Question: {question}
    
    Cypher Query:""",
    input_variables=["schema", "question"],
)


# QA prompt
qa_prompt = PromptTemplate(
    template="""You are an assistant for question-answering tasks. 
    Use the following Cypher query results to answer the question. If you don't know the answer, just say that you don't know. 
    Use three sentences maximum and keep the answer concise. If topic information is not available, focus on the paper titles.
    
    Question: {question} 
    Cypher Query: {query}
    Query Results: {context} 
    
    Answer:""",
    input_variables=["question", "query", "context"],
)


# Chain
graph_rag_chain = GraphCypherQAChain.from_llm(
    allow_dangerous_requests=True,
    cypher_llm=advanced_llm,
    qa_llm=advanced_llm,
    validate_cypher=True,
    graph=graph,
    verbose=True,
    return_intermediate_steps=True,
    return_direct=True,
    cypher_prompt=cypher_prompt,
    qa_prompt=qa_prompt,
)

# Test Run

In [13]:
# Example input data
# question = "What techniques are used for Multi-Agent? "
question = "What paper talk about Multi-Agent?"

In [None]:
vector_context = rag_chain.invoke({"context": format_docs(retriever.invoke(question)), "question": question})

print(vector_context)

In [None]:
graph_context = graph_rag_chain.invoke({"query": question})

print(graph_context)

In [16]:
### Composite Vector + Graph Generations

from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser


# Prompt
prompt = PromptTemplate(
    template="""You are an assistant for question-answering tasks. 
    Use the following pieces of retrieved context from a vector store and a graph database to answer the question. If you don't know the answer, just say that you don't know. 
    Use three sentences maximum and keep the answer concise:
    Question: {question} 
    Vector Context: {context} 
    Graph Context: {graph_context}
    Answer: 
    """,
    input_variables=["question", "context", "graph_context"],
)




In [None]:
# Run the chain
composite_chain = prompt | basic_llm | StrOutputParser()
answer = composite_chain.invoke(
    {"question": question, "context": vector_context, "graph_context": graph_context}
)

print(answer)