In [1]:
import os
from sqlalchemy import create_engine,text

DB_HOST = os.getenv("DB_HOST", "localhost") 
DATABASE_URL = f"oracle+oracledb://hr:hr@{DB_HOST}:1521/?service_name=XEPDB1"

engine = create_engine(
    DATABASE_URL,
    pool_pre_ping=True
)

In [2]:
# import os
# from sqlalchemy import create_engine

# DATABASE_URL = os.getenv("DATABASE_URL")

# if not DATABASE_URL:
#     DB_HOST = os.getenv("DB_HOST", "localhost")
#     # Add READ ONLY mode in the connection string
#     DATABASE_URL = f"oracle+oracledb://hr:hr@{DB_HOST}:1521/?service_name=XEPDB1"


# engine = create_engine(
#     DATABASE_URL,
#     pool_pre_ping=True,
#     connect_args={
#         'events': True,
#         'readonly': True  # Explicit readonly flag
#     }
# )

In [3]:
sql = """UPDATE EMPLOYEES e
SET SALARY = SALARY * 1.05
WHERE EXTRACT(YEAR FROM SYSDATE) - EXTRACT(YEAR FROM HIRE_DATE) >= 5"""

sql2 = """SELECT e.EMPLOYEE_ID, e.FIRST_NAME, e.LAST_NAME, e.SALARY
FROM EMPLOYEES e
JOIN DEPARTMENTS d ON e.DEPARTMENT_ID = d.DEPARTMENT_ID
WHERE e.SALARY > (
SELECT AVG(SALARY)
FROM EMPLOYEES
WHERE DEPARTMENT_ID = d.DEPARTMENT_ID
)"""

with engine.connect() as conn:
    result = conn.execute(text(sql2)) # type: ignore
    for row in result.fetchmany(5):
        print(row)

(100, 'Steven', 'King', Decimal('24000'))
(103, 'Alexander', 'Hunold', Decimal('9000'))
(104, 'Bruce', 'Ernst', Decimal('6000'))
(108, 'Nancy', 'Greenberg', Decimal('12008'))
(109, 'Daniel', 'Faviet', Decimal('9000'))


In [5]:
sql = text("""SELECT e.EMPLOYEE_ID, e.FIRST_NAME, e.LAST_NAME, e.SALARY
FROM EMPLOYEES e
JOIN DEPARTMENTS d ON e.DEPARTMENT_ID = d.DEPARTMENT_ID
WHERE e.SALARY > (
SELECT AVG(SALARY)
FROM EMPLOYEES
WHERE DEPARTMENT_ID = d.DEPARTMENT_ID
)""")

In [6]:
with engine.connect() as conn:
    result = conn.execute(sql)
    for row in result.fetchmany(5):
        print(row)

(100, 'Steven', 'King', Decimal('24000'))
(103, 'Alexander', 'Hunold', Decimal('9000'))
(104, 'Bruce', 'Ernst', Decimal('6000'))
(108, 'Nancy', 'Greenberg', Decimal('12008'))
(109, 'Daniel', 'Faviet', Decimal('9000'))


In [7]:
from sqlalchemy import create_engine,event


@event.listens_for(engine, "connect")
def set_default_schema(dbapi_connection, connection_record):
    cursor = dbapi_connection.cursor()
    cursor.execute("ALTER SESSION SET CURRENT_SCHEMA = HR")
    cursor.close()    


In [57]:
def extract_oracle_schema(engine, schema="HR") -> str:
    query = f"""
    SELECT
        table_name,
        column_name,
        data_type
    FROM all_tab_columns
    WHERE owner = '{schema}'
    ORDER BY table_name, column_id
    """

    schema_dict = {}

    with engine.connect() as conn:
        rows = conn.exec_driver_sql(query).fetchall()

    for table, column, dtype in rows:
        schema_dict.setdefault(table, []).append(f"{column} {dtype}")

    schema_text = []
    for table, cols in schema_dict.items():
        schema_text.append(
            f"{table} ({', '.join(cols)})"
        )

    return "\n".join(schema_text)


