In [None]:

import streamlit as st
import json
import os
import nest_asyncio
nest_asyncio.apply()
from io import StringIO
import plotly.express as px
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider

def app():
    if "history" not in st.session_state:
        st.session_state.history = []

    if "user_history" not in st.session_state:
        st.session_state.user_history = ""

    st.title("Welcome to the Insto Querying Engine")
    st.markdown("A Querying Engine for Institutional Data. Inputs user query, outputs data based from Hamilton nodes")

    st.divider()

    models = ["gemini-2.5-pro", "gpt-4o", "claude-sonnet-4", "gemini-2.5-flash", "gpt-4o-mini", "claude-3-5-haiku", "claude-3-5-sonnet", "claude-3-5-sonnet-v2", "claude-3-7-sonnet", "claude-3-haiku", "claude-opus-4", "dall-e-3", "gemini-2.0-flash-001", "gemini-2.0-flash-lite-001", "gpt-image-1", "o1", "o1-mini", "o3", "o3-mini", "o4-mini", "self_hosted_llama33_70", "text-embedding-3-large", "text-embedding-3-small", "text-embedding-ada-002"]

    # Create a select box for the models
    if "selected_model" not in st.session_state:
        st.session_state.selected_model = models[0]
    st.session_state.selected_model = st.sidebar.selectbox("Select OpenAI model", models, index=models.index(st.session_state.selected_model))

    agent = Agent(
    model=OpenAIModel(st.session_state.selected_model, provider=OpenAIProvider(api_key=os.environ['CDP_CLIENT_API_KEY'], base_url=base_url)),
    output_type=NodeAttributes,
    system_prompt=get_prompt
    )

    agent_aggregator = Agent(
        model=OpenAIModel(st.session_state.selected_model, provider=OpenAIProvider(api_key=os.environ['CDP_CLIENT_API_KEY'], base_url=base_url)),
        model_settings={'temperature': 0.0, 'seed': 42},
        output_type=[AggregationPlan],
        system_prompt="""Generate a step by step plan as list."""  # we'll overwrite this later with the correct system prompt
    )

    agent_graphs = Agent(
        model=OpenAIModel(st.session_state.selected_model, provider=OpenAIProvider(api_key=os.environ['CDP_CLIENT_API_KEY'], base_url=base_url)),
        model_settings={'temperature': 0.0, 'seed': 42},
        output_type=ChartRecommendation,
        system_prompt="""Generate a step by step plan as list."""  # we'll overwrite this later with the correct system prompt
    )

    agent_router = Agent(
        model=OpenAIModel(st.session_state.selected_model, provider=OpenAIProvider(api_key=os.environ['CDP_CLIENT_API_KEY'], base_url=base_url)),
        model_settings={'temperature': 0.0, 'seed': 42},
        output_type=QueryAction,
        system_prompt=router_prompt
    )

    agent_else = Agent(
        model=OpenAIModel(st.session_state.selected_model, provider=OpenAIProvider(api_key=os.environ['CDP_CLIENT_API_KEY'], base_url=base_url)),
        model_settings={'temperature': 0.0, 'seed': 42},
        system_prompt="""Take the user's query and provide a response. If the user's query is not related to data, prompt the user in a friendly and professional manner.'"""
    )

    # with st.container():
    st.markdown('<div class="sticky-container">', unsafe_allow_html=True)

    # Business Area Browser Section (moved to main page)
    st.caption("Select business area:")

    if "selected_business_area" not in st.session_state:
        st.session_state.selected_business_area = None

    business_areas = list(im.feature_descriptions.keys())

    # Create buttons for business areas in a more compact grid layout
    cols = st.columns(6)  # 6 columns for more compact layout

    for i, business_area in enumerate(business_areas):
        with cols[i % 6]:
            # Check if this business area is currently selected
            is_selected = st.session_state.get('selected_business_area') == business_area
            button_type = "primary" if is_selected else "secondary"
           
            if st.button(business_area, key=f"business_{i}", type=button_type):
                st.session_state.selected_business_area = business_area
                st.rerun()
   
    st.divider()

    # Initialize messages if not exists
    if "messages" not in st.session_state:
        st.session_state.messages = []

    if "feature_name" not in st.session_state:
        st.session_state.feature_name = None

    # Display chat messages from history on app rerun
    # for message in st.session_state.messages:
    #     with st.chat_message(message["role"]):
    #         st.markdown(message["content"])

    default = ""
    if st.session_state.selected_business_area:
        default = defaults[st.session_state.selected_business_area]


    # Accept user input
    if prompt := st.chat_input(default):
        if st.session_state.selected_business_area == None:
            st.error("Please select a business area first! Choose an area from the buttons above.")
            return
       
        else: # Add user message to chat history
            st.session_state.messages.append({"role": "user", "content": prompt})
            # Display user message in chat message container
            with st.chat_message("user"):
                st.markdown(prompt)

            # Display assistant response in chat message container
            with st.chat_message("assistant"):
                try:
                    # Build conversation context for the agent
                    conversation_context = "\n".join([
                        f"{msg['role']}: {msg['content']}"
                        for msg in st.session_state.messages[:-1]  # Exclude the current user message
                    ])

                    if len(st.session_state.messages) > 1:
                        st.session_state.user_history = st.session_state.messages[-2]['content']
                    else:
                        st.session_state.user_history = ""

                    # Get selected business area if one exists, otherwise use None
                    selected_area = st.session_state.get('selected_business_area', None)
                    # Prepare the full prompt with context
                    user_query = f"{conversation_context}\nuser: {prompt}" if conversation_context else prompt
                    user_input = user_query + ". `business_area` is " + selected_area

                    start_date, end_date, feature_name, business, domain, granularity, hamilton_description, source_file, node_match = im.get_query_attributes(user_input, agent)
                    if feature_name in feature_aggregations:
                        agg_func = feature_aggregations[st.session_state.selected_business_area][feature_name]
                    else:
                        agg_func = "sum"
                    st.session_state.node_description = hamilton_description
                   
                    # Get response from Pydantic AI agent
                    with st.spinner("Thinking..."):
                        if "active_df" not in st.session_state:

                            # result = agent.run_sync(full_prompt)
                            response = im.get_feature_output(feature_name, start_date, end_date)
                            df = response.copy()
                            result = im.groupby_aggregation_from_text(df, "ds", feature_name, agg_func)
                            result = im.format_large_number(result)
                            if agg_func == "last":
                                st.markdown(f'###### The {agg_func} {feature_name} from {start_date} to {end_date} was {result}. ')
                            else:
                                st.markdown(f'###### The {agg_func} of {feature_name} from {start_date} to {end_date} was {result}. ')

                            st.subheader(f"Sample Data for: {feature_name}")
                            df_old, df = im.aggregate_df(df, agg_func, feature_name)
                            st.dataframe(df_old, hide_index=True)

                            st.session_state.active_df = df_old.copy()
                            st.session_state.feature_name = feature_name

                            df_download = st.session_state.active_df
                            # Convert DataFrame to CSV
                            csv_buffer = StringIO()
                            df_download.to_csv(csv_buffer, index=False)
                            csv_data = csv_buffer.getvalue()
                            # Display download button
                            st.download_button(
                                label="Download as CSV",
                                data=csv_data,
                                file_name="active_data.csv",
                                mime="text/csv"
                            )

                            st.markdown(f'###### Please let me know if I can filter or aggregate the data differently! Would you like me to plot the data?')
                            st.session_state.history.append((st.session_state.user_history, df_download.copy(), result.__str__()))

                        else:
                            action = agent_router.run_sync(user_input).output.action

                            if action == 'get_data':
                                # result = agent.run_sync(full_prompt)
                                response = im.get_feature_output(feature_name, start_date, end_date)
                                df = response.copy()
                                result = im.groupby_aggregation_from_text(df, "ds", feature_name, agg_func)
                                result = im.format_large_number(result)
                                if agg_func == "last":
                                    st.markdown(f'###### The {agg_func} {feature_name} from {start_date} to {end_date} was {result}.')
                                else:
                                    st.markdown(f'###### The {agg_func} of {feature_name} from {start_date} to {end_date} was {result}.')

                                st.subheader(f"Sample Data for: {feature_name}")
                                df_old, df = im.aggregate_df(df, agg_func, feature_name)
                                st.dataframe(df_old, hide_index=True)

                                st.session_state.active_df = df_old.copy()
                                st.session_state.feature_name = feature_name

                                df_download = st.session_state.active_df
                                # Convert DataFrame to CSV
                                csv_buffer = StringIO()
                                df_download.to_csv(csv_buffer, index=False)
                                csv_data = csv_buffer.getvalue()
                                # Display download button
                                st.download_button(
                                    label="Download as CSV",
                                    data=csv_data,
                                    file_name="active_data.csv",
                                    mime="text/csv"
                                )

                                st.session_state.history.append((st.session_state.user_history, df_download.copy(), result.__str__()))

                            elif action == 'aggregate_data':
                                agent_aggregator._system_prompts = (generate_system_prompt(st.session_state.active_df))
                                result = agent_aggregator.run_sync(user_input)
                                if hasattr(result, "output"):
                                    kwargs = result.output.__dict__
                                    st.write(result.__str__())
                                    st.write(kwargs.get('sql', ''))
                                    func_name = kwargs.pop('function')
                                    try:
                                        new_df = SAFE_TOOLS[func_name](st.session_state.active_df, **kwargs)
                                        st.session_state.history.append((st.session_state.user_history, new_df.copy(), result.__str__()))
                                        st.session_state.active_df = new_df.copy()
                                        st.success("Applied transformation")
                                    except Exception as e:
                                        st.error(f"Could not request - do nothing!\nerror_message={str(e)}")

                                st.session_state.feature_name = feature_name
                                st.dataframe(st.session_state.active_df, hide_index=True)

                                df_download = st.session_state.active_df
                                # Convert DataFrame to CSV
                                csv_buffer = StringIO()
                                df_download.to_csv(csv_buffer, index=False)
                                csv_data = csv_buffer.getvalue()
                                # Display download button
                                st.download_button(
                                    label="Download as CSV",
                                    data=csv_data,
                                    file_name="active_data.csv",
                                    mime="text/csv"
                                )

                            elif action == 'graph_data':
                                df_download = st.session_state.active_df
                                # Convert DataFrame to CSV
                                csv_buffer = StringIO()
                                df_download.to_csv(csv_buffer, index=False)
                                csv_data = csv_buffer.getvalue()
                                # Display download button
                                st.download_button(
                                    label="Download as CSV",
                                    data=csv_data,
                                    file_name="active_data.csv",
                                    mime="text/csv"
                                )

                                agent_graphs._system_prompts = (generate_graphing_system_prompt(st.session_state.active_df, context=st.session_state.node_description))
                                result = agent_graphs.run_sync(user_input)
                                if hasattr(result, "output"):
                                    # st.write(result.__str__())
                                    try:
                                        graph = im.plot_with_agent_recommendation(st.session_state.active_df, result.output)
                                        st.session_state.history.append((st.session_state.user_history, st.session_state.active_df.copy(), result.__str__()))
                                        # st.session_state.active_df = new_df.copy()
                                        st.success("Plotted successfully")
                                    except Exception as e:
                                        st.error(f"Could not request - do nothing!\nerror_message={str(e)}")

                                st.session_state.feature_name = feature_name
                                st.plotly_chart(graph, use_container_width=True)

                            elif action == 'outlier_analysis':
                                df_download = st.session_state.active_df.copy()
                                ad_result = anomalyDetection(df_download)
                                if not ad_result:
                                    st.write("No outliers detected")
                                else:
                                    for i in ad_result:
                                        fig = px.line(ad_result[i], x=list(ad_result[i][ad_result[i].columns[0]]), y=list(ad_result[i][ad_result[i].columns[-1]]), title=f"Outlier Analysis for {i}")
                                        st.plotly_chart(fig, use_container_width=True)

                            else:
                                df_download = st.session_state.active_df
                                # Convert DataFrame to CSV
                                csv_buffer = StringIO()
                                df_download.to_csv(csv_buffer, index=False)
                                csv_data = csv_buffer.getvalue()
                                # Display download button
                                st.download_button(
                                    label="Download as CSV",
                                    data=csv_data,
                                    file_name="active_data.csv",
                                    mime="text/csv"
                                )

                                st.write(agent_else.run_sync(user_input).output)
                                st.session_state.history.append((st.session_state.user_history, st.session_state.active_df.copy(), result.__str__()))

                except Exception as e:
                    st.error(f"Error generating response: {str(e)}")
                    response = "Sorry, I encountered an error while processing your request."

    # Add a "Clear Chat" button to the sidebar
    if st.sidebar.button('Clear Chat'):
        # Clear chat history in db.json
        # db = {'chat_history': []}
        # with open(DB_FILE, 'w') as file:
        #     json.dump(db, file)
        # Clear chat messages in session state
        st.session_state.user_history = ""
        st.session_state.active_df = None
        st.session_state.history = []
        st.session_state.messages = []
        st.rerun()

    st.sidebar.subheader("History")
    if st.session_state.history and st.session_state.user_history:
        for idx, (instr, _, res) in enumerate(st.session_state.history[::-1], 1):
            if len(instr) > 3: st.sidebar.markdown(f"{idx}. **{instr}**")

