In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import requests
import json
import re
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows

# --- Utility Functions ---
def set_chart_title_and_labels(title, subtitle, x_title, y_title, ax):
    """Sets the chart title and axis labels."""
    ax.set_title(f"{title}\n{subtitle}", fontsize=16)
    ax.set_xlabel(x_title, fontsize=12)
    ax.set_ylabel(y_title, fontsize=12)

def apply_common_chart_formatting(ax):
    """Applies common formatting to charts."""
    ax.legend(loc="best", fontsize='medium')
    ax.grid(axis='y', linestyle='--', alpha=0.5)
    ax.tick_params(axis='x', rotation=45)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)

# --- Data Reading Function ---
def read_csv(file_path):
    """Reads a CSV file, ensures valid formatting, and cleans data."""
    try:
        data_frame = pd.read_csv(file_path, encoding='utf-8-sig', skipinitialspace=True)
    except FileNotFoundError:
        raise FileNotFoundError(f"Error: File not found at path: {file_path}")
    except pd.errors.EmptyDataError:
        raise ValueError("CSV file is empty.")
    except Exception as e:
        raise Exception(f"Error while reading the file: {e}")

    # Handle numeric conversions for specific columns
    numeric_columns = ['Rating', 'CustomerSatisfactionScore', 'MeanRating', 'StdDevRating', 'VarianceRating', 'tStatistic', 'pValue']
    for col in numeric_columns:
        if col in data_frame.columns:
            data_frame[col] = pd.to_numeric(data_frame[col], errors='coerce')  # Convert to numeric, setting invalid values as NaN

    # Check for missing or invalid data
    if data_frame.isnull().values.any():
        print("Warning: Missing or invalid data detected. Ensure all necessary values are provided.")
        print(data_frame.isnull().sum())  # Print a summary of missing values for debugging

    # Strip whitespace from string columns
    for col in data_frame.select_dtypes(include=['object']).columns:
        data_frame[col] = data_frame[col].str.strip()

    return data_frame

# --- API Interaction Function ---
def interact_with_ai(prompt, data):
    """Interacts with the Google Gemini API."""
    api_key = "KEY"  # Hardcoded API key (NOT RECOMMENDED)
    api_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent"

    headers = {"Content-Type": "application/json"}

    # Convert DataFrame to string for prompt
    data_str = data.to_string()
    full_prompt = f"Here's some information about the data:\n{data_str}\n\n{prompt}"

    payload = {
        "contents": [
            {
                "parts": [
                    {"text": full_prompt}
                ]
            }
        ]
    }

    try:
        response = requests.post(f"{api_url}?key={api_key}", headers=headers, json=payload)
        response.raise_for_status()  # Raise HTTPError for bad responses (4xx or 5xx)
    except requests.exceptions.RequestException as e:
        raise Exception(f"Error communicating with the API: {e}") from e

    try:
        result = response.json()
    except json.JSONDecodeError as e:
        raise Exception(f"Error decoding JSON response: {e}") from e

   # Process the response from the Google Gemini API
    if 'candidates' in result and result['candidates'] and 'content' in result['candidates'][0] and 'parts' in result['candidates'][0]['content'] and result['candidates'][0]['content']['parts']:
        ai_response_text = result['candidates'][0]['content']['parts'][0].get('text', "Error: Unexpected API response format.")
        ai_response_text = ai_response_text.replace("*", "") # Remove asterisks
    else:
        raise ValueError("Error: The API response structure is not as expected.")

    return ai_response_text, full_prompt

# --- JSON Validation Function ---
import jsonschema

def validate_json(data):
    """Validates JSON against a schema."""
    schema = {
        "type": "object",
        "properties": {
            "chart": {
                "type": "object",
                "properties": {
                    "type": {"type": "string"},
                    "title": {"type": "object"},
                    "subtitle": {"type": "object"},
                    "xAxis": {"type": "object"},
                    "yAxis": {"type": "object"},
                    "series": {"type": "array"}
                },
                "required": ["type", "title", "subtitle", "xAxis", "yAxis", "series"]
            }
        },
        "required": ["chart"]
    }
    try:
        jsonschema.validate(instance=data, schema=schema)
        return True, data
    except jsonschema.exceptions.ValidationError as e:
        print(f"JSON validation error: {e}")
        return False, None

