# 05 - Self-Correction : Corriger les erreurs SQL

**Progression** : Baseline 25% → Few-shot 58% → **Objectif 70%+**

**Approche** : Si le SQL échoue, demander au LLM de le corriger avec le message d'erreur

In [11]:
import sqlite3
import pandas as pd
import json
import requests
from pathlib import Path

DATA_DIR = Path('../data')
DB_PATH = DATA_DIR / 'database' / 'ecommerce.db'
conn = sqlite3.connect(DB_PATH)

with open(DATA_DIR / 'results' / 'test_questions.json', 'r') as f:
    TEST_QUESTIONS = json.load(f)

print(f'{len(TEST_QUESTIONS)} questions')

12 questions


In [2]:
SIMPLE_SCHEMA = '''Tables:
- customers(customer_id, customer_city, customer_state)
- orders(order_id, customer_id, order_status, order_purchase_timestamp, order_delivered_timestamp)
- order_items(order_id, order_item_id, product_id, price, freight_value)
- products(product_id, product_category, product_weight_g)
- payments(order_id, payment_sequential, payment_type, payment_installments, payment_value)
- reviews(review_id, order_id, review_score, review_comment_title, review_comment_message)'''

FEW_SHOT_EXAMPLES = '''
Example 1:
Q: How many orders are there?
SQL: SELECT COUNT(*) as total FROM orders;

Example 2:
Q: How many delivered orders?
SQL: SELECT COUNT(*) as total FROM orders WHERE order_status = 'delivered';

Example 3:
Q: Total revenue by product category?
SQL: SELECT p.product_category, SUM(oi.price) as revenue FROM order_items oi JOIN products p ON oi.product_id = p.product_id GROUP BY p.product_category ORDER BY revenue DESC;

Example 4:
Q: Average review score by city?
SQL: SELECT c.customer_city, ROUND(AVG(r.review_score), 2) as avg_score FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN reviews r ON o.order_id = r.order_id GROUP BY c.customer_city ORDER BY avg_score DESC;
'''

In [3]:
GENERATE_PROMPT = '''You are a SQLite expert for an e-commerce database.

{schema}

Rules:
- Use table aliases (c, o, oi, p, r)
- SQLite: use julianday() for dates

{examples}

Q: {question}
SQL:'''

CORRECTION_PROMPT = '''The following SQL query failed with an error.

Schema:
{schema}

Original question: {question}

Failed SQL:
{sql}

Error message:
{error}

Fix the SQL query. Return ONLY the corrected SQL, nothing else.
Corrected SQL:'''

In [4]:
import threading

OLLAMA_URL = 'http://localhost:11434/api/generate'
MODEL = 'mistral'

def call_llm(prompt: str) -> str:
    try:
        response = requests.post(
            OLLAMA_URL,
            json={'model': MODEL, 'prompt': prompt, 'stream': False,
                  'options': {'temperature': 0, 'num_predict': 250}},
            timeout=60
        )
        return response.json().get('response', '').strip()
    except Exception as e:
        return f'ERROR: {e}'

def extract_sql(response: str) -> str:
    sql = response.strip()
    if '```sql' in sql:
        sql = sql.split('```sql')[1].split('```')[0].strip()
    elif '```' in sql:
        sql = sql.split('```')[1].split('```')[0].strip()
    lines = [l for l in sql.split('\n') if l.strip() and not l.strip().startswith('--')]
    sql = ' '.join(lines)
    if ';' in sql:
        sql = sql.split(';')[0] + ';'
    return sql

def execute_sql(sql: str, conn, timeout: int = 10) -> tuple:
    """Exécute le SQL avec un timeout (en secondes) pour éviter les blocages."""
    def stop_query():
        conn.interrupt()
        
    # Timer pour arrêter la requête si elle est trop longue
    timer = threading.Timer(timeout, stop_query)
    timer.start()
    
    try:
        result = pd.read_sql(sql, conn)
        timer.cancel()
        return True, result, None
    except Exception as e:
        timer.cancel()
        return False, None, f"Error (possble timeout): {e}"


In [5]:
def generate_sql_with_correction(question: str, max_retries: int = 2) -> tuple:
    """Générer SQL avec auto-correction si erreur."""
    
    # Première génération
    prompt = GENERATE_PROMPT.format(
        schema=SIMPLE_SCHEMA,
        examples=FEW_SHOT_EXAMPLES,
        question=question
    )
    response = call_llm(prompt)
    sql = extract_sql(response)
    
    # Tester
    success, result, error = execute_sql(sql, conn)
    
    if success:
        return sql, result, 0  # 0 corrections
    
    # Auto-correction
    for attempt in range(max_retries):
        correction_prompt = CORRECTION_PROMPT.format(
            schema=SIMPLE_SCHEMA,
            question=question,
            sql=sql,
            error=error
        )
        response = call_llm(correction_prompt)
        sql = extract_sql(response)
        
        success, result, error = execute_sql(sql, conn)
        
        if success:
            return sql, result, attempt + 1
    
    return sql, None, max_retries  # Échec après corrections

# Test
test_sql, test_res, corrections = generate_sql_with_correction('Combien de commandes ?')
print(f'SQL: {test_sql}')
print(f'Corrections: {corrections}')