def main():
    app()

if __name__ == '__main__':
    app()


import numpy as np
import pandas as pd
import importlib.util
import plotly.express as px
from coinbase.feature_descriptions import feature_descriptions

if spec is not None and spec.loader is not None:
    cbds_fs_module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(cbds_fs_module)
    get = cbds_fs_module.get
else:
    raise ImportError(f"Could not load module from {fs_module_path}")

def get_feature_output(feature_name: str, start_date: str, end_date: str):
   
    return get(start_date=start_date, end_date=end_date, features = [feature_name], dims = [], show_progress = False, inputs = {"product_ids": ['BTC-USD']})

def aggregate_df(df, agg_func, feature_name):
    # function to read the dataframe,
    # ... see which columns to group by (all other than ds and the feature name itself)
    # ... see whether the feature_name is numerical or categorical
    # ... if numerical, then aggregate the feature_name using count.
    # ... if categorical, then aggregate the feature_name using the agg_func
    # ... return the aggregated dataframe
    df = df.reset_index()
    df_old = df.copy()

    groupby_cols = [col for col in df.columns if col not in ['ds', feature_name, 'index']]
    if groupby_cols:
        if df[feature_name].dtype == 'string':
            df = df.groupby(groupby_cols)[feature_name].count().reset_index()
        else:
            df = df.groupby(groupby_cols)[feature_name].agg(agg_func).reset_index()

        df = df[[col for col in df.columns if col not in ['index']]].sort_values(by=feature_name, ascending=False)
    else:
        if 'ds' in df.columns: df = df[[col for col in df.columns if col not in ['index']]].sort_values(by='ds', ascending=False)
    return df_old, df

