In [1]:
import pyodbc
import google.generativeai as genai
import pandas as pd
from IPython.display import display

In [2]:
genai.configure(api_key="AIzaSyA_F6QrdhtzQ-VPWQVugsGNWX9sg9EjN7Y")
model = genai.GenerativeModel('gemini-1.5-pro-latest')

In [3]:
def setup_database_connection():
    conn = pyodbc.connect(
        "DRIVER={ODBC Driver 17 for SQL Server};"
        "SERVER=PC-GHILEB;"
        "DATABASE=WireBreak;"
        "Trusted_Connection=yes;"
    )
    return conn, conn.cursor()

conn, cursor = setup_database_connection()

In [4]:
def get_schema(cursor):
    cursor.execute("""
        SELECT t.name AS table_name, c.name AS column_name, ty.name AS type_name
        FROM sys.tables t
        JOIN sys.columns c ON t.object_id = c.object_id
        JOIN sys.types ty ON c.user_type_id = ty.user_type_id
        ORDER BY t.name, c.column_id
    """)
    schema = {}
    for table, column, dtype in cursor.fetchall():
        schema.setdefault(table, []).append(f"{column} ({dtype})")
    return "\n".join([f"{table}({', '.join(cols)})" for table, cols in schema.items()])

In [5]:
SCHEMA_INFO = get_schema(cursor)
print("Database Schema:")
print(SCHEMA_INFO)

