In [None]:
from typing import Any, Dict
from typing_extensions import TypedDict
from dataclasses import dataclass, field

import pandas as pd
import duckdb

path_to_csv: str = f"../data/train.csv"

@dataclass
class Insightly:
    """
    A class to represent a connection to a DuckDB database.
    
    Attributes
    ----------
    conn : duckdb.DuckDBPyConnection
        The DuckDB connection object.
    """
    conn: duckdb.DuckDBPyConnection = field(default_factory=duckdb.connect)
    
    # reading the CSV file into a DuckDB table
    def read_csv_to_duckdb(
        self,
        path_to_csv: str,
        table_name: str = "titanic"
    ) -> None:
        """
        Reads a CSV file into a DuckDB table.
        
        Parameters
        ----------
        conn : duckdb.DuckDBPyConnection
            The DuckDB connection object.
        path_to_csv : str
            The path to the CSV file.
        table_name : str, optional
            The name of the table to create in DuckDB (default is "titanic").

        Returns
        -------
        None
        """
        self.conn.execute(
            f"""
            CREATE TABLE IF NOT EXISTS {table_name} AS
            SELECT * FROM read_csv_auto('{path_to_csv}')
            """
        )

    def retrieve_table(
        self,
        table_name: str = "titanic"
    ) -> duckdb.DuckDBPyRelation:
        """
        Retrieves a DuckDB table as a relation.
        
        Parameters
        ----------
        conn : duckdb.DuckDBPyConnection
            The DuckDB connection object.
        table_name : str, optional
            The name of the table to retrieve (default is "titanic").

        Returns
        -------
        duckdb.DuckDBPyRelation
            The relation representing the DuckDB table.
        """
        return self.conn.table(table_name)

    # get the schema using duckdb
    def get_schema(self) -> Dict[str, Any]:
        schema = self.conn.execute(
            f"""
            SELECT column_name, data_type
            FROM information_schema.columns
            """
        ).df()
        
        return str(schema)
    
    def execute_query(
        self,
        query: str
    ) -> pd.DataFrame:
        """
        Executes a SQL query on the DuckDB connection.
        
        Parameters
        ----------
        query : str
            The SQL query to execute
        
        Returns
        -------
        pd.DataFrame
            The result of the executed query as a relation.
        """
        return self.conn.sql(query)
    
# Example usage
insightly = Insightly()
insightly.read_csv_to_duckdb(path_to_csv)
table = insightly.retrieve_table()
# print(table.df())
schema = insightly.get_schema()
print(schema)



In [None]:
from typing import Optional
from enum import Enum, auto

from pydantic import Field, BaseModel
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables.config import RunnableConfig

class PlotType(str, Enum):
    """Plot type for the visualization of the data."""
    SCATTER = "SCATTER"
    # BAR = "BAR"
    # LINE = "LINE"
    # HISTOGRAM = "HISTOGRAM"
    # PIE = "PIE"
    # HEATMAP = "HEATMAP"

class SqlQueryInfo(TypedDict):
    sql_query: str
    query_result: str
    query_rows: list
    sql_error: bool

class PlotQueryInfo(TypedDict):
    plot_type: str
    columns: list[str]
    # plot_query: str
    query_result: str
    query_rows: list

class AgentState(TypedDict):
    question: str
    meant_as_query: bool
    sql_query_info: Optional[SqlQueryInfo] = None
    plot_query_info: Optional[PlotQueryInfo] = None
    attempts: int
    relevance: str

# this is made to be used with the LangChain framework and is a prompt
# to ensure that the answer given by the LLM is structured
class CheckRelevance(BaseModel):
    relevance: str = Field(
        description="Indicates whether the question is related to the database schema. 'relevant' or 'not_relevant'."
    )