def groupby_aggregation_from_text(df, groupby_col, agg_col, agg_func):
    """
    Groups the DataFrame based on a string instruction.
   
    Parameters:
    - df: pandas DataFrame
    - groupby_col: column name to group by
    - agg_col: column name to aggregate
    - text: string to search for aggregation keywords

    Returns:
    - Grouped and aggregated DataFrame
    """
    df_agg = df[agg_col].agg(agg_func)

    return df_agg


def format_large_number(num):
    abs_num = abs(num)
   
    if abs_num >= 1_000_000_000:
        return f"{round(num / 1_000_000_000, 1)}B"
    elif abs_num >= 1_000_000:
        return f"{round(num / 1_000_000, 1)}M"
    elif abs_num >= 1_000:
        return f"{round(num / 1_000, 1)}K"
    else:
        return str(num)

from pydantic import BaseModel
from pydantic_ai import Agent
from typing import Literal

from datetime import date
today = date.today()

class NodeAttributes(BaseModel):
    start_date: str
    end_date: str
    granularity: Literal['daily', 'weekly', 'monthly', 'quarterly', 'yearly']
    domain: str
    hamilton_node: str
    hamilton_description: str
    aggregation: Literal['sum', 'mean', 'min', 'max', 'count', 'count_distinct', 'std', 'var', 'skew', 'kurt', 'mode', 'quantile', 'quantile_90', 'quantile_95', 'quantile_99']
    source_file: str
    node_match: bool