In [58]:
schema = extract_oracle_schema(engine=engine)

In [59]:
schema

'COUNTRIES (COUNTRY_ID CHAR, COUNTRY_NAME VARCHAR2, REGION_ID NUMBER)\nDEPARTMENTS (DEPARTMENT_ID NUMBER, DEPARTMENT_NAME VARCHAR2, MANAGER_ID NUMBER, LOCATION_ID NUMBER)\nEMPLOYEES (EMPLOYEE_ID NUMBER, FIRST_NAME VARCHAR2, LAST_NAME VARCHAR2, EMAIL VARCHAR2, PHONE_NUMBER VARCHAR2, HIRE_DATE DATE, JOB_ID VARCHAR2, SALARY NUMBER, COMMISSION_PCT NUMBER, MANAGER_ID NUMBER, DEPARTMENT_ID NUMBER)\nEMP_DETAILS_VIEW (EMPLOYEE_ID NUMBER, JOB_ID VARCHAR2, MANAGER_ID NUMBER, DEPARTMENT_ID NUMBER, LOCATION_ID NUMBER, COUNTRY_ID CHAR, FIRST_NAME VARCHAR2, LAST_NAME VARCHAR2, SALARY NUMBER, COMMISSION_PCT NUMBER, DEPARTMENT_NAME VARCHAR2, JOB_TITLE VARCHAR2, CITY VARCHAR2, STATE_PROVINCE VARCHAR2, COUNTRY_NAME VARCHAR2, REGION_NAME VARCHAR2)\nJOBS (JOB_ID VARCHAR2, JOB_TITLE VARCHAR2, MIN_SALARY NUMBER, MAX_SALARY NUMBER)\nJOB_HISTORY (EMPLOYEE_ID NUMBER, START_DATE DATE, END_DATE DATE, JOB_ID VARCHAR2, DEPARTMENT_ID NUMBER)\nLOCATIONS (LOCATION_ID NUMBER, STREET_ADDRESS VARCHAR2, POSTAL_CODE VARCH

In [None]:
# import oracledb
# import requests
# import re

# # 1. Database Connection Configuration
# # Ensure your Docker container is running and XEPDB1 is active
# DB_CONFIG = {
#     "user": "hr",
#     "password": "hr",
#     "dsn": "localhost:1521/XEPDB1"
# }

# # Raw URL for the GitHub SQL file
# SQL_URL = "https://raw.githubusercontent.com/bbrumm/databasestar/main/sample_databases/oracle_hr/02%20create%20tables.sql"
# # SQL_URL = "https://raw.githubusercontent.com/bbrumm/databasestar/refs/heads/main/sample_databases/oracle_hr/03%20populate.sql"



# def run_github_sql():
#     try:
#         # Step 1: Download the SQL script
#         print(f"Downloading SQL from GitHub...")
#         response = requests.get(SQL_URL)
#         if response.status_code != 200:
#             print("Failed to download script. Check the URL.")
#             return
        
#         sql_content = response.text

#         # Step 2: Connect to Oracle
#         # Using Thin mode (default in python-oracledb)
#         conn = oracledb.connect(**DB_CONFIG)
#         cursor = conn.cursor()
#         print(" Connected to XEPDB1")

#         # Step 3: Clean up and Split SQL
#         # Oracle executes one statement at a time. We split by semicolon.
#         # We also remove the 'REM' comments often found in Oracle scripts.
#         statements = sql_content.split(';')
        
#         for statement in statements:
#             clean_stmt = statement.strip()
#             # Skip empty statements or comments
#             if not clean_stmt or clean_stmt.startswith('REM'):
#                 continue
                
#             try:
#                 print(f"Executing: {clean_stmt[:50]}...")
#                 cursor.execute(clean_stmt)
#             except oracledb.DatabaseError as e:
#                 error, = e.args
#                 # Ignore "table does not exist" errors during drops
#                 if error.code != 942: 
#                     print(f" SQL Error: {error.message}")