# --- Chart Generation Functions ---
def generate_column_chart(chart_data):
    """Generates a column chart."""
    title = chart_data['title']['text']
    subtitle = chart_data['subtitle']['text']
    x_categories = chart_data['xAxis']['categories']
    x_title = chart_data['xAxis']['title']['text']
    y_title = chart_data['yAxis']['title']['text']
    series = chart_data['series']

    fig, ax = plt.subplots(figsize=(10, 6))  # Reduced size
    df = pd.DataFrame({serie['name']: serie['data'] for serie in series}, index=x_categories)
    df.plot(kind='bar', stacked=False, rot=45, ax=ax)

    # Add value labels on top of bars
    for p in ax.patches:
        ax.annotate(str(p.get_height()), (p.get_x() * 1.005, p.get_height() * 1.005), fontsize=8)

    set_chart_title_and_labels(title, subtitle, x_title, y_title, ax)
    apply_common_chart_formatting(ax)
    plt.tight_layout(pad=2)

    return fig

def generate_pie_chart(chart_data):
    """Generates a pie chart with improved aesthetics."""
    title = chart_data['title']['text']
    subtitle = chart_data['subtitle']['text']
    x_categories = chart_data['xAxis']['categories']
    x_title = chart_data['xAxis']['title']['text']
    y_title = chart_data['yAxis']['title']['text']
    series = chart_data['series']

    fig, ax = plt.subplots(figsize=(8, 8))  # Reduced size
    for i, serie in enumerate(series):
        colors = [plt.cm.tab20(j / len(serie['data'])) for j in range(len(serie['data']))]
        wedges, texts, autotexts = ax.pie(
            serie['data'], labels=x_categories, autopct='%1.1f%%',
            startangle=90, colors=colors, explode=[0.01] * len(serie['data']),
            textprops=dict(color='w')
        )
        plt.setp(autotexts, size=10, weight="bold")
        ax.set_title(serie['name'])
        ax.axis('equal')  # Ensure the pie chart is a circle

    set_chart_title_and_labels(title, subtitle, x_title, y_title, ax)
    apply_common_chart_formatting(ax)
    plt.tight_layout(pad=2)

    return fig

def generate_line_chart(chart_data):
    """Generates a line chart."""
    title = chart_data['title']['text']
    subtitle = chart_data['subtitle']['text']
    x_categories = chart_data['xAxis']['categories']
    x_title = chart_data['xAxis']['title']['text']
    y_title = chart_data['yAxis']['title']['text']
    series = chart_data['series']

    fig, ax = plt.subplots(figsize=(10, 6))  # Reduced size
    color_palette = plt.cm.viridis(range(len(series)))
    for i, serie in enumerate(series):
        color = color_palette[i]
        ax.plot(x_categories, serie['data'], label=serie['name'], marker='o', linestyle='-', color=color, linewidth=2, markersize=5)

    set_chart_title_and_labels(title, subtitle, x_title, y_title, ax)
    apply_common_chart_formatting(ax)
    plt.tight_layout(pad=2)

    return fig

def generate_scatter_chart(chart_data):
    """Generates a scatter chart."""
    title = chart_data['title']['text']
    subtitle = chart_data['subtitle']['text']
    x_title = chart_data['xAxis']['title']['text']
    y_title = chart_data['yAxis']['title']['text']
    series = chart_data['series']

    fig, ax = plt.subplots(figsize=(10, 6))  # Reduced size
    color_palette = plt.cm.viridis(range(len(series)))
    for i, serie in enumerate(series):
        points = serie.get('data', [])
        x_vals, y_vals = zip(*points) if points else ([], [])
        color = color_palette[i]
        ax.scatter(x_vals, y_vals, label=serie['name'], color=color, s=50, alpha=0.8)

    set_chart_title_and_labels(title, subtitle, x_title, y_title, ax)
    apply_common_chart_formatting(ax)
    plt.tight_layout(pad=2)

    return fig

def generate_histogram_chart(chart_data):
    """Generates a histogram chart."""
    title = chart_data['title']['text']
    subtitle = chart_data['subtitle']['text']
    x_title = chart_data['xAxis']['title']['text']
    y_title = chart_data['yAxis']['title']['text']
    series = chart_data['series']

    fig, ax = plt.subplots(figsize=(10, 6))  # Reduced size

    for i, serie in enumerate(series):
        ax.hist(serie['data'], bins=10, alpha=0.7, label=serie['name'], edgecolor='black')

    set_chart_title_and_labels(title, subtitle, x_title, y_title, ax)
    apply_common_chart_formatting(ax)
    plt.tight_layout(pad=2)

    return fig