def get_query_attributes(user_query_and_business_area: str, agent: Agent):
    result = agent.run_sync(user_query_and_business_area)
    return pd.Timestamp(result.output.start_date).strftime('%Y-%m-%d') , pd.Timestamp(result.output.end_date).strftime('%Y-%m-%d'), result.output.hamilton_node, result.output.business, result.output.domain, result.output.granularity, result.output.hamilton_description, result.output.source_file, result.output.node_match

from typing import Literal, List, Union, Dict

from pydantic import BaseModel, Field


class DataFrameInfo(BaseModel):
    """Metadata about the DataFrame"""
    shape: tuple[int, int]
    columns: List[str]
    dtypes: Dict[str, str]
    sample_data: str
    memory_usage: str
    null_counts: Dict[str, int]


class FilterRows(BaseModel):
    function: Literal["filter_rows"]
    condition: str = Field(
        description="""filter condition to be used with pandas.DataFrame.query(). Always
convert string values to lowercase, e.g df.query('name == "charlie"') -> `df.query("name.str.lower() == @search_term.lower())
"""
    )


class SortRows(BaseModel):
    function: Literal["sort_rows"]
    by: str
    ascending: bool = True


class TopRows(BaseModel):
    function: Literal["top_rows"]
    n: int


class SelectColumns(BaseModel):
    function: Literal["select_columns"]
    columns: List[str]