#         conn.commit()
#         print(" HR Schema created successfully from GitHub source!")

#     except Exception as e:
#         print(f" Error: {e}")
#     finally:
#         if 'conn' in locals():
#             conn.close()

# if __name__ == "_main_":
#     run_github_sql()

In [None]:
# import oracledb
# import requests

# # 1 Database Connection Configuration
# DB_CONFIG = {
#     "user": "hr",
#     "password": "hr",
#     "dsn": "localhost:1521/XEPDB1"
# }

# # 2 GitHub SQL script URL (Populate HR tables)
# SQL_URL = "https://raw.githubusercontent.com/bbrumm/databasestar/refs/heads/main/sample_databases/oracle_hr/03%20populate.sql"
# # SQL_URL = "https://raw.githubusercontent.com/bbrumm/databasestar/main/sample_databases/oracle_hr/02%20create%20tables.sql"


# def run_github_sql():
#     try:
#         # Step 1: Download the SQL script
#         print("Downloading SQL from GitHub...")
#         response = requests.get(SQL_URL)
#         if response.status_code != 200:
#             print(" Failed to download script. Check the URL.")
#             return

#         sql_content = response.text

#         # Step 2: Connect to Oracle using oracledb (Thin mode)
#         conn = oracledb.connect(**DB_CONFIG)
#         cursor = conn.cursor()
#         print(" Connected to XEPDB1 as HR")

#         # Step 3: Clean and split SQL statements
#         # Remove 'REM' comments and split by semicolon
#         statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip() and not stmt.strip().startswith('REM')]

#         for stmt in statements:
#             try:
#                 print(f"Executing: {stmt[:60]}...")
#                 cursor.execute(stmt)
#             except oracledb.DatabaseError as e:
#                 error, = e.args
#                 # Ignore "table does not exist" errors (e.g., during drops)
#                 if error.code != 942:
#                     print(f" SQL Error {error.code}: {error.message}")

#         conn.commit()
#         print(" HR schema populated successfully from GitHub!")

#     except Exception as e:
#         print(f" Error: {e}")
#     finally:
#         if 'conn' in locals():
#             conn.close()

# #  Correct entry point
# if __name__ == "__main__":
#     run_github_sql()


In [None]:
# import oracledb
# import requests

# DB_CONFIG = {
#     "user": "hr",
#     "password": "hr",
#     "dsn": "localhost:1521/XEPDB1" # Change to 1522 if needed
# }

# # The sequence is critical: Create -> Populate
# URLS = [
#     "https://raw.githubusercontent.com/bbrumm/databasestar/main/sample_databases/oracle_hr/02%20create%20tables.sql",
#     "https://raw.githubusercontent.com/bbrumm/databasestar/main/sample_databases/oracle_hr/03%20populate.sql"
# ]

# def setup_hr_schema():
#     try:
#         conn = oracledb.connect(**DB_CONFIG)
#         cursor = conn.cursor()
        
#         # KEY FIX: Ensure we are working in the HR schema context
#         cursor.execute("ALTER SESSION SET CURRENT_SCHEMA = HR")
#         print(" Connected and switched to HR schema")

#         for url in URLS:
#             print(f"\nFetching: {url.split('/')[-1]}...")
#             content = requests.get(url).text
            
#             # Split by semicolon, but ignore semicolons inside quotes if they exist
#             statements = content.split(';')
            
#             for stmt in statements:
#                 clean_stmt = stmt.strip()
#                 if not clean_stmt or clean_stmt.startswith(('REM', 'SET', 'PROMPT')):
#                     continue
                
#                 try:
#                     cursor.execute(clean_stmt)
#                 except oracledb.Error as e:
#                     # Ignore "table already exists" (955) or "table does not exist" (942) during setup
#                     if e.args[0].code not in [942, 955]:
#                         print(f" SQL Notice: {e}")