def generate_cross_table(chart_data):
    """Generates a cross-table in Excel."""
    title = chart_data['title']['text']
    subtitle = chart_data['subtitle']['text']
    x_categories = chart_data['xAxis']['categories']
    series = chart_data['series']
    try:
        # Create a Pandas DataFrame
        data = {}
        for serie in series:
            data[serie['name']] = serie['data']

        df = pd.DataFrame(data, index=x_categories)

        # Dynamically Handle nested row categories and extract level names
        #num_levels = len(x_categories[0].split('-'))
        #index_names = [f'Level {i+1}' for i in range(num_levels)]

        # Create MultiIndex
        new_index = []
        for cat in x_categories:
            # Split only if there are 4 parts for each row, if the row do not have five parts add an empty spot
            parts = cat.split('-')

            new_index.append(tuple(parts))
        num_levels = len(new_index[0])
        index_names = [f'Level {i+1}' for i in range(num_levels)]

        df.index = pd.MultiIndex.from_tuples(new_index, names=index_names) # index names will now be the columns names


        # Handle nested column names.
        column_names = []
        for serie in series:
            column_names.append(serie['name'])

        # Convert series' names to multiindex if needed:  seriesnameA-seriesNameB-etc
        new_columns = []
        for name in column_names:
            if '-' in name:
                new_columns.append(tuple(name.split('-'))) #splits the data to create columns names
            else:
                new_columns.append(name)

        #assign the data to the df
        df.columns = pd.MultiIndex.from_tuples([new_columns], names=['ColumnNames']) if all(isinstance(col, tuple) for col in new_columns) else new_columns


        # Create an Excel workbook and sheet
        wb = Workbook()
        ws = wb.active

        # Add the DataFrame to the worksheet
        for r in dataframe_to_rows(df, index=True, header=True):
            ws.append(r)

        # Save the Excel file
        filename = "cross_table.xlsx"
        wb.save(filename)
        print(f"Cross-table generated and saved to {filename}")

    except (TypeError, KeyError, AttributeError) as e:
        print(f"Error generating cross-table: {e}")
    return  # Exit the function after generating the cross-table

def generate_chart_from_json(json_str):
    """Generates charts based on JSON specification and returns the figure."""
    try:
        # Extract JSON object using regex
        match = re.search(r'\{.*\}', json_str, re.DOTALL)
        if match:
            json_data = json.loads(match.group(0))
        else:
            print("Error: Could not find valid JSON in the response.")
            return None  # Return None if JSON extraction fails

        # Validate JSON structure
        if not isinstance(json_data, dict) or "chart" not in json_data:
            print("Error: JSON structure is invalid.")
            return None  # Return None if JSON is invalid

        # Extract chart attributes
        chart = json_data['chart']
        chart_type = chart.get('type', 'column').lower()  # Default to column chart, case-insensitive

        # --- Chart Type Handling ---
        if chart_type == 'cross-table':
            generate_cross_table(chart)
            return None  # Cross-table doesn't return a Matplotlib figure
        elif chart_type == 'column' or chart_type == 'bar':
            fig = generate_column_chart(chart)
        elif chart_type == 'pie':
            fig = generate_pie_chart(chart)
        elif chart_type == 'line':
            fig = generate_line_chart(chart)
        elif chart_type == 'scatter':
            fig = generate_scatter_chart(chart)
        elif chart_type == 'histogram':
            fig = generate_histogram_chart(chart)
        else:
            print(f"Unsupported chart type: {chart_type}")
            return None

        return fig  # Return the created figure

    except (json.JSONDecodeError, TypeError, KeyError) as e:
        print(f"Error processing the JSON input: {e}")
        print(json_str)  # Print the raw JSON string for further inspection
        return None


# --- Menu Display Function ---
def display_menu():
    menu_width = 40
    menu_title = " MAIN MENU "
    print("\n" + "=" * menu_width)
    print(f"{menu_title:^{menu_width}}")
    print("=" * menu_width)
    print("\n1. 💬 Ask an open question about the data (no chart option)")
    print("2. 📈 Select a chart Or cross-table option from predefined prompts")
    print("3. ❌ Exit\n")
    print("=" * menu_width)

