In [None]:
import pandas as pd
import seaborn as sns
from PIL import Image
import os
import openai
import json
import streamlit as st
from streamlit_chat import message
import plotly.graph_objects as go
import plotly.express as px

from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
from langchain_openai import OpenAI, ChatOpenAI
from langchain.agents import AgentType
from langchain.tools import tool, Tool
from langchain.schema.output_parser import OutputParserException




In [None]:
# Set up Streamlit page configuration
st.set_page_config(layout="wide")

def sidebar():
    """App sidebar content"""

    st.markdown("<h3 style='text-align:center;'>Model Options</h3>", unsafe_allow_html=True)X

    # Adding users API Key
    api_key_select = st.text_input('Please add your API key',
                                   placeholder='Paste your API key here',
                                   type='password',
                                   help=("""[Get Your OpenAI API key here](https://platform.openai.com/api-keys)""")
    )
    
    model_select = st.selectbox(
        label="Available Models",
        options=["gpt-3.5-turbo-1106", "gpt-4-1106-preview", "gpt-4o-mini", "gpt-4o"],
        help="""The available models. Same prompt might return different results for
        different models. Experimentation is recommended."""
    )
    
    temperature_select = st.slider(
        label="Temperature",
        value=0.0,
        min_value=0.,
        max_value=1.,
        step=0.01,
        help=(
            """Controls randomness. What sampling temperature to use, between 0 and 1.
            Higher values like 0.8 will make the output more random, while lower values
            like 0.2 will make it more focused and deterministic."""
        )
    )

    st.markdown("<hr style='border:1px solid black'>", unsafe_allow_html=True)
    st.markdown("<h3 style='text-align:center;'>Page Options</h3>", unsafe_allow_html=True)

    # Slider for layout width
    layout_width = st.slider(
        label="Layout Width",
        min_value=600,
        max_value=1800,
        value=1200,
        step=50,
        help="Adjust the width of the layout."
    )
    
    out_dict = {
        "api_key": api_key_select,
        "model": model_select,
        "temperature": temperature_select,
        "layout_width": layout_width
    }
    return out_dict


In [None]:
st.markdown("<h1 style='text-align:center; padding:0.5em; border-radius:6px;'>Data Insight Assistant</h1>", unsafe_allow_html=True)
# st.title('Data Insight Assistant')

# Set a flag to check if API key is entered
api_key_entered = False

with st.sidebar:
    model_params = sidebar()
    api_key_select = model_params["api_key"]
    model_select = model_params["model"]
    temperature_select = model_params["temperature"]
    layout_width = model_params["layout_width"]
    
    # Check if API key is provided
    if api_key_select:
        api_key_entered = True

# Stop if API key is not entered
if not api_key_entered:
    st.warning("Please enter your API key (sk-...) in the sidebar to proceed.")
    st.stop()

In [None]:
# Continue with file upload and data processing if API key is entered
file = st.file_uploader(label="Choose file (.csv)", type=["csv"])

if not file: 
    st.stop()
    
df = pd.read_csv(file, encoding='windows-1252')

In [None]:
@tool
def plotChart(data: str) -> str:
    """
    Plots JSON data using a Plotly Figure.
    Args:
        data (str): JSON string representing the figure.
    Returns:
        str: Confirmation message after plotting.
    """
    import json
    from plotly.io import from_json
    import streamlit as st
    import plotly.graph_objects as go
    import pandas as pd

    try:
        # Load JSON data
        figure_dict = json.loads(data)

        # Create Figure object from JSON data
        fig = from_json(json.dumps(figure_dict))

        # Correct number formats
        fig.update_layout({
            'yaxis': {'tickformat': ',.2f'},
            'xaxis': {'tickformat': ',.2f'}
        })

        # Plot the main chart using Streamlit and Plotly
        # st.plotly_chart(fig)

        return "Charts and table plotted successfully."
    except Exception as e:
        st.error(f"An error occurred while creating the plot: {e}")
        return "An error occurred while creating the plot. Please provide the dataset or dataframe for troubleshooting."

extra_tools = [plotChart]

def extract_input_output(result):
    input_cmds = [step[0].tool_input for step in result['intermediate_steps']]
    output = result['output']
    return input_cmds, output

def convert_to_streamlit_format(data):
    return json.loads(data)