class GroupBy(BaseModel):
    function: Literal["group_by"]
    group_by_cols: List[str] = Field(
        description="""group by columns used for pandas.DataFrame.groupby(group_by_cols)[df.columns[-1]].agg(metric).
IMPORTANT: Instead of `date`, use `ds`."""
    )
    metric: Literal["first", "last", "mean", "median", "min", "max", "count", "count_distinct", "sum", "std", "var", "skew", "kurt", "mode", "quantile", "quantile_90", "quantile_95", "quantile_99"] = Field(
        description="""group by metric used for pandas.DataFrame.groupby(group_by_cols)[df.columns[-1]].agg(metric).
if unclear, use last."""
    )

class RunDuckDBSQL(BaseModel):
    function: Literal["run_duckdb_sql"]
    sql: str = Field(
        description="""
Executes a SQL query against the in-memory dataframe using DuckDB syntax.

Instructions for Formatting:
- Put each WHERE clause on a new line
- Use `ILIKE '%...%'` for string filters
- Prefer CTEs for multi-step logic
- Use 2-space indents for readability

Instructions for Filtering and Aggregating:
- If the user mentions the `date` column, use the `ds` column.

Example:

WITH filtered AS (
  SELECT *
  FROM df
  WHERE
    country ILIKE '%US%'
    AND category ILIKE '%Retail%'
)
SELECT
  date_trunc('month', ts)   AS month,
  AVG(revenue)              AS avg_revenue
FROM filtered
GROUP BY month
ORDER BY month
""")


AggregationPlan = Union[FilterRows, SortRows, TopRows, SelectColumns, GroupBy, RunDuckDBSQL] # add get hamilton tool

import json
import duckdb
import pandas as pd
from pydantic_ai import Agent

SAFE_TOOLS = {}


def register_tool(fn):
    SAFE_TOOLS[fn.__name__] = fn
    return fn


@register_tool
def filter_rows(df: pd.DataFrame, condition: str) -> pd.DataFrame:
    SAFE_TOOLS['filter_rows'] = filter_rows
    return df.query(condition)


@register_tool
def sort_rows(df: pd.DataFrame, by: str, ascending: bool = True) -> pd.DataFrame:
    return df.sort_values(by=by, ascending=ascending)


@register_tool
def top_rows(df: pd.DataFrame, n: int) -> pd.DataFrame:
    return df.head(n)


def last_rows(df: pd.DataFrame, n: int) -> pd.DataFrame:
    return df.tail(n)


@register_tool
def select_columns(df: pd.DataFrame, columns: list) -> pd.DataFrame:
    return df[columns]


@register_tool
def group_by(df: pd.DataFrame, group_by_cols: list, metric: str) -> pd.DataFrame:
    df = df.groupby(group_by_cols)[df.columns[-1]].agg(metric)
    df = df.reset_index()
    # df = df.sort_values(by=list(df.columns), ascending=False)
    return df