# --- Main Function ---
def main():
    """Main function to run the program with enhanced prompts."""
    file_path = "TestData.csv"  # Update path if needed
    try:
        data = read_csv(file_path)
    except FileNotFoundError as e:
        print(e)
        return
    except ValueError as e:
        print(e)
        return

    # Define predefined chart prompts categorized with chart type hints
    chart_prompts = [
    # Comparison prompts
    "Compare the customer satisfaction scores across different products using a column chart.",
    "Compare the purchase frequencies between premium and budget products using a column chart.",
    "Show the count of ratings (1 to 5) by gender using a grouped column chart.",

    # Composition prompts
    "Show the distribution of purchase contexts (e.g., online purchase, in-store purchase) using a pie chart.",
    "Show the gender distribution among customers for premium product using a pie chart.",

    # Distribution prompts
    "Show the ratings provided by customers in different income levels using a column chart.",
    "Show the variation in purchase frequencies across regions using a column chart.",

    # Relationship prompts
    "Show the relationship between customer satisfaction scores and ratings for each product using a scatter plot.",
    "Show the distribution of income levels across different price categories using a grouped column chart.",

    # Trend prompts
    "Show the mean ratings for products categorized by age groups using a line chart.",
    "Show the customer satisfaction scores for different purchase contexts using a column chart.",

    #Cross-table prompt
    '''Generate a cross-table with rows nested hierarchically by 'Product', 'Region', 'PurchaseFrequency', and 'AgeGroup'.
    Display 'CustomerSatisfactionScore' in columns, ensuring clarity and highlighting trends explicitly present in the data.'''
]


    while True:
        display_menu()
        try:
            choice = input("Please enter your choice (1-3): ").strip()
            if choice in ["1", "2", "3"]:
                print(f"\nYou selected option {choice}.")
                if choice == "3":
                    print("Exiting program. Have a great day!\n")
                    break
            else:
                print("\n⚠️ Invalid input. Please select a valid option (1-3).\n")
        except Exception as e:
            print(f"\n⚠️ An error occurred: {e}\n")

        if choice == '1':
            open_question = input("Ask your question about the data: ").strip()
            # Send open question to AI assistant
            try:
                ai_response, _ = interact_with_ai(open_question, data)
            except ValueError as e:
                print(e)
                continue
            except Exception as e:
                print(f"Error during AI interaction: {e}")  # Handle API errors better
                continue

            if ai_response is None:
                print("Error: Failed to get a response from the AI assistant.")
                continue
            print("AI Response:")
            # Wrap the AI response to prevent long lines
            wrapped_response = ""
            line_length = 80  # Maximum characters per line
            words = ai_response.split()
            current_line = ""
            for word in words:
                if len(current_line + word) + 1 <= line_length:
                    current_line += word + " "
                else:
                    wrapped_response += current_line.rstrip() + "\n"
                    current_line = word + " "
            wrapped_response += current_line.rstrip()  # Add the last line
            print(wrapped_response)



        elif choice == '2':
            print("\nSelect a chart prompt:")
            for i, prompt in enumerate(chart_prompts, start=1):
                print(f"{i}. {prompt}")

            try:
                chart_choice = int(input("Enter the prompt number: ").strip()) - 1
                if chart_choice < 0 or chart_choice >= len(chart_prompts):
                    print("Invalid prompt selection. Please try again.")
                    continue

                chart_prompt = chart_prompts[chart_choice]
                print(f"Selected Chart Prompt: {chart_prompt}")

