In [1]:
import os
import pandas as pd
import plotly.express as px
from dotenv import load_dotenv
from langchain_experimental.agents import create_pandas_dataframe_agent
from langchain.tools import BaseTool
from langchain_openai import ChatOpenAI

# Load environment variables from .env file
load_dotenv()

# Set up OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")

# Create the language model (LLM) instance
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo")


In [2]:
# Define the tool to import CSV files
class CSVImportTool(BaseTool):
    name: str = "csv_import"
    description: str = "Import a CSV file into a pandas DataFrame."
    
    def _run(self, query: str) -> str:
        try:
            df = pd.read_csv(query)
            return f"CSV file '{query}' imported successfully. The DataFrame has {len(df)} rows and {len(df.columns)} columns."
        except Exception as e:
            return f"Error importing CSV file: {e}"
    
    def _arun(self, query: str) -> pd.DataFrame:
        return pd.read_csv(query)

# Define the tool to suggest the best chart
class ChartSuggestionTool(BaseTool):
    name: str = "chart_suggestion"
    description: str = "Suggest the most appropriate chart for visualizing the data in a pandas DataFrame."
    
    def _run(self, df: pd.DataFrame) -> str:
        num_cols = len([col for col in df.columns if df[col].dtype in ['int64', 'float64']])
        cat_cols = len([col for col in df.columns if df[col].dtype == 'object'])
        
        if num_cols == 1 and cat_cols == 1:
            return "For a single numerical column and a single categorical column, a bar chart or a line chart would be appropriate."
        elif num_cols == 2 and cat_cols == 0:
            return "For two numerical columns, a scatter plot would be a good choice."
        elif num_cols > 2 and cat_cols == 0:
            return "For multiple numerical columns, you could consider a scatter plot matrix or a parallel coordinates plot."
        elif cat_cols > 1:
            return "For multiple categorical columns, you could use a stacked bar chart or a heatmap."
        else:
            return "Based on the data, it's difficult to suggest an appropriate chart. Please provide more information about the data and the analysis you want to perform."
    
    def _arun(self, df: pd.DataFrame) -> str:
        return self._run(df)

# Define the tool to provide DataFrame insights
class DataFrameInsightsTool(BaseTool):
    name: str = "dataframe_insights"
    description: str = "Provide insights about a pandas DataFrame, such as data types, missing values, and basic statistics."
    
    def _run(self, df: pd.DataFrame) -> str:
        insights = []
        
        # Data types
        data_types = df.dtypes.value_counts().to_dict()
        insights.append(f"Data types: {', '.join([f'{dtype} ({count})' for dtype, count in data_types.items()])}")
        
        # Missing values
        missing_values = df.isnull().sum().sum()
        insights.append(f"Missing values: {missing_values}")
        
        # Basic statistics
        if len([col for col in df.columns if df[col].dtype in ['int64', 'float64']]) > 0:
            numeric_cols = df.select_dtypes(include=['int64', 'float64'])
            numeric_stats = numeric_cols.describe().T
            insights.append("Basic statistics for numerical columns:\n" + numeric_stats.to_string())
        
        return "\n".join(insights)
    
    def _arun(self, df: pd.DataFrame) -> str:
        return self._run(df)


In [3]:
# Function to create and save a plot using plotly.express
def create_and_save_plot(df, x_column, y_column, plot_type="bar", file_name="plot.png"):
    if plot_type == "bar":
        fig = px.bar(df, x=x_column, y=y_column)
    elif plot_type == "line":
        fig = px.line(df, x=x_column, y=y_column)
    elif plot_type == "scatter":
        fig = px.scatter(df, x=x_column, y=y_column)
    elif plot_type == "histogram":
        fig = px.histogram(df, x=x_column)
    elif plot_type == "box":
        fig = px.box(df, x=x_column, y=y_column)
    else:
        raise ValueError(f"Unsupported plot type: {plot_type}")
    
    # Save the plot to a file
    fig.write_image(file_name)
    print(f"Plot saved as {file_name}")


In [5]:
csv_file = "basededados_aulat7.csv"
df = pd.read_csv(csv_file)

# Initialize the agent with the actual DataFrame
agent = create_pandas_dataframe_agent(llm=llm, df=df, verbose=False, allow_dangerous_code=True)

print("AI Assistant: Hello! How can I assist you today? (Type 'exit' to end the chat)")

while True:
    query = input("Human: ")
    if query.lower() == "exit":
        break  # Exit the chat loop
    else:
        response = agent.invoke(query)
        print("AI Assistant:", response)


AI Assistant: Hello! How can I assist you today? (Type 'exit' to end the chat)


AI Assistant: {'input': 'quero insghts dos dados', 'output': 'The insights from the data include summary statistics of numerical columns and the number of unique values in categorical columns.'}
AI Assistant: {'input': 'me de o sumario', 'output': 'The summary statistics for the dataframe `df` are displayed above.'}
AI Assistant: {'input': 'qual melhor gráfico para plotar os 5 primeiros produtos', 'output': 'Agent stopped due to iteration limit or time limit.'}
