In [None]:
from langchain_community.chat_models import ChatOllama
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, START, END
from typing_extensions import TypedDict
from IPython.display import Image, display
from pydantic import BaseModel, Field
import os
from dotenv import load_dotenv
from typing import List
load_dotenv()

## MODELS

In [None]:
neo4j_llm = ChatOllama(model="tomasonjo/llama3-text2cypher-demo")
sql_llm = ChatOllama(model="sqlcoder:7b")
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")

## SQL PROMPT

In [None]:
ddl_schema = """CREATE TABLE users (
user_id SERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
email VARCHAR(100) UNIQUE NOT NULL);

CREATE TABLE courses (
course_id SERIAL PRIMARY KEY,
title VARCHAR(255) NOT NULL,
description TEXT);

CREATE TABLE instructors (
instructor_id SERIAL PRIMARY KEY,
name VARCHAR(100) NOT NULL,
bio TEXT);

CREATE TABLE enrollments (
enrollment_id SERIAL PRIMARY KEY,
user_id INT REFERENCES users(user_id),
course_id INT REFERENCES courses(course_id),
enrollment_date DATE NOT NULL);

CREATE TABLE reviews (
review_id SERIAL PRIMARY KEY,
user_id INT REFERENCES users(user_id),
course_id INT REFERENCES courses(course_id),
rating INT CHECK (rating BETWEEN 1 AND 5),
comment TEXT);

CREATE TABLE course_instructors (
course_id INT REFERENCES courses(course_id),
instructor_id INT REFERENCES instructors(instructor_id),
PRIMARY KEY (course_id, instructor_id));"""

## NEO4J Connection

In [None]:
from langchain_community.graphs import Neo4jGraph
from langchain_community.chat_models import ChatOllama
from langchain_core.prompts import ChatPromptTemplate

graph_schema = Neo4jGraph(
    url=os.getenv("NEO4J_URI"),
    database=os.getenv("NEO4J_DATABASE"),
    username=os.getenv("NEO4J_USERNAME"),
    password=os.getenv("NEO4J_PASSWORD")
)

schema = graph_schema.schema

## Grade SQL

In [None]:
llm = ChatGroq(model = "gemma2-9b-it")

class GradeSqlQuery(BaseModel):
    binary_score: str = Field(
        description = "Sql is relevant to the schema, 'yes' or 'no'"
    )
    
structured_sql = llm.with_structured_output(GradeSqlQuery)

system = """You are a strict SQL schema relevance checker.

Task:
Given:
1. A database schema in DDL statements.
2. A SQL statement.

Rules:
- "Relevant" means the SQL statement only uses table names and columns that exist in the DDL schema.
- If it uses any table or column not present in the DDL, respond "No".
- Ignore syntax errors unless they involve non-existent tables or columns.
- Do not explain your reasoning.
- Respond with exactly one word: Yes or No.
"""

grade_sql_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "\nSchema (DDL):\n{ddl_schema}\nSQL statement:\n{sql_statement}")
])

sql_grader = grade_sql_prompt | structured_sql

## Grade Cypher

In [None]:
class GradeCypherQuery(BaseModel):
    binary_score: str = Field(
        description = "Cypher is relevant to the schema, 'yes' or 'no'"
    )
    
structured_cypher = llm.with_structured_output(GradeCypherQuery)

system = """You are given:
1. A Neo4j database schema, including:
   - Node labels and their properties
   - Relationship types, their properties, and which nodes they connect
2. A Cypher query.

Your task:
Determine if the Cypher query is relevant to the schema. 
"Relevant" means:
- The query only references node labels, properties, relationship types, and relationship properties that exist in the schema.
- The relationships used connect the correct node labels as per the schema.
- The query is semantically valid for the given schema.

Respond with exactly one word:
- "yes" if the query is relevant to the schema.
- "no" if it is not.
"""

grade_cypher_prompt = ChatPromptTemplate.from_messages([
    ("system", system),
    ("human", "\nSchema:\n{cypher_schema}\nCypher Query:\n{cypher_query}")
])

cypher_grader = grade_cypher_prompt | structured_cypher

## SQL Generation Prompt

In [None]:
prompt_template = ChatPromptTemplate.from_template("""
### Instructions:
Your task is to convert a question into a SQL query, given a Postgres database schema.
Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
- **Do not provide any explanation or text. Only output the SQL query inside a code block.**

### Input:
Generate a SQL query that answers the question `{question}`.
This query will run on a database whose schema is represented in this string:
{ddl_schema}

### Answer
Based on the provided database schema, here is the SQL query that answers {question}:
""")