@register_tool
def resample(df: pd.DataFrame, freq: str) -> pd.DataFrame:
    return df.resample(freq).last()


@register_tool
def run_duckdb_sql(df: pd.DataFrame, sql: str) -> pd.DataFrame:
    """
    Run a SQL query on the in-memory DataFrame using DuckDB syntax.
    You must use `df` as the table name.
    """
    with duckdb.connect() as conn:
        conn.register('df', df)
        result_df = conn.execute(sql).fetchdf()

    return result_df

def create_dataframe_info(df: pd.DataFrame) -> DataFrameInfo:
    """Create comprehensive DataFrame metadata"""
    return DataFrameInfo(
        shape=df.shape,
        columns=list(df.columns),
        dtypes={str(col): str(dtype) for col, dtype in df.dtypes.items()},
        sample_data=df.head(3).to_string(),
        memory_usage=f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB",
        null_counts={col: df[col].isnull().sum() for col in df.columns}
    )

def generate_system_prompt(df: pd.DataFrame = pd.DataFrame()):
    df_info = create_dataframe_info(df)

    # Dynamic system prompt with DataFrame context
    system_prompt = f"""
You're an expert data analyst assistant working with a DataFrame.

DATASET OVERVIEW:
- Shape: {df_info.shape[0]:,} rows by {df_info.shape[1]} columns
- Columns: {', '.join(df_info.columns)}
- Data Types: {json.dumps(df_info.dtypes, indent=2)}
- Memory Usage: {df_info.memory_usage}
- Missing Values: {json.dumps(df_info.null_counts, indent=2)}

SAMPLE DATA:
{df_info.sample_data}

SCHEMA:
{df.dtypes.to_dict()}

INSTRUCTIONS:
1. Always validate column names exist before operations
2. Provide clear descriptions of what you're doing
3. Handle errors gracefully
4. Summarize results meaningfully
5. Remember previous operations in the conversation

You have access to various tools: If you can, always prefer to use of the `run_duckdb_sql` tool.
"""
    return system_prompt


class QueryAction(BaseModel):
    action: Literal['get_data', 'aggregate_data', 'graph_data', 'outlier_analysis']

import pandas as pd
import numpy as np

def combinations(iterable, r):
    for i in range(len(iterable)):
        for j in range(i+1, len(iterable)):
            yield (iterable[i], iterable[j])


class ChartRecommendation(BaseModel):
    """
    A complete recommendation for generating a single, insightful chart.
    """
    chart_type: Literal["stacked_bar", "bar", "line_plot", "histogram", "pie_chart", "none"] = Field(
        ..., description="The type of chart to generate."
    )
    x_column: str = Field(None, description="The column for the x-axis.")
    y_column: str = Field(None, description="The column for the y-axis (numerical).")
    segment_column: str = Field(None, description="The column for color segmentation.")
    title: str = Field(..., description="A publication-quality title.")
    reasoning: str = Field(..., description="Explanation for why this chart was chosen.")
    color_palette_suggestion: Dict[str, str] = Field(None, description="Mapping of category names to hex color codes.")
    advice: str = Field(None, description="Helpful advice if no chart is plotted.")

def create_dataframe_info(df: pd.DataFrame) -> DataFrameInfo:
    """Create comprehensive DataFrame metadata"""
    return DataFrameInfo(
        shape=df.shape,
        columns=list(df.columns),
        dtypes={str(col): str(dtype) for col, dtype in df.dtypes.items()},
        sample_data=df.head(3).to_string(),
        memory_usage=f"{df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB",
        null_counts={col: df[col].isnull().sum() for col in df.columns}
    )

def _summarize_dataframe(df: pd.DataFrame) -> (str, bool):
        """
        Creates a concise text summary of the DataFrame's columns,
        ignoring columns with only one unique value.
        """
        summary_lines = []
        candidates_found = False
       
        for col in df.columns:
            unique_count = df[col].nunique()
           
            # Edge Case: Ignore columns with only one value
            if unique_count <= 1 and len(df) > 1:
                continue

            col_type = "Unknown"
            if pd.api.types.is_numeric_dtype(df[col]):
                col_type = "Numerical"
            elif pd.api.types.is_datetime64_any_dtype(df[col]):
                col_type = "Temporal"
            elif df[col].dtype == 'object' or df[col].dtype.name == 'category':
                col_type = "Categorical"
           
            samples = df[col].dropna().unique()[:3]
            samples_str = ", ".join([str(s) for s in samples])
           
            summary_lines.append(
                f"- Column: '{col}', Type: {col_type}, Unique Values: {unique_count}, Samples: [{samples_str}]"
            )
            candidates_found = True
       
        return "\n".join(summary_lines), candidates_found

