In [106]:
import pyodbc
import google.generativeai as genai
import pandas as pd
from IPython.display import display
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np


In [107]:
genai.configure(api_key="AIzaSyA_F6QrdhtzQ-VPWQVugsGNWX9sg9EjN7Y")
model = genai.GenerativeModel('gemini-1.5-pro-latest')

In [108]:
def setup_database_connection():
    conn = pyodbc.connect(
        "DRIVER={ODBC Driver 17 for SQL Server};"
        "SERVER=PC-GHILEB;"
        "DATABASE=WireBreak;"
        "Trusted_Connection=yes;"
    )
    return conn, conn.cursor()

conn, cursor = setup_database_connection()

In [109]:
def get_schema(cursor):
    cursor.execute("""
        SELECT t.name AS table_name, c.name AS column_name, ty.name AS type_name
        FROM sys.tables t
        JOIN sys.columns c ON t.object_id = c.object_id
        JOIN sys.types ty ON c.user_type_id = ty.user_type_id
        ORDER BY t.name, c.column_id
    """)
    schema = {}
    for table, column, dtype in cursor.fetchall():
        schema.setdefault(table, []).append(f"{column} ({dtype})")
    return "\n".join([f"{table}({', '.join(cols)})" for table, cols in schema.items()])

In [110]:
SCHEMA_INFO = get_schema(cursor)
print("Database Schema:")
print(SCHEMA_INFO)