# Chain: prompt → model → output parser
sql_chain = prompt_template | sql_llm | StrOutputParser()

### Cypher Generation Prompt

In [None]:
neo4j_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            "Given an input question, convert it to a Cypher query. No pre-amble.",
        ),
        (
            "human",
            (
                "Based on the Neo4j graph schema below, write a Cypher query that would answer the user's question: "
                "\n{schema} \nQuestion: {question} \nCypher query:"
            ),
        ),
    ]
)
cypher_chain = neo4j_prompt | llm

## SQL Regeneration Prompt

In [None]:
sql_regenerate = ChatPromptTemplate.from_template("""The following SQL was generated:
{previous_sql}

Please regenerate the SQL so that it is correct for the given schema, fixing only the error and keeping the original intent.

Schema:
{ddl_schema}

Request:
{question}

Corrected SQL:""")

# Chain: prompt → model → output parser
sql_regen_chain = sql_regenerate | sql_llm | StrOutputParser()  

## Cypher regeneration prompt

In [None]:
cypher_regenerate = ChatPromptTemplate.from_template("""The following Cypher was generated:
{previous_cypher}

Please regenerate the Cypher so that it is correct for the given schema, fixing only the error and keeping the original intent.

Schema:
{cypher_schema}

Request:
{question}

Corrected Cypher:""")

# Chain: prompt → model → output parser
cypher_regen_chain = cypher_regenerate | neo4j_llm | StrOutputParser()  

## Connect to Neo4j and Postgres to get results

In [None]:
from neo4j import GraphDatabase
import psycopg2
import pandas as pd

def cypher_connection(query):
    url=os.getenv("NEO4J_URI")
    database=os.getenv("NEO4J_DATABASE")
    username=os.getenv("NEO4J_USERNAME")
    password=os.getenv("NEO4J_PASSWORD")
    driver = GraphDatabase.driver(url, auth=(username, password))
    try:
        with driver.session(database=database) as session:
            result = session.run(query)
            return result.values()
    except Exception as e:
        print("Error:", e)
    finally:
        driver.close()

def sql_connection(query):
    host = os.getenv("POSTGRES_HOST")
    port = os.getenv("POSTGRES_PORT")
    dbname = os.getenv("POSTGRES_DB_NAME")
    user = os.getenv("POSTGRES_USERNAME")
    password = os.getenv("POSTGRES_PASSWORD")

    conn = psycopg2.connect(
        host=host,
        port=port,
        dbname=dbname,
        user=user,
        password=password
    )

    cur = conn.cursor()
    cur.execute(query)
    tables = cur.fetchall()
    result = [table[0] for table in tables]

    # Clean up
    cur.close()
    conn.close()
    return result


# Agentic starts here

### Define State

In [None]:
class State(TypedDict):
    sql_question: str
    cypher_question: str
    sql_stmt: str
    cypher_stmt: str
    cypher_ans: List
    sql_ans: List
    sql_regenerate: str
    cypher_regenerate: str
    regen_sql: str
    regen_cypher: str

### Node to generate SQL and Cypher

In [None]:
def generate_sql(state):
    print("======== GENERATE: SQL QUERY ========")
    question = state["sql_question"]
    sql_stmt = sql_chain.invoke({"ddl_schema": ddl_schema, "question": question})
    sql_stmt = sql_stmt.replace("<s> ","").replace("`","")
    print(sql_stmt)
    return {"sql_stmt": sql_stmt}

def generate_neo4j(state):
    print("======== GENERATE: CYPHER QUERY ========")
    question = state["cypher_question"]
    cypher_stmt =cypher_chain.invoke({"question": question, "schema": graph_schema.schema})
    cypher_stmt = cypher_stmt.content
    print(cypher_stmt)
    return {"cypher_stmt":cypher_stmt}


### Get Postgres and Neo4j results

In [None]:
def generate_sql_res(state):
    print("======== GENERATE: SQL QUERY RESULT ========")
    sql_stmt = state["sql_stmt"]
    sql_ans = sql_connection(state["sql_stmt"])
    return {"sql_ans": sql_ans}
    

def generate_cypher_res(state):
    print("======== GENERATE: CYPHER QUERY RESULT ========")
    cypher_stmt = state["cypher_stmt"]
    cypher_ans = cypher_connection(state["cypher_stmt"])
    return { "cypher_ans": cypher_ans}

### Grade the SQL and Cypher generated