def generate_graphing_system_prompt(df: pd.DataFrame = pd.DataFrame(), context: str = "general business data"):
    """
        Analyzes the DataFrame and returns a complete chart recommendation.
        """

    # Edge Case: Empty DataFrame
    if df.empty:
        return ChartRecommendation(
            chart_type=ChartType.NONE,
            x_column=None,
            y_column=None,
            segment_column=None,
            title="No Data Found",
            reasoning="The provided DataFrame is empty.",
            advice="Please provide a DataFrame with data to generate a visualization.",
            color_palette_suggestion=None
        )

    summary, candidates_found = _summarize_dataframe(df)
    print(summary)
   
    if not candidates_found:
        return ChartRecommendation(
            chart_type=ChartType.NONE,
            x_column=None,
            y_column=None,
            segment_column=None,
            title="No Suitable Data Found",
            reasoning="The DataFrame contains no columns suitable for plotting (e.g., all columns have only one unique value).",
            advice="Please check your data. A plottable column needs more than one unique value.",
            color_palette_suggestion=None
        )

    # Dynamic system prompt with DataFrame context
    system_prompt = f"""
        You are an expert data visualization specialist. Your task is to analyze a summary of a DataFrame and recommend the SINGLE BEST chart to represent the data, following a strict hierarchy of choices. You must provide all parameters needed to build the chart.

        BUSINESS CONTEXT: The data represents {context}.

        Follow this exact decision-making process in order of preference:

        **1. If there are multiple categorical/temporal columns and at least one numerical column, make a STACKED BAR CHART. Prioritize the STACKED BAR CHART over the other chart types.**
           - **Condition:** The DataFrame must contain at least one suitable categorical/temporal column.
           - **Action:** If the condition is met, you must choose the first categorical/temporal column for the x-axis (`x_column`). Choose the other categorical/temporal column with the least number of unique values for segmentation (`segment_column`), EXCEPT if there is a column called 'symbol' or 'currency', in which case use that column for segmentation. The `y_column` should be the most relevant numerical column.
           - If a stacked bar chart is chosen, suggest a professional color palette dictionary mapping the unique values of the `segment_column` to distinct hex codes.

        **2. If there is one categorical column and one numerical column, try a BAR CHART.**
           - **Condition:** The DataFrame must contain at least one suitable categorical/temporal column.
           - **Action:** If the condition is met, you must choose the BEST categorical/temporal column for the x-axis (`x_column`) and another for segmentation (`segment_column`). The `y_column` should be the most relevant numerical column.

        **3. If there is only one temporal column and one numerical column, try a LINE PLOT.**
           - **Condition:** The DataFrame must contain ONE clear temporal column with many unique values (e.g., dates) AND ONE numerical column.
           - **Action:** Set `x_column` to the temporal column and `y_column` to the numerical column.

        **4. If neither of the above are suitable, consider a HISTOGRAM.**
           - **Condition:** The DataFrame contains only numerical columns or no clear relationships between categorical and numerical data.
           - **Action:** Choose the most important numerical column for the `x_column` to show its distribution. `y_column` will be 'Frequency'.

        **5. Finally, if none of the above are possible, consider a PIE CHART.**
           - **Condition:** The DataFrame has only ONE categorical column with a reasonable number of categories (2-10).
           - **Action:** Set the `segment_column` to this categorical column.

        **Edge Case:** If there is a temporal column but no numerical or categorical columns to plot against it, state this in your reasoning and provide advice.

        Here is the summary of the available columns:
        {summary}

        Now, provide your final recommendation as a complete JSON object.
        """
    return system_prompt