def check_relevance(state: AgentState, config: RunnableConfig):
    question = state["question"]
    schema = insightly.get_schema()
    print(f"Checking relevance of the question: {question}")
    system = """You are an assistant that determines whether a given question is related to the following database schema.

Schema:
{schema}

Respond with only "relevant" or "not_relevant".
""".format(schema=schema)
    human = f"Question: {question}"
    check_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human),
        ]
    )
    llm = ChatOpenAI(temperature=0)
    structured_llm = llm.with_structured_output(CheckRelevance)
    relevance_checker = check_prompt | structured_llm
    relevance = relevance_checker.invoke({})
    state["relevance"] = relevance.relevance
    print(f"Relevance determined: {state['relevance']}")
    return state


In [None]:
class QueryType(str, Enum):
    SQL = "sql"
    PLOT = "plot"

class CheckIfSQLOrPlot(BaseModel):
    meant_as_query: QueryType = Field(
        description="Indicates whether the question requires an SQL query or a plot."
    )
    type_of_plot: PlotType = Field(
        description="The type of plot to be generated if the question requires a plot."
    )

def check_if_sql_or_plot(state: AgentState, config: RunnableConfig):
    question = state["question"]
    print(f"Checking if the question requires an SQL query or a plot: {question}")
    system = """
You are an assistant that determines whether a given question requires an SQL query or a plot based on the following schema:
{schema}

Respond with 'sql' if the question is related to data retrieval or manipulation that can be expressed in SQL,
and 'plot' if the question is related to data visualization.
If the question is related to data visualization, provide the type of plot as one of the following with no explanation: BAR.
"""
    human = f"Question: {question}"
    check_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human),
        ]
    )
    llm = ChatOpenAI(temperature=0)
    structured_llm = llm.with_structured_output(CheckIfSQLOrPlot)
    type_checker = check_prompt | structured_llm
    result = type_checker.invoke({})
    state["meant_as_query"] = result.meant_as_query
    if result.meant_as_query:
        state["sql_query_info"] = SqlQueryInfo(
            sql_query="",
            query_result="",
            query_rows=[]
        )
    else:
        state["plot_query_info"] = PlotQueryInfo(
            plot_type=result.type_of_plot.value,
            query_result="",
            query_rows=[]
        )
    print(f"Determined type: {state['meant_as_query']}")
    return state

In [None]:
class ConvertToSQL(BaseModel):
    sql_query: str = Field(
        description="The SQL query generated from the natural language question."
    )

def convert_nl_to_sql(state: AgentState, config: RunnableConfig):
    question = state["question"]
    schema = insightly.get_schema()
    print(f"Converting question to SQL: {question}")
    system = """You are an assistant that converts natural language questions into SQL queries based on the following schema:

{schema}

Provide only the SQL query without any explanations. Alias columns appropriately to match the expected keys in the result.

For example, alias 'food.name' as 'food_name' and 'food.price' as 'price'.
""".format(schema=schema)
    convert_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "Question: {question}"),
        ]
    )
    llm = ChatOpenAI(temperature=0)
    structured_llm = llm.with_structured_output(ConvertToSQL)
    sql_generator = convert_prompt | structured_llm
    result = sql_generator.invoke({"question": question})
    state["sql_query_info"]["sql_query"] = result.sql_query
    print(f"Generated SQL query: {state["sql_query_info"]['sql_query']}")
    return state

In [None]:
def execute_sql(state: AgentState):
    sql_query = state["sql_query"].strip()
    print(f"Executing SQL query: {sql_query}")
    try:
        result: duckdb.DuckDBPyRelation = insightly.execute_query(sql_query)
        if sql_query.lower().startswith("select"):
            rows = result.fetchall()
            columns = result.keys()
            if rows:
                header = ", ".join(columns)
                state["sql_query_info"]["query_rows"] = [dict(zip(columns, row)) for row in rows]
                print(f"Raw SQL Query Result: {state['query_rows']}")
                # Format the result for readability
                data = "; ".join([f"{row.get('food_name', row.get('name'))} for ${row.get('price', row.get('food_price'))}" for row in state["query_rows"]])
                formatted_result = f"{header}\n{data}"
            else:
                state["query_rows"] = []
                formatted_result = "No results found."
            state["sql_query_info"]["query_result"] = formatted_result
            state["sql_query_info"]["sql_error"] = False
            print("SQL SELECT query executed successfully.")
        else:
            result.commit()
            state["sql_query_info"]["query_result"] = "The action has been successfully completed."
            state["sql_query_info"]["sql_error"] = False
            print("SQL command executed successfully.")
    except Exception as e:
        state["sql_query_info"]["query_result"] = f"Error executing SQL query: {str(e)}"
        state["sql_query_info"]["sql_error"] = True
        print(f"Error executing SQL query: {str(e)}")
    return state

