# 03 - Schema RAG : Améliorer avec le contexte

**Objectif** : Améliorer l'accuracy en donnant plus de contexte au LLM

**Améliorations** :
1. Schéma enrichi avec exemples de données
2. Règles SQLite spécifiques
3. Instructions plus strictes

**Baseline** : 25% EX → **Objectif : 50%+**

In [1]:
import sqlite3
import pandas as pd
import json
import requests
from pathlib import Path

DATA_DIR = Path('../data')
DB_PATH = DATA_DIR / 'database' / 'ecommerce.db'
conn = sqlite3.connect(DB_PATH)

# Charger test set
with open(DATA_DIR / 'results' / 'test_questions.json', 'r') as f:
    TEST_QUESTIONS = json.load(f)

print(f'✓ Database connectée')
print(f'✓ {len(TEST_QUESTIONS)} questions chargées')

✓ Database connectée
✓ 12 questions chargées


## 1. Schéma enrichi avec exemples

In [2]:
def get_enriched_schema(conn) -> str:
    """Créer un schéma enrichi avec exemples de données."""
    
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = [t[0] for t in cursor.fetchall()]
    
    schema_parts = []
    
    for table in tables:
        # Get columns
        cursor.execute(f"PRAGMA table_info({table})")
        columns = cursor.fetchall()
        
        # Get sample data
        sample = pd.read_sql(f"SELECT * FROM {table} LIMIT 3", conn)
        
        # Get row count
        cursor.execute(f"SELECT COUNT(*) FROM {table}")
        count = cursor.fetchone()[0]
        
        # Build schema
        cols_str = []
        for col in columns:
            col_name = col[1]
            col_type = col[2] or 'TEXT'
            # Get sample values
            sample_vals = sample[col_name].dropna().head(2).tolist()
            sample_str = ', '.join([repr(v)[:30] for v in sample_vals])
            cols_str.append(f"  {col_name} {col_type}  -- ex: {sample_str}")
        
        table_schema = f"""-- {table} ({count} rows)
CREATE TABLE {table} (
{chr(10).join(cols_str)}
);"""
        schema_parts.append(table_schema)
    
    return "\n\n".join(schema_parts)

ENRICHED_SCHEMA = get_enriched_schema(conn)
print(ENRICHED_SCHEMA)

-- customers (99441 rows)
CREATE TABLE customers (
  customer_id TEXT  -- ex: '06b8999e2fba1a1fbc88172c00ba8, '18955e83d337fd6b2def6b18a428a
  customer_city TEXT  -- ex: 'franca', 'sao bernardo do campo'
  customer_state TEXT  -- ex: 'SP', 'SP'
);

-- products (32951 rows)
CREATE TABLE products (
  product_id TEXT  -- ex: '1e9e8ef04dbcff4541ed26657ea51, '3aa071139cb16b67ca9e5dea641aa
  product_category TEXT  -- ex: 'perfumaria', 'artes'
  product_weight_g REAL  -- ex: 225.0, 1000.0
  product_length_cm REAL  -- ex: 16.0, 30.0
  product_height_cm REAL  -- ex: 10.0, 18.0
  product_width_cm REAL  -- ex: 14.0, 20.0
);