def plot_with_agent_recommendation(df: pd.DataFrame, recommendation: ChartRecommendation):
    if recommendation.chart_type == "none":
        print("Plotting skipped.")
        if recommendation.advice: print(f"Agent's Advice: {recommendation.advice}")
        return

    fig = None
    if recommendation.chart_type == "stacked_bar":
        df = df.sort_values(by=[recommendation.x_column, recommendation.y_column], ascending=False)
        # Get the top 5 values and group the rest as "Other"
        top_5_values = list(df[recommendation.segment_column])[:5]
        # top_5_values = value_counts.nlargest(5).index
       
        # Create a mask for values not in top 5
        df_2 = df.copy()
        df_2[recommendation.segment_column] = df_2[recommendation.segment_column].apply(
            lambda x: x if x in top_5_values else "Other"
        )
        fig = px.bar(df_2, x=recommendation.x_column, y=recommendation.y_column, color=recommendation.segment_column, title=recommendation.title, color_discrete_map=recommendation.color_palette_suggestion)
    elif recommendation.chart_type == "bar":
        if recommendation.x_column == "ds": df = df.sort_values(by=recommendation.x_column)
        fig = px.bar(df, x=recommendation.x_column, y=recommendation.y_column, title=recommendation.title, color_discrete_map=recommendation.color_palette_suggestion)
    elif recommendation.chart_type == "line_plot":
        if recommendation.x_column == "ds": df = df.sort_values(by=recommendation.x_column)
        fig = px.line(df, x=recommendation.x_column, y=recommendation.y_column, title=recommendation.title, markers=True)
    elif recommendation.chart_type == "histogram":
        fig = px.histogram(df, x=recommendation.x_column, title=recommendation.title, marginal="box")
    elif recommendation.chart_type == "pie_chart":
        fig = px.pie(df, names=recommendation.segment_column, title=recommendation.title, color_discrete_map=recommendation.color_palette_suggestion)
   
    if fig:
        fig.update_layout(title_font_size=20, xaxis_title=recommendation.x_column.replace('_', ' ').title() if recommendation.x_column else None, yaxis_title=recommendation.y_column.replace('_', ' ').title() if recommendation.y_column else None, legend_title=recommendation.segment_column.replace('_', ' ').title() if recommendation.segment_column else None, font=dict(family="Arial, sans-serif", size=12))
        return fig
    else:
        print("Could not generate a chart for the given recommendation.")

router_prompt = """You are an expert at comprehending the user's query and discerning the action they desire.
        Take the user's query and understand whether they are asking for new data. If they are, return 'get data'.
        If they are not (perhaps there are keywords like 'take this data' which means they want to use the data you've already provided).
        Then understand whether they are asking for data to be aggregated, sorted, or filtered. If they are, return 'aggregate data'.
        If they are not, take the user's query and understand whether they are asking for data to be graphed. If they are, return 'graph data'.
        If they are not, take the user's query and understand whether they are asking for data trends to be analyzed. If they are, return 'outlier analysis'.
        If none of these are what the user wants, return 'else', but be sure to try to understand if the user wants data first, and prefer to return 'get data' if possible.
        You need to work quickly, so do not return more than one phrase"""

get_prompt = f"""Parse the question and the business area and return a result in the output format.
    For the output_fields `start_date` and `end_date`, please give start and end dates the user desires in 'YYYY-MM-DD' string format.
    Please note that today's date is {today}.
    For output field `business`, if the business name is not explicitly mentioned in the query, leave it as 'Unknown'.
    For output field `aggregation`, if the aggregation function is not explicitly mentioned in the query, leave it as 'mean'.
    For output field `granularity`, if the granularity is not explicitly mentioned in the query, leave it as 'daily'.
    For output field `hamilton_node`, please search the following dictionary object for the `business_area` key. There should be a sub-dictionary for this `business_area`. Please exclusively search this sub-dictionary and return the key of the description from the sub-dictionary which is the best match. The dictionary object is here: {str(feature_descriptions)}.
    For the output field `hamilton_description`, please return the value of the description of the hamilton node that is the best match.
    For the output field `source_file`, please return the dictionary key of the source file that contains the feature.
    For the output field `node_match`, please return True if the hamilton node found is in the sub-dictionary of the `business_area` key, and False otherwise."""