In [None]:
def grade_sql_prompt(state):
    print("======== GRADE: SQL QUERY ========")
    question = state["sql_question"]
    sql_stmt = state["sql_stmt"]
    sql_score = sql_grader.invoke({"ddl_schema": ddl_schema, "sql_statement": sql_stmt})
    sql_grade = sql_score.binary_score
    sql_regenerate = ""
    if sql_grade == "yes":
        sql_regenerate = "no"
    else:
        sql_regenerate = "yes"
    print(sql_regenerate)
    return {"sql_regenerate": sql_regenerate, "sql_stmt": sql_stmt, "sql_question": question}
    
def grade_cypher_prompt(state):
    print("======== GRADE: CYPHER QUERY ========")
    question = state["cypher_question"]
    cypher_stmt = state["cypher_stmt"]
    cypher_score = cypher_grader.invoke({"cypher_schema": schema, "cypher_query": cypher_stmt})
    cypher_grade = cypher_score.binary_score
    cypher_regenerate = ""
    if cypher_grade == "yes":
        cypher_regenerate = "no"
    else:
        cypher_regenerate = "yes"
    print(cypher_regenerate)
    return {"cypher_regenerate": cypher_regenerate, "cypher_stmt": cypher_stmt,  "cypher_question": question}

### Regenerate SQL

In [None]:
def regenerate_sql(state):
    print("======== REGENERATE: SQL QUERY ========")
    question = state["sql_question"]
    sql_stmt = state["sql_stmt"]
    regen_sql = sql_regen_chain.invoke({"previous_sql":sql_stmt,  "ddl_schema": ddl_schema, "question": question})
    regen_sql = regen_sql.replace("<s> ","").replace("`","")
    return {"regen_sql": regen_sql, "sql_question": question}
    

### Regenerate Cypher

In [None]:
def regenerate_cypher(state):
    print("======== GENERATE: CYPHER QUERY ========")
    question = state["cypher_question"]
    cypher_stmt = state["cypher_stmt"]
    regen_cypher = cypher_regen_chain.invoke({"previous_cypher":cypher_stmt, "question": question, "cypher_schema": graph_schema.schema})
    regen_cypher = regen_cypher.content
    return {"regen_cypher": regen_cypher, "cypher_question": question}

### Decision to regenerate

In [None]:
def decision_cypher(state):
    print("======== DECISION: CYPHER QUERY ========")
    print(state["cypher_regenerate"])
    regenerate = state["cypher_regenerate"]
    if regenerate == "yes":
        return "regenerate_cypher"
    else:
        return "generate_cypher_res"
    
def decision_sql(state):
    print("======== GENERATE: SQL QUERY ========")
    print(state["sql_regenerate"])
    regenerate = state["sql_regenerate"]
    if regenerate == "yes":
        return "regenerate_sql"
    else:
        return "generate_sql_res"

### Create Nodes

In [None]:
graph = StateGraph(State)
graph.add_node("generate_sql", generate_sql)
graph.add_node("generate_cypher", generate_neo4j)
graph.add_node("grade_sql_prompt", grade_sql_prompt)
graph.add_node("grade_cypher_prompt", grade_cypher_prompt)
graph.add_node("generate_sql_res",generate_sql_res)
graph.add_node("generate_cypher_res", generate_cypher_res)
graph.add_node("regenerate_sql", regenerate_sql)
graph.add_node("regenerate_cypher", regenerate_cypher)

### Create Edges

In [None]:
graph.add_edge(START, "generate_sql")
graph.add_edge(START, "generate_cypher")
graph.add_edge("generate_sql","grade_sql_prompt")
graph.add_edge("generate_cypher","grade_cypher_prompt")
graph.add_conditional_edges(
    "grade_sql_prompt",
    decision_sql,
    {
        "regenerate_sql": "regenerate_sql",
        "generate_sql_res": "generate_sql_res"
    }
)
graph.add_conditional_edges(
    "grade_cypher_prompt",
    decision_cypher,
    {
        "regenerate_cypher": "regenerate_cypher",
        "generate_cypher_res": "generate_cypher_res"
    }
)
graph.add_edge("regenerate_sql", "generate_sql_res")
graph.add_edge("regenerate_cypher","generate_cypher_res")
graph.add_edge("generate_sql_res", END)
graph.add_edge("generate_cypher_res", END)

app = graph.compile()

In [None]:
display(Image(app.get_graph().draw_mermaid_png()))

In [None]:
question = "Find top 3 most-reviewed courses"
results = app.invoke({"sql_question": question, "cypher_question": question})

In [None]:
results["cypher_ans"]

In [None]:
results["sql_ans"]