<a href="https://colab.research.google.com/github/ulrischa/Text2sql/blob/main/Text2sqlimproved.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers torch dateparser sqlalchemy




In [None]:
pip install sqlalchemy transformers torch



In [None]:
import os
import sqlite3
import re
import datetime
import logging
from typing import List, Optional, Dict, Tuple
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.exc import SQLAlchemyError
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

DB_PATH = "example_database.db"
MODEL_NAME = "gaussalgo/T5-LM-Large-text2sql-spider"

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s')

class QueryIntent:
    def __init__(self) -> None:
        self.is_existence_query: bool = False
        self.is_count_query: bool = False
        self.time_range: Optional[Tuple[datetime.date, datetime.date]] = None
        self.original_question: str = ""

class DatabaseManager:
    def __init__(self, db_path: str) -> None:
        self.db_path = db_path
        self.engine = create_engine(f"sqlite:///{self.db_path}", echo=False)

    def create_example_database(self) -> None:
        if os.path.exists(self.db_path):
            os.remove(self.db_path)
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        # Beispiel-Tabellen
        cursor.execute("""
        CREATE TABLE customers (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            city TEXT NOT NULL
        );
        """)

        cursor.execute("""
        CREATE TABLE orders (
            order_id INTEGER PRIMARY KEY AUTOINCREMENT,
            customer_id INTEGER NOT NULL,
            amount REAL NOT NULL,
            order_date TEXT NOT NULL,
            FOREIGN KEY(customer_id) REFERENCES customers(id)
        );
        """)

        cursor.execute("""
        CREATE TABLE products (
            product_id INTEGER PRIMARY KEY AUTOINCREMENT,
            name TEXT NOT NULL,
            price REAL NOT NULL,
            category TEXT NOT NULL
        );
        """)

        cursor.execute("""
        CREATE TABLE sales (
            sale_id INTEGER PRIMARY KEY AUTOINCREMENT,
            product_id INTEGER NOT NULL,
            customer_id INTEGER NOT NULL,
            quantity INTEGER NOT NULL,
            sale_date TEXT NOT NULL,
            FOREIGN KEY(product_id) REFERENCES products(product_id),
            FOREIGN KEY(customer_id) REFERENCES customers(id)
        );
        """)

        customers_data = [
            ("Alice", "Berlin"),
            ("Bob", "München"),
            ("Charlie", "Hamburg"),
            ("Diana", "Berlin")
        ]

        orders_data = [
            (1, 120.50, "2024-01-15"),
            (1, 89.99, "2024-02-10"),
            (2, 200.00, "2024-01-20"),
            (3, 35.20, "2024-03-05"),
            (3, 99.00, "2024-03-10"),
            (4, 55.55, "2024-02-25"),
            (1, 100.00, "2024-11-15"),
            (1, 100.00, "2024-12-05")
        ]

        products_data = [
            ("Laptop", 999.99, "Electronics"),
            ("Smartphone", 499.99, "Electronics"),
            ("Chair", 89.99, "Furniture"),
            ("Table", 150.00, "Furniture")
        ]

        sales_data = [
            (1, 2, 1, "2024-01-16"),
            (2, 3, 3, "2024-02-20"),
            (3, 1, 4, "2024-03-15"),
            (4, 4, 2, "2024-04-05")
        ]

        for c in customers_data:
            cursor.execute("INSERT INTO customers (name, city) VALUES (?, ?)", c)
        for o in orders_data:
            cursor.execute("INSERT INTO orders (customer_id, amount, order_date) VALUES (?, ?, ?)", o)
        for p in products_data:
            cursor.execute("INSERT INTO products (name, price, category) VALUES (?, ?, ?)", p)
        for s in sales_data:
            cursor.execute("INSERT INTO sales (product_id, customer_id, quantity, sale_date) VALUES (?, ?, ?, ?)", s)

        conn.commit()
        conn.close()
        logging.info("Beispieldatenbank wurde erstellt.")

    def extract_foreign_keys(self) -> Dict[str, List[Dict]]:
        inspector = inspect(self.engine)
        fk_info = {}
        for table_name in inspector.get_table_names():
            fks = inspector.get_foreign_keys(table_name)
            fk_info[table_name] = fks
        return fk_info

    def extract_full_schema_details(self):
        """
        Liefert ein detailliertes Schema zurück:
        {
          "table_name": {
             "columns": [
                {
                  "name": str,
                  "type": str (SQLAlchemy Type String),
                  "primary_key": bool,
                  "nullable": bool
                }
             ],
             "foreign_keys": [
                {
                  "constrained_columns": [...],
                  "referred_table": "...",
                  "referred_columns": [...]
                }
             ]
          },
          ...
        }
        """
        inspector = inspect(self.engine)
        schema_details = {}
        for table_name in inspector.get_table_names():
            columns = inspector.get_columns(table_name)
            fks = inspector.get_foreign_keys(table_name)
            col_details = []
            for col in columns:
                col_type = col['type']
                # Vereinfache Datentypen auf int, real, text
                # Dies ist ein Heuristikmapping
                typ_str = str(col_type).lower()
                if "int" in typ_str:
                    mapped_type = "int"
                elif "char" in typ_str or "text" in typ_str or "date" in typ_str:
                    mapped_type = "text"
                elif "real" in typ_str or "float" in typ_str or "double" in typ_str or "numeric" in typ_str:
                    mapped_type = "real"
                else:
                    mapped_type = "text"  # Fallback

                col_details.append({
                    "name": col['name'],
                    "type": mapped_type,
                    "primary_key": col['primary_key'],
                    "nullable": col['nullable']
                })
            schema_details[table_name] = {
                "columns": col_details,
                "foreign_keys": fks
            }
        return schema_details