#         conn.commit()
#         print("\n Database is now fully created and populated!")

#     except Exception as e:
#         print(f" Critical Error: {e}")
#     finally:
#         if 'conn' in locals(): conn.close()

# if __name__ == "__main__":
#     setup_hr_schema()

In [7]:
import arabic_reshaper
from bidi.algorithm import get_display
from langchain_community.utilities import SQLDatabase
from langchain_ollama import ChatOllama
import pandas as pd
from sqlalchemy import create_engine
from langchain_community.utilities import SQLDatabase
from decimal import Decimal
from sqlalchemy import text
from typing import List, Dict, Any



model = "qwen2.5-coder:3b"
translator_model = "qwen2.5:3b-instruct"



def fix_arabic_for_terminal(text: str) -> str:
    try:
        reshaped = arabic_reshaper.reshape(text)
        return str(get_display(reshaped))
    except Exception as e:
        print(f"[warn] Arabic reshaping failed, using raw text. Error: {e}")
        return text

def build_client() -> ChatOllama:
    return ChatOllama(model="qwen2.5-coder:3b", temperature=0)

def build_translate_client() -> ChatOllama:
    return ChatOllama(model="qwen2.5:3b-instruct", temperature=0)

def chat_once(prompt: str, client: ChatOllama) -> str:
    try:
        resp = client.invoke(prompt)  # resp is an AIMessage
        # LangChain messages always have `.content` as str or list of chunks
        content = getattr(resp, "content", "")
        if isinstance(content, list):
            # join any chunked content
            content = "".join(
                part.get("text", "") if isinstance(part, dict) else str(part)
                for part in content
            )
        return content or ""
    except Exception as e:
        print("Ollama call failed:", e)
        return ""

def translate_question(question: str, client: ChatOllama) -> str:
    """Translate Arabic to English only if needed, keep English as-is."""

    # Quick heuristic: Arabic has these common characters
    arabic_chars = any(0x0600 <= ord(c) <= 0x06FF for c in question)

    if not arabic_chars:
        print(f"English question (kept original): '{question}'")
        return question

    # Arabic detected - translate
    print(f" Arabic detected, translating...")
    prompt = f"""Translate ONLY this Arabic question to clear English for SQL querying.


Arabic: {question}


English:"""

    english = chat_once(prompt, client).strip()

    # Retry if bad translation
    if not english or len(english) < 5 or "?" not in english:
        prompt = f"""Translate ONLY: {question}


One English question:"""
        english = chat_once(prompt, client).strip()

    print(f" Translated: '{english}'")
    return english

def extract_oracle_schema(engine, schema="HR") -> str:
    query = f"""
    SELECT
        table_name,
        column_name,
        data_type
    FROM all_tab_columns
    WHERE owner = '{schema}'
    ORDER BY table_name, column_id
    """

    schema_dict = {}

    with engine.connect() as conn:
        rows = conn.exec_driver_sql(query).fetchall()

    for table, column, dtype in rows:
        schema_dict.setdefault(table, []).append(f"{column} {dtype}")

    schema_text = []
    for table, cols in schema_dict.items():
        schema_text.append(
            f"{table} ({', '.join(cols)})"
        )

    return "\n".join(schema_text)