# Send chart prompt to AI assistant with clear instructions
                clear_instruction = (
    f"You are a JSON data provider for chart generation or cross-table data. "
    f"Given the prompt: '{chart_prompt}', "
    f"your task is to generate a JSON object that strictly adheres to the following requirements:\n\n"
    f"1. **Strict JSON Structure**:\n"
    f"   The JSON must follow one of these formats exactly, depending on the requested chart type:\n"
    f"   {{\n"
    f"       'chart': {{\n"
    f"           'type': '[chart_type]',\n"  # Specify chart type: 'column', 'bar', 'pie', 'line', 'scatter', 'histogram'
    f"           'title': {{'text': '[chart_title]'}},\n"
    f"           'subtitle': {{'text': '[chart_subtitle]'}},\n"
    f"           'xAxis': {{\n"
    f"               'categories': [categories_list],\n"
    f"               'title': {{'text': '[x_axis_title]'}}\n"
    f"           }},\n"
    f"           'yAxis': {{\n"
    f"               'title': {{'text': '[y_axis_title]'}}\n"
    f"           }},\n"
    f"           'series': [\n"
    f"               {{'name': '[series_1_name]', 'data': [series_1_data_list]}},\n"
    f"               {{'name': '[series_2_name]', 'data': [series_2_data_list]}},\n"  # Add more series as needed
    f"           ]\n"
    f"       }}\n"
    f"   }}\n\n"
    f"   **B. For Cross-Table Generation ('cross-table'):**\n"
    f"   {{\n"
    f"       'chart': {{\n"
    f"           'type': 'cross-table',\n"
    f"           'title': {{'text': '[table_title]'}},\n"
    f"           'subtitle': {{'text': '[table_subtitle]'}},\n"
    f"           'xAxis': {{\n"
    f"               'categories': [row_categories_list],\n"  # Categories for the rows
    f"               'title': {{'text': '[row_axis_title]'}}\n"
    f"           }},\n"
    f"           'series': [\n"
    f"               {{'name': '[column_1_name]', 'data': [column_1_data_list]}},\n"
    f"               {{'name': '[column_2_name]', 'data': [column_2_data_list]}},\n"  # Add more columns as needed
    f"           ]\n"
    f"       }}\n"
    f"   }}\n\n"
    f"2. **Prompt Alignment**:\n"
    f"   - Ensure the JSON is tailored to the context of the prompt.\n"
    f"   - For pie charts, ensure all 'data' values are valid **numbers** (e.g., integers or floats) and avoid nested structures like dictionaries.\n"
    f"   - Use logical and complete placeholders where details are missing.\n\n"
    f"3. **Validation and Accuracy**:\n"
    f"   - The JSON must be syntactically correct and ready for direct use in a Python charting library or Excel generation.\n"
    f"   - Validate that 'data' fields in the 'series' array contain only a list of numeric values (not dictionaries or other non-numeric types).\n"
    f"   - Ensure the number of 'categories' matches the number of 'data' points for pie charts and cross-tables.\n\n"
    f"4. **Supported Chart/Table Types**:\n"
    f"   - Supported chart types include: column, bar, pie, line, scatter, and histogram.\n"
    f"   - A cross-table can also be generated (type: 'cross-table').\n"
    f"   - Generate the appropriate structure based on the provided chart/table type.\n"
    f"   - For pie charts:\n"
    f"       - 'data' must contain a list of numeric values only.\n"
    f"       - Labels for the pie chart (from 'categories') must be clearly defined.\n"
    f" - For cross-tables:\n"
    f" - The 'xAxis' exclusively defines the row categories. The 'categories' list within 'xAxis' contains the labels for each row.\n"
    f" - Each item in 'series' represents a column. Each item must contain 'name' (column header) and 'data' (column values).\n"
    f" - The length of each 'data' list in 'series' must exactly match the length of the 'categories' list in 'xAxis'. This consistency is mandatory.\n"
    f"5. **Output Only JSON**:\n"
    f"   - Do not include additional text, explanations, or comments. Return only the JSON object.\n"
)
                try:
                    chart_response, _ = interact_with_ai(clear_instruction, data)
                except ValueError as e:
                    print(e)
                    continue
                except Exception as e:
                    print(f"Error during AI interaction: {e}")  # Handle API errors better
                    continue

                if chart_response is None:
                    print("Error: Failed to get a chart response from the AI assistant.")
                    continue

                # Generate chart from JSON response
                fig = generate_chart_from_json(chart_response)

                if fig:
                    plt.tight_layout(pad=2)  # Apply tight_layout here
                    plt.show()  # Display the chart only if a figure was returned

            except ValueError:
                print("Invalid input. Please enter a number corresponding to a prompt.")

        else:
            print("Invalid choice. Please try again.")

if __name__ == "__main__":
    main()


               MAIN MENU                

1. 💬 Ask an open question about the data (no chart option)
2. 📈 Select a chart Or cross-table option from predefined prompts
3. ❌ Exit

Please enter your choice (1-3): 2

You selected option 2.

Select a chart prompt:
1. Compare the customer satisfaction scores across different products using a column chart.
2. Compare the purchase frequencies between premium and budget products using a column chart.
3. Show the count of ratings (1 to 5) by gender using a grouped column chart.
4. Show the distribution of purchase contexts (e.g., online purchase, in-store purchase) using a pie chart.
5. Show the gender distribution among customers for premium product using a pie chart.
6. Show the ratings provided by customers in different income levels using a column chart.
7. Show the variation in purchase frequencies across regions using a column chart.
8. Show the relationship between customer satisfaction scores and ratings for each product using a scatter