In [None]:
import plotly.express as px

class ScatterPlotter(BaseModel):
    plot: str = Field(
        description="The generated scatter plot as a string representation."
    )

def get_scatter_plot_columns(state: AgentState, config: RunnableConfig):
    question = state["question"]
    schema = insightly.get_schema()
    print(f"Converting question to SQL: {question}")
    system = """You are an assistant that chooses the appropriate columns for a scatter plot based on the following schema:

{schema}

Provide only the columns to be used in the scatter plot without any explanations.
The columns should be suitable for a scatter plot, typically two numerical columns.
""".format(schema=schema)
    convert_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", "Question: {question}"),
        ]
    )
    llm = ChatOpenAI(temperature=0)
    structured_llm = llm.with_structured_output(ConvertToSQL)
    sql_generator = convert_prompt | structured_llm
    result = sql_generator.invoke({"question": question})
    state["plot_query_info"]["columns"] = result.sql_query.split(", ")
    print(f"Selected columns for scatter plot: {state['plot_query_info']['columns']}")
    return state
    

In [None]:
def generate_scatter_plot(state: AgentState, config: RunnableConfig):
    columns = state.get("columns", [])
    if len(columns) != 2:
        raise ValueError("Scatter plot requires exactly two columns.")
    
    query = f"SELECT {columns[0]}, {columns[1]} FROM titanic"
    print(f"Generating scatter plot with query: {query}")
    
    # Execute the query and fetch the data
    df = insightly.conn.execute(query).df()
    
    # Generate a scatter plot (this is a placeholder, actual plotting code would go here)
    px.scatter(df, x=columns[0], y=columns[1], title="Scatter Plot").show()

In [None]:
from langchain_core.output_parsers import StrOutputParser

def generate_funny_response(state: AgentState):
    print("Generating a funny response for an unrelated question.")
    system = """You are a charming and funny assistant who responds in a playful manner.
    """
    human_message = "I can not help with that, but doesn't asking questions make you hungry? You can always order something delicious."
    funny_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            ("human", human_message),
        ]
    )
    llm = ChatOpenAI(temperature=0.7)
    funny_response = funny_prompt | llm | StrOutputParser()
    message = funny_response.invoke({})
    state["sql_query_info"]["query_result"] = message
    print("Generated funny response.")
    return state

In [None]:
class RewrittenQuestion(BaseModel):
    question: str = Field(description="The rewritten question.")

def regenerate_query(state: AgentState):
    question = state["question"]
    print("Regenerating the SQL query by rewriting the question.")
    system = """You are an assistant that reformulates an original question to enable more precise SQL queries. Ensure that all necessary details, such as table joins, are preserved to retrieve complete and accurate data.
    """
    rewrite_prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system),
            (
                "human",
                f"Original Question: {question}\nReformulate the question to enable more precise SQL queries, ensuring all necessary details are preserved.",
            ),
        ]
    )
    llm = ChatOpenAI(temperature=0)
    structured_llm = llm.with_structured_output(RewrittenQuestion)
    rewriter = rewrite_prompt | structured_llm
    rewritten = rewriter.invoke({})
    state["question"] = rewritten.question
    state["attempts"] += 1
    print(f"Rewritten question: {state['question']}")
    return state

In [None]:
from langgraph.graph import StateGraph, END