def display_string_in_list(data):
    if 'query' in data:
        query = data['query']
        # Split the query by '\n' to handle each line separately
        lines = query.split('\n')
        # Remove empty lines and strip leading/trailing whitespace
        lines = [line.strip() for line in lines if line.strip()]
        # Join the lines with '\n' for a single string representation
        formatted_query = '\n'.join(lines)
        return formatted_query
    return None

def remove_repeated_words(query: str) -> str:
    # Split the query into lines
    lines = query.split('\n')
    
    # Check if there are at least two lines to compare
    if len(lines) < 2:
        return query  # Return the original query if there's nothing to compare

    # Get the last line and the second to last line
    last_line = lines[-1].strip()
    second_last_line = lines[-2].strip()

    # Split the last line by '=' to isolate the variable and its assignment
    if '=' in second_last_line:
        var_name, _ = second_last_line.split('=', 1)
        var_name = var_name.strip()
        
        # Check if the last line starts with the variable name and remove it
        if last_line == var_name:
            lines.pop()  # Remove the last line if it's just a repetition of the variable name

    # Join the lines back together
    cleaned_query = '\n'.join(lines)
    return cleaned_query

def extract_fig_from_query(query: str) -> str:
    fig_start = query.find("fig =")
    if fig_start != -1:
        fig_code = query[fig_start:]
        fig_code_lines = fig_code.split('\n')
        fig_code_clean = '\n'.join(line for line in fig_code_lines if not line.strip().startswith('#'))
        return fig_code_clean
    return None

model = ChatOpenAI(
    model=model_select,
    temperature=temperature_select,
    api_key=api_key_select
)

pandas_agent = create_pandas_dataframe_agent(
    llm=model,
    df=df,
    max_iterations=30,
    max_execution_time=45,
    agent_type=AgentType.OPENAI_FUNCTIONS,
    return_intermediate_steps=True,
    verbose=True,
    allow_dangerous_code=True,
    extra_tools=extra_tools,
)

# Adjust layout width using CSS
st.markdown(
    f"""
    <style>
        .main .block-container {{
            max-width: {layout_width}px;
            margin-left: auto;
            margin-right: auto;
        }}
    </style>
    """,
    unsafe_allow_html=True
)

In [None]:
###################################################################################################################################
# 1.0 Streamlit App
###################################################################################################################################

with st.expander("🔎 Dataframe Preview"):
    st.write(df.head(3))

# Initialize chat history
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

# Chat message input
user_question = st.chat_input("🗣️ Ask question about dataset")

if user_question:
    try:
        with st.spinner("Generating output..."):
            # Get the response from the agent
            response = pandas_agent.invoke(user_question)
            # Extract input, query, and output
            input_cmds, output = extract_input_output(response)
            st.session_state.chat_history.append((user_question, input_cmds, output))
            
    except OutputParserException as e:
        st.error("OutputParserException error occurred in LangChain agent. Refine your query.")
    except Exception as e:
        st.error(f"Unknown error occurred in LangChain agent. Refine your query. Error: {e}")


# Display chat history
for question, input_cmds, output in st.session_state.chat_history:
    st.chat_message("user").write(f"{question}")
    
    if input_cmds and len(input_cmds) > 0 and 'query' in input_cmds[0]:
        formatted_query = display_string_in_list(input_cmds[0])
        if formatted_query:
            formatted_query = remove_repeated_words(formatted_query)
            st.chat_message("Q").write(f"\n```{formatted_query}```") 
            
        # Extract fig from query if it exists
        fig_code = extract_fig_from_query(input_cmds[0]['query'])
        if fig_code:
            exec(fig_code, globals())
            st.plotly_chart(fig)
    
    st.chat_message("assistant").write(f"{output}") 
    
    if len(input_cmds) > 1 and ('data' in input_cmds[1] or len(input_cmds) > 2 and 'data' in input_cmds[2]):
        chart_data = None
        if 'data' in input_cmds[1] and input_cmds[1]['data']:
            chart_data = convert_to_streamlit_format(input_cmds[1]['data'])
        elif len(input_cmds) > 2 and 'data' in input_cmds[2] and input_cmds[2]['data']:
            chart_data = convert_to_streamlit_format(input_cmds[2]['data'])
        
        if chart_data:
            fig = go.Figure(chart_data)
            st.plotly_chart(fig)
    
    st.markdown("<hr style='border:3px solid black'>", unsafe_allow_html=True)