class IntentExtractor:
    def extract_intent(self, user_input: str) -> QueryIntent:
        intent = QueryIntent()
        intent.original_question = user_input

        if re.search(r"\bgibt es\b|\bgab es\b", user_input.lower()):
            intent.is_existence_query = True

        if re.search(r"wie\s+viele", user_input.lower()):
            intent.is_count_query = True

        month_year_pattern = re.search(r"(januar|februar|märz|maerz|april|mai|juni|juli|august|september|oktober|november|dezember)\s+(\d{4})", user_input.lower())
        if month_year_pattern:
            month_str = month_year_pattern.group(1)
            year_str = month_year_pattern.group(2)
            months = {
                "januar": 1, "februar": 2, "märz": 3, "maerz": 3, "april": 4,
                "mai": 5, "juni": 6, "juli": 7, "august": 8, "september": 9,
                "oktober": 10, "november": 11, "dezember": 12
            }
            month_num = months[month_str]
            year_int = int(year_str)
            start_date = datetime.date(year_int, month_num, 1)
            if month_num == 12:
                end_date = datetime.date(year_int, 12, 31)
            else:
                end_date = datetime.date(year_int, month_num+1, 1) - datetime.timedelta(days=1)

            intent.time_range = (start_date, end_date)

        return intent

class SQLPostProcessor:
    def __init__(self, schema_metadata: Dict[str, List[str]], foreign_keys: Dict[str, List[Dict]]) -> None:
        self.schema_metadata = schema_metadata
        self.foreign_keys = foreign_keys

    def post_process(self, sql_query: str, intent: QueryIntent) -> str:
        # Hier könnte Validierung/Korrektur erfolgen.
        return sql_query

class HuggingFaceQueryGenerator:
    def __init__(self, model_name: str, schema_details: Dict[str, Dict]) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.schema_details = schema_details

    def generate_sql_query(self, intent: QueryIntent) -> str:
        schema_str = self._format_schema()
        additional_hints = ""
        if intent.time_range:
            start_date = intent.time_range[0].isoformat()
            end_date = intent.time_range[1].isoformat()
            additional_hints += f"Zeitraum: {start_date} bis {end_date}\n"
        if intent.is_existence_query:
            additional_hints += "Existenzfrage: nutze COUNT(*)\n"
        if intent.is_count_query:
            additional_hints += "Zählfrage: nutze COUNT(*)\n"

        prompt = f"""
{additional_hints.strip()}
Question: {intent.original_question}
{schema_str}
Füge niemals einen Join ein, der nicht nötig ist. Erfinde keine Datenbankkonfigurationen, die nicht im Schema definiert wurden.
""".strip()

        inputs = self.tokenizer.encode(prompt, return_tensors="pt")
        with torch.no_grad():
            outputs = self.model.generate(inputs, max_length=256, num_beams=4, early_stopping=True)
        sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return sql_query

    def _format_schema(self) -> str:
        """
        Formatiert das Schema dynamisch:
        Format: "table_name" "col" type , "col" type , foreign_key: ... primary key: ... [SEP]
        """
        # Identifiziere für jede Tabelle den Primary Key
        # Falls mehrere PKs: nimm den ersten. Falls keiner markiert ist, nimm die erste Spalte.
        schema_str = "Schema: "
        for table_name, info in self.schema_details.items():
            columns = info["columns"]
            fks = info["foreign_keys"]

            # Bestimme primary key
            pk_cols = [c["name"] for c in columns if c["primary_key"]]
            pk_str = "primary key: "
            if pk_cols:
                pk_str += f"\"{pk_cols[0]}\""
            else:
                # Falls kein Primärschlüssel markiert ist, nimm die erste Spalte
                pk_str += f"\"{columns[0]['name']}\""

            # Spaltenstring
            col_parts = []
            for c in columns:
                col_parts.append(f"\"{c['name']}\" {c['type']}")

            # Foreign keys:
            # Format: foreign_key: "col" from "ref_table" "ref_col"
            fk_parts = []
            for fk in fks:
                # Wir gehen davon aus, dass der FK immer 1:1 Spalten matched.
                if len(fk['constrained_columns']) == 1 and len(fk['referred_columns']) == 1:
                    fk_col = fk['constrained_columns'][0]
                    ref_tab = fk['referred_table']
                    ref_col = fk['referred_columns'][0]
                    fk_parts.append(f"\"{fk_col}\" from \"{ref_tab}\" \"{ref_col}\"")

            if fk_parts:
                fk_str = "foreign_key: " + ", ".join(fk_parts)
            else:
                fk_str = "foreign_key: "

            table_def = f"\"{table_name}\" {', '.join(col_parts)} , {fk_str} {pk_str} [SEP]"
            schema_str += table_def + " "
        return schema_str.strip()