Database Schema:
alembic_version(version_num (varchar))
Dim_Calendrier(Cal_id (smallint), Cal_Date (date), Cal_Jour (char), Cal_Mois (varchar), Cal_NumMois (int), Cal_Annee (int), Cal_Semaine (int), Cal_NomJour (varchar), Cal_Trimestre (smallint), Cal_Semestre (smallint))
Inventory(plant (varchar), supplier (varchar), Year (int), initial_stock (real))
machine(codeMachine (varchar), typeM (varchar), is_active (bit))
machinetype(machinetype (varchar), minBreakDiameter (real), maxBreakDiameter (real))
notification(id (int), user_id (varchar), message (varchar), created_at (datetime), is_read (bit))
plant(Plant (varchar), Region (varchar), Sales_Company (int), Inv_Company (int), plant_Description (varchar))
supplier(supplierid (varchar), is_active (bit), target (float))
sysdiagrams(name (sysname), principal_id (int), diagram_id (int), version (int), definition (varbinary))
token_blacklist(id (int), jti (nvarchar), created_at (datetime))
user(id (int), name (nvarchar), email (nvarchar), pas

In [111]:
def is_relevant_query(user_query, model, schema_info):
    prompt = f"""
    You are a database query classifier. Your task is to determine if this user query:
    \"{user_query}\"
    
    is relevant to a database with this schema:
    {schema_info}
    
    Respond with ONLY 'True' if the query is relevant and can be answered with this database,
    or 'False' if the query is completely unrelated to this database context.
    """
    
    response = model.generate_content(prompt)
    return response.text.strip().lower() == 'true'

In [112]:
def execute_query(sql_query, cursor):
    try:
        cursor.execute(sql_query)
        columns = [column[0] for column in cursor.description]
        results = cursor.fetchall()
        return pd.DataFrame.from_records(results, columns=columns)
    except Exception as e:
        return f"Error executing query: {str(e)}"

In [113]:
def smart_visualization_fallback(df, user_query):
    """Generate visualization without LLM when needed"""
    try:
        # Try to automatically determine best chart type
        num_cols = len(df.columns)
        
        # Single column - histogram
        if num_cols == 1:
            fig = px.histogram(df, x=df.columns[0], title=f"Distribution of {df.columns[0]}")
        
        # Two columns - relationship plot
        elif num_cols == 2:
            x_col, y_col = df.columns[0], df.columns[1]
            
            if df[x_col].nunique() < 10 and df[y_col].nunique() > 10:
                fig = px.bar(df, x=x_col, y=y_col, title=f"{y_col} by {x_col}")
            elif pd.api.types.is_datetime64_any_dtype(df[x_col]):
                fig = px.line(df, x=x_col, y=y_col, title=f"{y_col} over time")
            else:
                fig = px.scatter(df, x=x_col, y=y_col, title=f"{y_col} vs {x_col}")
        
        # Multiple columns - faceted or subplots
        else:
            fig = px.scatter_matrix(df, title="Multi-variable Relationships")
            
        return fig
    
    except Exception as e:
        print(f"Automatic visualization failed: {str(e)}")
        return None

In [114]:
def generate_plotly_code(df, user_query, model):
    """Generate appropriate visualization for any query result"""
    # Analyze dataframe structure
    num_cols = len(df.columns)
    num_rows = len(df)
    date_cols = [col for col in df.columns if pd.api.types.is_datetime64_any_dtype(df[col])]
    numeric_cols = [col for col in df.columns if np.issubdtype(df[col].dtype, np.number)]
    cat_cols = [col for col in df.columns if col not in numeric_cols + date_cols]
    
    # Determine best chart type automatically
    chart_type = determine_chart_type(df, numeric_cols, date_cols, cat_cols)
    
    prompt = f"""
    DATAFRAME STRUCTURE:
    - Shape: {num_rows} rows × {num_cols} columns
    - Numeric columns: {numeric_cols}
    - Categorical columns: {cat_cols}
    - Date columns: {date_cols}
    - Suggested chart type: {chart_type}
    
    USER QUESTION: "{user_query}"
    
    Generate Plotly visualization code with these requirements:
    1. MUST start with `import plotly.graph_objects as go`
    2. Use {chart_type} as the primary chart type
    3. Include proper titles and axis labels based on the user query
    4. Make the visualization clear and professional
    5. Return ONLY the Python code wrapped in ```python ``` blocks
    
    Example for bar chart:
    ```python
    import plotly.graph_objects as go
    fig = go.Figure()
    fig.add_trace(go.Bar(x=df['category'], y=df['value']))
    fig.update_layout(title='Clear Title', xaxis_title='X Label', yaxis_title='Y Label')
    ```
    """
    
    response = model.generate_content(prompt)
    code = response.text.strip()
    if '```python' in code:
        code = code.split('```python')[1].split('```')[0].strip()
    return code

def determine_chart_type(df, numeric_cols, date_cols, cat_cols):
    """Automatically determine the best chart type based on data characteristics"""
    num_cols = len(df.columns)
    
    if len(date_cols) >= 1 and len(numeric_cols) >= 1:
        return "line chart" if len(df) > 10 else "bar chart"
    elif len(cat_cols) >= 1 and len(numeric_cols) >= 1:
        if len(df) <= 7:
            return "pie chart"
        elif df[numeric_cols[0]].nunique() <= 12:
            return "bar chart"
        else:
            return "histogram"
    elif len(numeric_cols) >= 2:
        return "scatter plot"
    else:
        return "bar chart"

def execute_plotly_code(code, df, user_query):
    """Safely execute visualization code with comprehensive checks"""
    if not code or not isinstance(df, pd.DataFrame) or df.empty:
        return smart_visualization_fallback(df, user_query)
    
    try:
        # Create execution environment with only allowed objects
        allowed_objects = {
            'go': go,
            'px': px,
            'df': df.copy(),
            'make_subplots': make_subplots,
            'np': np,
            'pd': pd
        }
        
        # Execute the code
        exec(code, allowed_objects)
        
        # Find the figure object
        fig = None
        for fig_name in ['fig', 'figure', 'plot']:
            if fig_name in allowed_objects:
                fig = allowed_objects[fig_name]
                break
                
        if fig is None:
            return smart_visualization_fallback(df, user_query)
            
        # Standardize the figure layout
        if not fig.layout.title.text:
            fig.update_layout(title=user_query[:50])
        if not fig.layout.xaxis.title.text and len(df.columns) > 0:
            fig.update_layout(xaxis_title=df.columns[0])
        if not fig.layout.yaxis.title.text and len(df.columns) > 1:
            fig.update_layout(yaxis_title=df.columns[1])
            
        return fig
        
    except Exception as e:
        print(f"Visualization code execution failed: {str(e)}")
        return smart_visualization_fallback(df, user_query)

In [115]:
def execute_plotly_code(code, df):
    """Safely execute visualization code with comprehensive checks"""
    if not code or not isinstance(df, pd.DataFrame) or df.empty:
        return smart_visualization_fallback(df, "Data Overview")
    
    try:
        # Create execution environment with only allowed objects
        allowed_objects = {
            'go': go,
            'px': px,
            'df': df.copy(),
            'make_subplots': make_subplots,
            'np': np
        }
        
        # Add all plotly graph objects to allowed objects
        for name in dir(go):
            if not name.startswith('_'):
                allowed_objects[name] = getattr(go, name)
        
        # Execute the code
        exec(code, allowed_objects)
        
        # Try to find the figure in various possible names
        for fig_name in ['fig', 'figure', 'plot']:
            if fig_name in allowed_objects:
                return allowed_objects[fig_name]
                
        return smart_visualization_fallback(df, "Data Overview")
        
    except Exception as e:
        print(f"Visualization code execution failed: {str(e)}")
        print(f"Problematic code:\n{code}")
        return smart_visualization_fallback(df, "Data Overview")

In [116]:
def natural_language_to_sql(user_query, model, cursor, schema_info):
    # Generate SQL
    prompt = f"""
    Database Schema:
    {schema_info}

    Convert this to SQL (ONLY the query, using SQL Server syntax):
    - Use YEAR() instead of strftime('%Y')
    - Use MONTH() instead of strftime('%m')
    - Use DAY() instead of strftime('%d')
    - Use CONVERT() for date formatting

    User query: \"{user_query}\"
    """
    response = model.generate_content(prompt)
    sql_query = response.text.strip().replace('```sql', '').replace('```', '').strip()

    # Execute query
    results = execute_query(sql_query, cursor)

    # Generate explanation
    is_report = any(keyword in user_query.lower() for keyword in 
                   ["analyze", "report", "trend", "compare", "summary", "breakdown"])

    report_instruction = """Generate output EXACTLY in this format:

    [Report Title Brief Description]

    - [Item 1]: [Metric] ([Percentage if available])
    - [Item 2]: [Metric]
    - (...)

    Helpful Insights: [The top 3(if possible) remarks made from the result of SQL query execution,The remarks provided need to be helpful to a Business Intelligence perspective].

    Suggested Actions: [The top 3(if possible) actions that needs to be done about the remarks made in the Helpful Insights section even if the insights are not enough for you to suggest actions , suggest the ones you would if you were running a multi million dollar company that produces cables for car manufacturing and other industries  ].

    RULES:
    1. Title must be <10 words
    2. List all the items returned in the results of the SQL execution result 
    3. Helpful Insights must include a percentage or multiplier
    4. Suggested Actions must specify both what and where
    5. Never show SQL or technical details
    6. Use same punctuation/capitalization as example"""

    if is_report :
        explanation_prompt = f"""
        Database Context:
        {schema_info}

        User Question: {user_query}
        SQL Used: {sql_query}
        Query Results: {str(results)}

        {report_instruction}

        STRICT FORMATTING:
        - Blank line after title
        - Dash-start for list items
        - "Helpful Insights:" and "Suggested Actions:" labels exactly as shown
        """
    else:
        explanation_prompt = f"""
        Database Context:
        {schema_info}

        User Question: {user_query}
        SQL Used: {sql_query}
        Query Results: {str(results)}

        Answer in 1 line with the key number.

        STRICT FORMATTING:
        - No report structure, just answer the question directly.
        - Do NOT provide Helpful Insights or Suggested Actions.
        """

    explanation = model.generate_content(explanation_prompt)

    # Generate visualization
    plotly_code = generate_plotly_code(results, user_query, model)
    visualization = execute_plotly_code(plotly_code, results)

    return {
        "query": sql_query,
        "results": results,
        "explanation": explanation.text,
        "visualization": visualization,
        "plotly_code": plotly_code
    }



In [117]:
def interactive_demo():
    print("NL-to-SQL System (type 'exit' to quit)")
    while True:
        user_input = input("\nEnter your question: ")
        if user_input.lower() == 'exit':
            break
            
        result = natural_language_to_sql(user_input, model, cursor, SCHEMA_INFO)
        
        if isinstance(result['results'], str) and 'Error' in result['results']:
            print(f"\n⚠️ {result['results']}")
        else:
            print("\nGenerated SQL:")
            print(result['query'])
            print("\nResults:")
            display(result['results'])
            print("\nAnswer:")
            print(result['explanation'])


In [118]:
interactive_demo()

NL-to-SQL System (type 'exit' to quit)



Generated SQL:
SELECT
    Cal_Mois,  -- Month name
    COUNT(wb.id) AS TotalWireBreaks
FROM
    wirebreakDetails AS wb
JOIN
    Dim_Calendrier AS cal ON wb.Break_date = cal.Cal_Date
WHERE
    YEAR(wb.Break_date) = 2024
GROUP BY
    Cal_Mois
ORDER BY
    Cal_Mois;

Results:


Unnamed: 0,Cal_Mois,TotalWireBreaks
0,April,282
1,August,117
2,February,358
3,January,645
4,July,215
5,June,114
6,March,380
7,May,110
8,November,146
9,October,425



Answer:
Wire Breaks by Month in 2024

- January: 645
- February: 358
- March: 380
- April: 282
- May: 110
- June: 114
- July: 215
- August: 117
- September: 251
- October: 425
- November: 146

Helpful Insights: January had the highest number of wire breaks (representing about 22% of total breaks in 2024 so far), followed by February and then March.  October also saw a significant spike in wire breaks.  Wire breaks were lowest in May, June and August.

Suggested Actions:
- Investigate the causes of the high number of wire breaks in January and implement corrective actions at all plants.
- Analyze the spike in October wire breaks and address any systemic issues found at the affected plants.
- Review maintenance schedules and procedures for all machinery across all plants to minimize future breaks, especially during peak months.