SQL: SELECT COUNT(*) as total FROM orders;
Corrections: 0


In [6]:
def compare_results(r1: pd.DataFrame, r2: pd.DataFrame) -> bool:
    try:
        if r1 is None or r2 is None:
            return False
        if r1.empty and r2.empty:
            return True
        if r1.shape[0] != r2.shape[0]:
            return False
        v1 = sorted([str(x) for x in r1.iloc[:, -1].tolist()])
        v2 = sorted([str(x) for x in r2.iloc[:, -1].tolist()])
        if v1 == v2:
            return True
        try:
            n1 = sorted([float(x) for x in r1.iloc[:, -1].tolist()])
            n2 = sorted([float(x) for x in r2.iloc[:, -1].tolist()])
            return all(abs(a - b) < 1 for a, b in zip(n1, n2))
        except:
            return False
    except:
        return False

In [12]:
print('ÉVALUATION AVEC SELF-CORRECTION')
print('='*60)

results = []
total_corrections = 0

for i, q in enumerate(TEST_QUESTIONS):
    question = q['question']
    expected_sql = q['sql']
    difficulty = q['difficulty']
    
    # Generate with correction
    generated_sql, gen_result, num_corrections = generate_sql_with_correction(question)
    total_corrections += num_corrections
    
    # Execute expected
    exp_ok, exp_result, _ = execute_sql(expected_sql, conn)
    
    # Compare
    gen_ok = gen_result is not None
    ex = compare_results(gen_result, exp_result) if gen_ok else False
    
    results.append({
        'question': question,
        'difficulty': difficulty,
        'expected_sql': expected_sql,
        'generated_sql': generated_sql,
        'exec_success': gen_ok,
        'exec_accuracy': ex,
        'corrections': num_corrections
    })
    
    status = 'OK' if ex else ('WARN' if gen_ok else 'FAIL')
    corr_str = f' (corrected x{num_corrections})' if num_corrections > 0 else ''
    print(f'[{i+1:2d}] {status} [{difficulty:7s}] {question[:35]}...{corr_str}')

ÉVALUATION AVEC SELF-CORRECTION
[ 1] OK [simple ] Combien de commandes au total ?...
[ 2] OK [simple ] Combien de clients différents ?...
[ 3] OK [simple ] Combien de commandes livrées ?...
[ 4] OK [simple ] Quel est le nombre de produits dans...
[ 5] OK [medium ] Quelles sont les 5 villes avec le p...
[ 6] WARN [medium ] Quel est le revenue total par catég...
[ 7] OK [medium ] Quelle est la note moyenne par caté...
[ 8] WARN [medium ] Combien de commandes par méthode de...
[ 9] OK [complex] Quels clients ont passé plus de 2 c...
[10] WARN [complex] Quel est le panier moyen par ville ...
[11] FAIL [complex] Quels produits ont une note moyenne... (corrected x2)
[12] FAIL [complex] Quel est le délai moyen de livraiso... (corrected x2)


In [8]:
df = pd.DataFrame(results)
total = len(df)
success = df['exec_success'].sum()
accuracy = df['exec_accuracy'].sum()

print('\n' + '='*60)
print('RÉSULTATS AVEC SELF-CORRECTION')
print('='*60)
print(f'Execution Success: {success}/{total} ({success/total:.0%})')
print(f'Execution Accuracy: {accuracy}/{total} ({accuracy/total:.0%})')
print(f'Total corrections: {total_corrections}')

print('\nPar difficulté:')
for diff in ['simple', 'medium', 'complex']:
    sub = df[df['difficulty'] == diff]
    n = len(sub)
    ex = sub['exec_accuracy'].sum()
    print(f'  {diff:8s}: {ex}/{n} ({ex/n:.0%})')

print('\n' + '='*60)
print('PROGRESSION')
print('='*60)
print(f'Baseline:        25% EX')
print(f'Few-shot:        58% EX')
print(f'Self-correction: {accuracy/total:.0%} EX')


RÉSULTATS AVEC SELF-CORRECTION
Execution Success: 10/12 (83%)
Execution Accuracy: 7/12 (58%)
Total corrections: 4

Par difficulté:
  simple  : 4/4 (100%)
  medium  : 2/4 (50%)
  complex : 1/4 (25%)

PROGRESSION
Baseline:        25% EX
Few-shot:        58% EX
Self-correction: 58% EX


In [None]:
# Sauvegarder
correction_results = {
    'model': MODEL,
    'prompt_type': 'few-shot + self-correction',
    'max_retries': 2,
    'metrics': {
        'execution_success': success / total,
        'execution_accuracy': accuracy / total,
        'total_corrections': total_corrections
    },
    'by_difficulty': {d: float(df[df['difficulty']==d]['exec_accuracy'].mean()) 
                     for d in ['simple', 'medium', 'complex']},
    'progression': {
        'baseline': 0.25,
        'fewshot': 0.58,
        'self_correction': accuracy / total
    }
}

with open(DATA_DIR / 'results' / 'correction_results.json', 'w') as f:
    json.dump(correction_results, f, indent=2)

print('Sauvegardé')
conn.close()

✓ Sauvegardé