Database Schema:
alembic_version(version_num (varchar))
Dim_Calendrier(Cal_id (smallint), Cal_Date (date), Cal_Jour (char), Cal_Mois (varchar), Cal_NumMois (int), Cal_Annee (int), Cal_Semaine (int), Cal_NomJour (varchar), Cal_Trimestre (smallint), Cal_Semestre (smallint))
Inventory(plant (varchar), supplier (varchar), Year (int), initial_stock (real))
machine(codeMachine (varchar), typeM (varchar))
machinetype(machinetype (varchar), minBreakDiameter (real), maxBreakDiameter (real))
notification(id (int), user_id (varchar), message (varchar), created_at (datetime), is_read (bit))
plant(Plant (varchar), Region (varchar), Sales_Company (int), Inv_Company (int), plant_Description (varchar))
supplier(supplierid (varchar))
sysdiagrams(name (sysname), principal_id (int), diagram_id (int), version (int), definition (varbinary))
token_blacklist(id (int), jti (nvarchar), created_at (datetime))
user(id (int), name (nvarchar), email (nvarchar), password_hash (nvarchar), role (nvarchar), plant_name

In [6]:
def is_relevant_query(user_query, model, schema_info):
    prompt = f"""
    You are a database query classifier. Your task is to determine if this user query:
    \"{user_query}\"
    
    is relevant to a database with this schema:
    {schema_info}
    
    Respond with ONLY 'True' if the query is relevant and can be answered with this database,
    or 'False' if the query is completely unrelated to this database context.
    """
    
    response = model.generate_content(prompt)
    return response.text.strip().lower() == 'true'

In [7]:
def execute_query(sql_query, cursor):
    try:
        cursor.execute(sql_query)
        columns = [column[0] for column in cursor.description]
        results = cursor.fetchall()
        return pd.DataFrame.from_records(results, columns=columns)
    except Exception as e:
        return f"Error executing query: {str(e)}"

In [10]:
def get_few_shot_examples():
    return """
    === FEW-SHOT EXAMPLES ===
    
    [SIMPLE QUERY - COUNT]
    User: "Count breaks for Plant A in January 2024"
    SQL: SELECT SUM(num_of_break) FROM wirebreakDetails WHERE Plant = 'A' AND MONTH(Break_date) = 1 AND YEAR(Break_date) = 2024
    Response: "Plant A had 127 wire breaks in January 2024."

    [SIMPLE QUERY - LIST]
    User: "Show machines with >15 breaks last month"
    SQL: SELECT Machine_Number, SUM(num_of_break) FROM wirebreakDetails WHERE Break_date >= DATEADD(month, -1, GETDATE()) GROUP BY Machine_Number HAVING SUM(num_of_break) > 15
    Response: "Machines exceeding 15 breaks: M-107 (23), M-203 (18)."

    [REPORT - TREND ANALYSIS]
    User: "Analyze break trends by week for Plant B"
    SQL: SELECT Week_Number, SUM(num_of_break) FROM wirebreakDetails WHERE Plant = 'B' GROUP BY Week_Number ORDER BY Week_Number
    Response: '''
    **Summary**: Breaks peaked in Week 32 (42 incidents), 58% higher than average.
    **Root Cause**: 80% of Week 32 breaks were on Machine M-107 with diameter < 0.3mm.
    **Next Steps**: Inspect M-107's tension settings and review Supplier Gamma's 0.3mm wires.
    '''

    [REPORT - COMPARISON]
    User: "Compare material vs process breaks by supplier"
    SQL: SELECT w.Supplier, b.typeB, COUNT(*) FROM wirebreakDetails w JOIN wirebreaktype b ON w.Wire_Break_Type = b.wirebreaktype GROUP BY w.Supplier, b.typeB
    Response: '''
    **Summary**: Supplier Alpha has 3x more material breaks (45) than process breaks (15).
    **Root Cause**: Alpha's material breaks correlate with humidity > 80%.
    **Next Steps**: Store Alpha's wires in climate-controlled areas during summer.
    '''

    [COMPLEX JOIN]
    User: "Find batches with breaks outside diameter specs"
    SQL: SELECT w.Batch_Number FROM wirebreakDetails w JOIN machinetype m ON w.Machine_Number = m.machinetype WHERE w.Break_Diameter < m.minBreakDiameter OR w.Break_Diameter > m.maxBreakDiameter
    Response: "Batches with out-of-spec diameters: BX-205, BX-209 (all from Supplier Alpha)."

    [TIME-BASED ANALYSIS]
    User: "Show monthly consumption vs break rates"
    SQL: SELECT MONTH(Break_date) AS month, SUM(c.Real_Consumption), SUM(w.num_of_break)/SUM(c.Real_Consumption) FROM wirebreakDetails w JOIN wireConsumption c ON w.Plant = c.Plant AND w.Week_Number = c.Week_Number GROUP BY MONTH(Break_date)
    Response: '''
    **Trend**: Break rate doubled in July (0.15 breaks/ton) vs June (0.07).
    **Threshold**: Rates exceed 0.1 when temperature > 30°C.
    **Action**: Implement cooling systems for wires in summer months.
    '''
    """

In [11]:
def natural_language_to_sql(user_query, model, cursor, schema_info):
    # Generate SQL
    prompt = f"""
    {get_few_shot_examples()}
    
    Database Schema:
    {schema_info}
    
    Convert this to SQL (ONLY the query):
    "{user_query}"
    """
    response = model.generate_content(prompt)
    sql_query = response.text.strip().replace('```sql', '').replace('```', '').strip()
    
    # Execute query
    results = execute_query(sql_query, cursor)
    
    # Generate response
    is_report = any(keyword in user_query.lower() for keyword in 
                   ["analyze", "report", "trend", "compare", "summary", "breakdown"])
    
    explanation_prompt = f"""
    {get_few_shot_examples()}
    
    Database Schema: {schema_info}
    User Question: {user_query}
    SQL Used: {sql_query}
    Results: {str(results)}
    
    {"Generate a MANAGER-READY report with **Summary/Root Cause/Next Steps** sections." if is_report 
     else "Answer concisely like the simple examples above."}
    """
    
    explanation = model.generate_content(explanation_prompt)
    
    return {
        "query": sql_query,
        "results": results,
        "explanation": explanation.text
    }


In [12]:
def interactive_demo():
    print("NL-to-SQL System (type 'exit' to quit)")
    while True:
        user_input = input("\nEnter your question: ")
        if user_input.lower() == 'exit':
            break
            
        result = natural_language_to_sql(user_input, model, cursor, SCHEMA_INFO)
        
        if isinstance(result['results'], str) and 'Error' in result['results']:
            print(f"\n⚠️ {result['results']}")
        else:
            print("\nGenerated SQL:")
            print(result['query'])
            print("\nResults:")
            display(result['results'])
            print("\nAnswer:")
            print(result['explanation'])


In [16]:
interactive_demo()

NL-to-SQL System (type 'exit' to quit)

Generated SQL:
SELECT DISTINCT Supplier FROM wirebreakDetails

Results:


Unnamed: 0,Supplier



Answer:
No suppliers found.


Generated SQL:
SELECT COUNT(DISTINCT supplierid) FROM supplier

Results:


Unnamed: 0,Unnamed: 1
0,30



Answer:
We have 30 suppliers.


Generated SQL:
SELECT COUNT(DISTINCT machinetype) FROM machinetype

Results:


Unnamed: 0,Unnamed: 1
0,2



Answer:
There are 2 machine types.


⚠️ Error executing query: ('42000', "[42000] [Microsoft][ODBC Driver 17 for SQL Server][SQL Server]Incorrect syntax near the keyword 'user'. (156) (SQLExecDirectW)")


ResourceExhausted: 429 You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits. [violations {
}
, links {
  description: "Learn more about Gemini API quotas"
  url: "https://ai.google.dev/gemini-api/docs/rate-limits"
}
, retry_delay {
  seconds: 34
}
]