class SQLExecutor:
    def __init__(self, engine) -> None:
        self.engine = engine

    def execute_query(self, sql_query: str):
        with self.engine.connect() as conn:
            result = conn.execute(text(sql_query))
            rows = result.fetchall()
            columns = result.keys()
        return rows, columns

class ResultFormatter:
    def format_result(self, rows: List[tuple], columns: List[str], intent: QueryIntent) -> str:
        row_count = len(rows)

        if intent.is_existence_query:
            count = rows[0][0] if row_count > 0 else 0
            if count > 0:
                return f"Ja, es gab {count} passende Datensätze."
            else:
                return "Nein, es gab keine passenden Datensätze."

        if intent.is_count_query:
            count = rows[0][0] if row_count > 0 else 0
            return f"Es gibt insgesamt {count} passende Datensätze."

        if row_count == 0:
            return "Es gibt keine passenden Datensätze."

        answer = f"Es wurden {row_count} Datensätze gefunden.\nBeispiel:\n"
        answer += " | ".join(columns) + "\n"
        max_rows_to_show = min(5, row_count)
        for i in range(max_rows_to_show):
            answer += " | ".join(str(cell) for cell in rows[i]) + "\n"
        return answer

class NaturalLanguageToSQLApp:
    def __init__(self, db_path: str, model_name: str) -> None:
        self.db_manager = DatabaseManager(db_path)
        self.db_manager.create_example_database()
        self.foreign_keys = self.db_manager.extract_foreign_keys()
        # Hier nutzen wir nun das full schema details:
        self.schema_details = self.db_manager.extract_full_schema_details()
        # Extrahieren auch ein vereinfachtes schema_metadata wenn benötigt:
        self.schema_metadata = {t: [c['name'] for c in self.schema_details[t]['columns']] for t in self.schema_details}

        self.intent_extractor = IntentExtractor()
        self.query_generator = HuggingFaceQueryGenerator(model_name, self.schema_details)
        self.post_processor = SQLPostProcessor(self.schema_metadata, self.foreign_keys)
        self.sql_executor = SQLExecutor(self.db_manager.engine)
        self.result_formatter = ResultFormatter()

    def run(self) -> None:
        print("Willkommen! Stellen Sie natürlichsprachliche Fragen zu den Daten.")
        print("Beispiel: 'Gab es Bestellungen im November 2024?' oder 'Wie viele Kunden kommen aus Berlin?'")
        print("Beenden mit: exit\r\n")

        while True:
            user_input = input("Ihre Frage: ").strip()
            if user_input.lower() in ["exit", "quit", "q"]:
                print("Programm wird beendet.")
                break
            if not user_input:
                continue

            try:
                intent = self.intent_extractor.extract_intent(user_input)
                sql_query = self.query_generator.generate_sql_query(intent)
                sql_query = self.post_processor.post_process(sql_query, intent)
                logging.info(f"Generiertes SQL: {sql_query}")

                rows, columns = self.sql_executor.execute_query(sql_query)
                answer = self.result_formatter.format_result(rows, columns, intent)
                print("Antwort:", answer)
            except SQLAlchemyError as e:
                logging.error(f"SQL-Fehler: {e}")
                print("Fehler bei der Ausführung der SQL-Abfrage:", str(e))
            except Exception as e:
                logging.error(f"Allgemeiner Fehler: {e}")
                print("Es ist ein Fehler aufgetreten. Bitte versuchen Sie es erneut.")


if __name__ == "__main__":
    app = NaturalLanguageToSQLApp(
        db_path=DB_PATH,
        model_name=MODEL_NAME
    )
    app.run()


Willkommen! Stellen Sie natürlichsprachliche Fragen zu den Daten.
Beispiel: 'Gab es Bestellungen im November 2024?' oder 'Wie viele Kunden kommen aus Berlin?'
Beenden mit: exit
Ihre Frage: Welche Kunden kommen aus Berlin? 
Antwort: Es wurden 2 Datensätze gefunden.
Beispiel:
name
Alice
Diana

Ihre Frage: Gab es im Dezember 2024 Bestellungen? 
Antwort: Nein, es gab keine passenden Datensätze.
Ihre Frage: Exit
Programm wird beendet.