def generate_sql(question: str, schema: str, client: ChatOllama) -> str:
    prompt = f"""
You are an expert Oracle SQL assistant for the HR database.

SCHEMA:
{schema}

IMPORTANT TABLES:
- EMPLOYEES (EMPLOYEE_ID, FIRST_NAME, LAST_NAME, SALARY, DEPARTMENT_ID, JOB_ID, HIRE_DATE)
- DEPARTMENTS (DEPARTMENT_ID, DEPARTMENT_NAME, LOCATION_ID)
- JOBS (JOB_ID, JOB_TITLE, MIN_SALARY, MAX_SALARY)
- LOCATIONS (LOCATION_ID, CITY, COUNTRY_ID)
- COUNTRIES (COUNTRY_ID, COUNTRY_NAME, REGION_ID)
- REGIONS (REGION_ID, REGION_NAME)

RULES:
- Use Oracle SQL syntax only.
- Use table aliases (employees e, departments d, jobs j).
- Always qualify columns with table aliases.
- Use FETCH FIRST N ROWS ONLY instead of LIMIT.
- For year extraction, use EXTRACT(YEAR FROM date_column).
- Use SYSDATE for current date.
- Return ONE valid Oracle SELECT query only.
- NO explanation, NO markdown.
- SELECT queries ONLY (NO INSERT, UPDATE, DELETE, DROP, CREATE, ALTER).
- Table and column names are UPPERCASE.
- DO NOT end the SQL statement with a semicolon (;).


Question:
{question}

SQL:
"""
    sql = chat_once(prompt, client).strip().strip("`")
    return sql

def is_safe_sql(sql: str) -> bool:
    """Block all DML/DDL - allow SELECT only for Oracle."""
    sql_upper = sql.strip().upper()
    
    # Block dangerous keywords
    dangerous = [
        'INSERT', 'UPDATE', 'DELETE', 'DROP', 'ALTER', 'CREATE', 
        'TRUNCATE', 'ATTACH', 'DETACH', 'REINDEX', 'ANALYZE',
        'PRAGMA', 'BEGIN', 'COMMIT', 'ROLLBACK', 'SAVEPOINT'
    ]
    
    for keyword in dangerous:
        if sql_upper.startswith(keyword):
            return False
    
    # Must start with SELECT
    return sql_upper.startswith('SELECT')


def ask_db(
    question: str,
    engine,                # SQLAlchemy engine
    schema: str,
    client                 # ChatOllama or LLM client
) -> tuple[str, str, List[Dict[str, Any]]]:
    sql = ""
    rows_as_dict: List[Dict[str, Any]] = []
    message = ""
    error_msg = None

    # Clean question
    question = question.strip("“”\"").strip(".")

    for attempt in range(2):
        #  Generate SQL from LLM
        sql = generate_sql(question, schema, client).rstrip(";")
        print("Raw SQL from model:\n", sql)

        # Safety check
        if not is_safe_sql(sql):
            message = "الاستعلام غير آمن ولا يمكن تنفيذه"
            return sql, message, []

        try:
            # Execute SQL using SQLAlchemy
            with engine.connect() as conn:
                result = conn.execute(text(sql))
                columns = result.keys()
                rows_as_dict = [
                    {col: float(val) if isinstance(val, Decimal) else val
                     for col, val in zip(columns, row)}
                    for row in result.fetchall()
                ]

            break  # success

        except Exception as e:
            error_msg = str(e)
            print("Execution failed:\n", error_msg)

            if attempt == 0:
                # Attempt repair once using LLM
                repair_prompt = f"""
You wrote this SQL:

{sql}

The Oracle database returned this error:
{error_msg}

Rewrite the query to fix the error.
Return ONLY valid Oracle SELECT SQL.
Do NOT use semicolons at the end.
"""
                sql = chat_once(repair_prompt, client).strip().rstrip(";")
            else:
                message = "حدث خطأ أثناء تنفيذ الاستعلام"
                return sql, message, []

    #  Empty result handling
    if not rows_as_dict:
        message = "لا توجد بيانات متاحة لهذا الطلب"
        return sql, message, []

    # Return SQL, message (empty string if OK), and rows
    return sql, message, rows_as_dict





# اعرض اسم القسم ومتوسط الرواتب فيه، لكن بس للأقسام اللي متوسط الرواتب أعلى من متوسط رواتب الشركة كلها
# مين الموظفين اللي اشتغلوا في نفس الوظيفة لمدة أطول من متوسط مدة الوظيفة لكل الموظفين؟
# اعرض الموظفين اللي تم تعيينهم قبل مديرهم
# اعرض الأقسام اللي ما فيهاش أي موظف راتبه أعلى من متوسط راتب الشركة
# اعرض متوسط المرتبات لكل Department، ورتّبهم من الأعلى للأقل.
#مين الموظفين اللي مرتباتهم أعلى من متوسط المرتبات في القسم بتاعهم؟