-- orders (99441 rows)
CREATE TABLE orders (
  order_id TEXT  -- ex: 'e481f51cbdc54678b7cc49136f2d6, '53cdb2fc8bc7dce0b6741e2150273
  customer_id TEXT  -- ex: '9ef432eb6251297304e76186b10a9, 'b0830fb4747a6c6d20dea0b8c802d
  order_status TEXT  -- ex: 'delivered', 'delivered'
  order_purchase_timestamp TEXT  -- ex: '2017-10-02 10:56:33', '2018-07-24 20:41:37'
  order_delivered_

## 2. Prompt amélioré avec règles SQLite

In [3]:
IMPROVED_PROMPT = '''You are a SQL expert. Generate SQLite queries for an e-commerce database.

DATABASE SCHEMA:
{schema}

RELATIONSHIPS:
- customers.customer_id -> orders.customer_id
- orders.order_id -> order_items.order_id
- orders.order_id -> payments.order_id
- orders.order_id -> reviews.order_id
- products.product_id -> order_items.product_id

SQLITE RULES (IMPORTANT):
- Use julianday() for date differences, NOT DATEDIFF
- Always use table aliases to avoid ambiguous columns
- Use ROUND() for decimal formatting
- String comparison is case-sensitive

QUESTION: {question}

Return ONLY the SQL query. No explanation.'''

print('Prompt template créé')
print(f'Longueur: {len(IMPROVED_PROMPT)} chars')

Prompt template créé
Longueur: 604 chars


In [4]:
# Configuration LLM
OLLAMA_URL = 'http://localhost:11434/api/generate'
MODEL = 'mistral'

def call_llm(prompt: str, temperature: float = 0.0) -> str:
    """Appeler le LLM via Ollama."""
    try:
        response = requests.post(
            OLLAMA_URL,
            json={
                'model': MODEL,
                'prompt': prompt,
                'stream': False,
                'options': {'temperature': temperature, 'num_predict': 300}
            },
            timeout=60
        )
        return response.json().get('response', '').strip()
    except Exception as e:
        return f'ERROR: {e}'

def extract_sql(response: str) -> str:
    """Extraire le SQL de la réponse."""
    sql = response.strip()
    if '```sql' in sql:
        sql = sql.split('```sql')[1].split('```')[0].strip()
    elif '```' in sql:
        sql = sql.split('```')[1].split('```')[0].strip()
    return sql

def generate_sql_improved(question: str) -> str:
    """Générer SQL avec prompt amélioré."""
    prompt = IMPROVED_PROMPT.format(schema=ENRICHED_SCHEMA, question=question)
    response = call_llm(prompt)
    return extract_sql(response)

# Test
test_sql = generate_sql_improved('Combien de commandes au total ?')
print(f'Test SQL:\n{test_sql}')

Test SQL:
SELECT COUNT(DISTINCT orders.order_id) FROM orders;


## 3. Fonctions d'évaluation

In [5]:
def execute_sql(sql: str, conn) -> tuple:
    """Exécuter SQL et retourner (success, result)."""
    try:
        result = pd.read_sql(sql, conn)
        return True, result
    except Exception as e:
        return False, str(e)

def compare_results(result1: pd.DataFrame, result2: pd.DataFrame) -> bool:
    """Comparer deux résultats."""
    try:
        if result1.shape != result2.shape:
            return False
        r1 = result1.sort_values(by=result1.columns[0]).reset_index(drop=True)
        r2 = result2.sort_values(by=result2.columns[0]).reset_index(drop=True)
        
        # Compare values with tolerance for floats
        for col in r1.columns:
            if r1[col].dtype in ['float64', 'float32']:
                if not all(abs(r1[col] - r2[col]) < 0.01):
                    return False
            else:
                if not r1[col].equals(r2[col]):
                    # Try string comparison
                    if not all(str(a) == str(b) for a, b in zip(r1[col], r2[col])):
                        return False
        return True
    except:
        return False

def evaluate_question(question_data: dict, generate_fn, conn) -> dict:
    """Évaluer une question."""
    question = question_data['question']
    expected_sql = question_data['sql']
    difficulty = question_data['difficulty']
    
    generated_sql = generate_fn(question)
    gen_success, gen_result = execute_sql(generated_sql, conn)
    exp_success, exp_result = execute_sql(expected_sql, conn)
    
    ex = False
    if gen_success and exp_success:
        ex = compare_results(gen_result, exp_result)
    
    return {
        'question': question,
        'difficulty': difficulty,
        'expected_sql': expected_sql,
        'generated_sql': generated_sql,
        'execution_success': gen_success,
        'execution_accuracy': ex,
        'error': None if gen_success else gen_result
    }

print('✓ Fonctions d\'évaluation prêtes')

✓ Fonctions d'évaluation prêtes


## 4. Évaluation avec prompt amélioré

In [6]:
print('ÉVALUATION - PROMPT AMÉLIORÉ')
print('='*60)
print(f'Améliorations: schéma enrichi + règles SQLite + relations')
print('='*60)

results_improved = []
for i, q in enumerate(TEST_QUESTIONS):
    result = evaluate_question(q, generate_sql_improved, conn)
    results_improved.append(result)
    
    status_ex = '✓' if result['execution_accuracy'] else '✗'
    status_run = '✓' if result['execution_success'] else '✗'
    
    print(f'[{i+1}/{len(TEST_QUESTIONS)}] [{result["difficulty"]}] {status_run} run | {status_ex} EX')
    print(f'   Q: {result["question"][:50]}...')
    if not result['execution_success']:
        print(f'   Error: {str(result["error"])[:60]}...')
    print()

ÉVALUATION - PROMPT AMÉLIORÉ
Améliorations: schéma enrichi + règles SQLite + relations
[1/12] [simple] ✗ run | ✗ EX
   Q: Combien de commandes au total ?...
   Error: Execution failed on sql 'SQL
SELECT COUNT(DISTINCT orders.or...

[2/12] [simple] ✗ run | ✗ EX
   Q: Combien de clients différents ?...
   Error: Execution failed on sql 'SQL
SELECT COUNT(DISTINCT c.custome...

[3/12] [simple] ✗ run | ✗ EX
   Q: Combien de commandes livrées ?...
   Error: Execution failed on sql 'SQL
SELECT COUNT(DISTINCT o.order_i...

[4/12] [simple] ✓ run | ✗ EX
   Q: Quel est le nombre de produits dans le catalogue ?...

[5/12] [medium] ✓ run | ✗ EX
   Q: Quelles sont les 5 villes avec le plus de clients ...

[6/12] [medium] ✓ run | ✗ EX
   Q: Quel est le revenue total par catégorie de produit...

[7/12] [medium] ✓ run | ✗ EX
   Q: Quelle est la note moyenne par catégorie de produi...

[8/12] [medium] ✓ run | ✗ EX
   Q: Combien de commandes par méthode de paiement ?...

[9/12] [complex] ✓ run | ✗ EX
   

In [7]:
# Résultats
df_improved = pd.DataFrame(results_improved)

total = len(df_improved)
exec_success = df_improved['execution_success'].sum()
ex_correct = df_improved['execution_accuracy'].sum()

print('\n' + '='*60)
print('RÉSULTATS - PROMPT AMÉLIORÉ')
print('='*60)
print(f'Execution Success: {exec_success}/{total} ({exec_success/total:.1%})')
print(f'Execution Accuracy (EX): {ex_correct}/{total} ({ex_correct/total:.1%})')

print('\nPar difficulté:')
for diff in ['simple', 'medium', 'complex']:
    subset = df_improved[df_improved['difficulty'] == diff]
    n = len(subset)
    if n > 0:
        ex = subset['execution_accuracy'].sum()
        print(f'{diff.upper():8} | EX: {ex}/{n} ({ex/n:.0%})')

# Comparaison avec baseline
print('\n' + '='*60)
print('COMPARAISON AVEC BASELINE')
print('='*60)
print(f'Baseline EX: 25% → Amélioré EX: {ex_correct/total:.0%}')
print(f'Amélioration: +{(ex_correct/total - 0.25)*100:.0f} points')


RÉSULTATS - PROMPT AMÉLIORÉ
Execution Success: 7/12 (58.3%)
Execution Accuracy (EX): 0/12 (0.0%)

Par difficulté:
SIMPLE   | EX: 0/4 (0%)
MEDIUM   | EX: 0/4 (0%)
COMPLEX  | EX: 0/4 (0%)

COMPARAISON AVEC BASELINE
Baseline EX: 25% → Amélioré EX: 0%
Amélioration: +-25 points


## 5. Analyse des erreurs restantes

In [8]:
# Erreurs d'exécution
errors = df_improved[~df_improved['execution_success']]
print(f'Erreurs d\'exécution: {len(errors)}')

for _, row in errors.iterrows():
    print(f'\n{"="*60}')
    print(f'Q: {row["question"]}')
    print(f'Difficulty: {row["difficulty"]}')
    print(f'Generated: {row["generated_sql"][:100]}...')
    print(f'Error: {str(row["error"])[:100]}')

Erreurs d'exécution: 5

Q: Combien de commandes au total ?
Difficulty: simple
Generated: SQL
SELECT COUNT(DISTINCT orders.order_id) FROM orders;...
Error: Execution failed on sql 'SQL
SELECT COUNT(DISTINCT orders.order_id) FROM orders;': near "SQL": synta

Q: Combien de clients différents ?
Difficulty: simple
Generated: SQL
SELECT COUNT(DISTINCT c.customer_id) FROM customers AS c;...
Error: Execution failed on sql 'SQL
SELECT COUNT(DISTINCT c.customer_id) FROM customers AS c;': near "SQL":

Q: Combien de commandes livrées ?
Difficulty: simple
Generated: SQL
SELECT COUNT(DISTINCT o.order_id) FROM orders AS o
JOIN order_items AS oi ON o.order_id = oi.ord...
Error: Execution failed on sql 'SQL
SELECT COUNT(DISTINCT o.order_id) FROM orders AS o
JOIN order_items AS 

Q: Quel est le panier moyen par ville ?
Difficulty: complex
Generated: SELECT c.customer_city, AVG(SUM(oi.price) + SUM(oi.freight_value)) as average_cart
FROM customers AS...
Error: Execution failed on sql 'SELECT c.customer_ci

In [9]:
# Mauvais résultats
wrong = df_improved[df_improved['execution_success'] & ~df_improved['execution_accuracy']]
print(f'\nRésultats incorrects: {len(wrong)}')

for _, row in wrong.iterrows():
    print(f'\n{"="*60}')
    print(f'Q: {row["question"]}')
    print(f'\nExpected: {row["expected_sql"][:80]}...')
    print(f'\nGenerated: {row["generated_sql"][:80]}...')
    
    # Montrer la différence de résultats
    _, exp_res = execute_sql(row['expected_sql'], conn)
    _, gen_res = execute_sql(row['generated_sql'], conn)
    print(f'\nExpected result: {exp_res.head(2).to_dict()}')
    print(f'Generated result: {gen_res.head(2).to_dict()}')


Résultats incorrects: 7

Q: Quel est le nombre de produits dans le catalogue ?

Expected: SELECT COUNT(*) as total FROM products...

Generated: SELECT COUNT(DISTINCT products.product_id) FROM products;...

Expected result: {'total': {0: 32951}}
Generated result: {'COUNT(DISTINCT products.product_id)': {0: 32951}}

Q: Quelles sont les 5 villes avec le plus de clients ?

Expected: SELECT customer_city, COUNT(*) as num_customers 
                  FROM customer...

Generated: SELECT customer_city, COUNT(*) as number_of_customers
FROM customers c
JOIN orde...

Expected result: {'customer_city': {0: 'sao paulo', 1: 'rio de janeiro'}, 'num_customers': {0: 15540, 1: 6882}}
Generated result: {'customer_city': {0: 'sao paulo', 1: 'rio de janeiro'}, 'number_of_customers': {0: 15540, 1: 6882}}

Q: Quel est le revenue total par catégorie de produit ?

Expected: SELECT p.product_category, SUM(oi.price) as revenue
                  FROM order...

Generated: SELECT
    p.product_category,
    SUM(oi

## 6. Sauvegarder

In [10]:
improved_results = {
    'model': MODEL,
    'prompt_type': 'improved (enriched schema + SQLite rules)',
    'total_questions': total,
    'metrics': {
        'execution_success': exec_success / total,
        'execution_accuracy': ex_correct / total
    },
    'by_difficulty': {},
    'comparison': {
        'baseline_ex': 0.25,
        'improved_ex': ex_correct / total,
        'improvement': (ex_correct / total) - 0.25
    }
}

for diff in ['simple', 'medium', 'complex']:
    subset = df_improved[df_improved['difficulty'] == diff]
    n = len(subset)
    if n > 0:
        improved_results['by_difficulty'][diff] = {
            'total': n,
            'execution_accuracy': subset['execution_accuracy'].sum() / n
        }

with open(DATA_DIR / 'results' / 'improved_results.json', 'w') as f:
    json.dump(improved_results, f, indent=2)

print(f'✓ Résultats sauvegardés')

✓ Résultats sauvegardés


## Résumé

| Métrique | Baseline | Amélioré |
|----------|----------|----------|
| Exec Success | 75% | ?% |
| Exec Accuracy | 25% | ?% |

### Prochaine étape :
**Notebook 04** : Ajouter des exemples few-shot pour encore améliorer

In [None]:
conn.close()
print('Connexion fermée.')