class State(str, Enum):
    CHECK_RELEVANCE: str = "check_relevance"
    CHECK_IF_SQL_OR_PLOT: str = "check_if_sql_or_plot"
    CONVERT_NL_TO_SQL: str = "convert_nl_to_sql"
    GET_SCATTER_PLOT_COLUMNS: str = "get_scatter_plot_columns"
    GENERATE_SCATTER_PLOT: str = "generate_scatter_plot"
    GENERATE_FUNNY_RESPONSE: str = "generate_funny_response"
    REGENERATE_QUERY: str = "regenerate_query"
    EXECUTE_SQL: str = "execute_sql"
    END_MAX_ITERATIONS: str = "end_max_iterations"
    
# ROUTERS
def end_max_iterations(state: AgentState):
    state["query_result"] = "Please try again."
    print("Maximum attempts reached. Ending the workflow.")
    return state

def relevance_router(state: AgentState) -> State:
    if state["relevance"] == "relevant":
        return State.CHECK_IF_SQL_OR_PLOT
    else:
        return State.GENERATE_FUNNY_RESPONSE
    
def type_router(state: AgentState) -> State:
    if state["query_type"]:
        return State.CONVERT_NL_TO_SQL
    else:
        return State.GET_SCATTER_PLOT_COLUMNS

workflow = StateGraph(AgentState)

# relevancy checks at the beginning to ensure question is relevant and can be answered
workflow.add_node(
    State.CHECK_RELEVANCE,
    check_relevance,
)
workflow.add_conditional_edges(
    State.CHECK_RELEVANCE,
    relevance_router,
    {
        State.CHECK_IF_SQL_OR_PLOT: State.CHECK_IF_SQL_OR_PLOT,
        State.GENERATE_FUNNY_RESPONSE: State.GENERATE_FUNNY_RESPONSE
    },
)
# checking if the question is meant to be answered with SQL statement or a plot
workflow.add_node(
    State.CHECK_IF_SQL_OR_PLOT,
    check_if_sql_or_plot,
)
# adding conditional edge to go different directions based on query type
workflow.add_conditional_edges(
    State.CHECK_IF_SQL_OR_PLOT,
    type_router,
    {
        QueryType.SQL: State.CONVERT_NL_TO_SQL,
        QueryType.PLOT: State.GET_SCATTER_PLOT_COLUMNS
    }
)
# if the question is meant to be answered with SQL,
# convert the natural language question to SQL query
workflow.add_node(
    State.CONVERT_NL_TO_SQL,
    convert_nl_to_sql,
)
workflow.add_node(
    State.EXECUTE_SQL,
    execute_sql,
)
workflow.add_edge(
    State.CONVERT_NL_TO_SQL,
    State.EXECUTE_SQL
)

# if the question is relevant but the SQL query is not executed successfully,
# regenerate the query by rewriting the question
workflow.add_node(
    State.REGENERATE_QUERY,
    regenerate_query,
)
workflow.add_edge(
    State.REGENERATE_QUERY,
    State.CONVERT_NL_TO_SQL
)
workflow.add_node(
    State.END_MAX_ITERATIONS,
    end_max_iterations
)
workflow.add_edge(
    State.END_MAX_ITERATIONS,
    END
)

# if the question is meant to be answered with a plot,
# get the columns for the scatter plot
workflow.add_node(
    State.GET_SCATTER_PLOT_COLUMNS,
    get_scatter_plot_columns,
)
workflow.add_node(
    State.GENERATE_SCATTER_PLOT,
    generate_scatter_plot,
)
workflow.add_edge(
    State.GET_SCATTER_PLOT_COLUMNS,
    State.GENERATE_SCATTER_PLOT
)
workflow.add_edge(
    State.GENERATE_SCATTER_PLOT,
    END
)

# if the question is not relevant, generate a funny response
workflow.add_node(
    State.GENERATE_FUNNY_RESPONSE,
    generate_funny_response,
)
workflow.add_edge(
    State.GENERATE_FUNNY_RESPONSE,
    END
)

workflow.set_entry_point(State.CHECK_RELEVANCE)

app = workflow.compile()


In [None]:
with open("diagram.md", "w") as f:
    f.write(app.get_graph().draw_mermaid())