In [63]:
def get_engine():
    return create_engine(
        "oracle+cx_oracle://hr:hr@localhost:1521/?service_name=XEPDB1",
        pool_pre_ping=True
    )

engine = get_engine()
schema = extract_oracle_schema(engine=engine, schema="HR")
client = build_client()
translator_client = build_translate_client()

sql, message, rows_as_dict = ask_db("#مين الموظفين اللي مرتباتهم أعلى من متوسط المرتبات في القسم بتاعهم؟",engine,schema=schema,client=client)

Raw SQL from model:
 SELECT e.EMPLOYEE_ID, e.FIRST_NAME, e.LAST_NAME, e.SALARY, d.DEPARTMENT_NAME
FROM EMPLOYEES e
JOIN DEPARTMENTS d ON e.DEPARTMENT_ID = d.DEPARTMENT_ID
WHERE e.SALARY > (SELECT AVG(SALARY) FROM EMPLOYEES)


In [64]:
rows_as_dict

[{'employee_id': 201,
  'first_name': 'Michael',
  'last_name': 'Hartstein',
  'salary': 13000.0,
  'department_name': 'Marketing'},
 {'employee_id': 114,
  'first_name': 'Den',
  'last_name': 'Raphaely',
  'salary': 11000.0,
  'department_name': 'Purchasing'},
 {'employee_id': 203,
  'first_name': 'Susan',
  'last_name': 'Mavris',
  'salary': 6500.0,
  'department_name': 'Human Resources'},
 {'employee_id': 120,
  'first_name': 'Matthew',
  'last_name': 'Weiss',
  'salary': 8000.0,
  'department_name': 'Shipping'},
 {'employee_id': 123,
  'first_name': 'Shanta',
  'last_name': 'Vollman',
  'salary': 6500.0,
  'department_name': 'Shipping'},
 {'employee_id': 122,
  'first_name': 'Payam',
  'last_name': 'Kaufling',
  'salary': 7900.0,
  'department_name': 'Shipping'},
 {'employee_id': 121,
  'first_name': 'Adam',
  'last_name': 'Fripp',
  'salary': 8200.0,
  'department_name': 'Shipping'},
 {'employee_id': 103,
  'first_name': 'Alexander',
  'last_name': 'Hunold',
  'salary': 9000.0,
  

In [None]:
sql = """
SELECT e.EMPLOYEE_ID, e.FIRST_NAME, e.LAST_NAME, e.SALARY, d.DEPARTMENT_NAME
FROM EMPLOYEES e
JOIN DEPARTMENTS d ON e.DEPARTMENT_ID = d.DEPARTMENT_ID
WHERE e.SALARY > (SELECT AVG(SALARY) FROM EMPLOYEES)
"""

with engine.connect() as conn:
    result = conn.execute(text(sql))

    columns = result.keys()

    rows_as_dict = [
        {col: float(val) if isinstance(val, Decimal) else val for col, val in zip(columns, row)}
        for row in result.fetchall()
    ]

print("Columns:", columns)
print("First 3 rows as dict:")
for row in rows_as_dict[:3]:
    print(row)

Columns: RMKeyView(['employee_id', 'first_name', 'last_name', 'salary', 'department_name'])
First 3 rows as dict:
{'employee_id': 201, 'first_name': 'Michael', 'last_name': 'Hartstein', 'salary': 13000.0, 'department_name': 'Marketing'}
{'employee_id': 114, 'first_name': 'Den', 'last_name': 'Raphaely', 'salary': 11000.0, 'department_name': 'Purchasing'}
{'employee_id': 203, 'first_name': 'Susan', 'last_name': 'Mavris', 'salary': 6500.0, 'department_name': 'Human Resources'}


In [66]:
type(rows_as_dict)

list