# 02 - Baseline LLM : Zero-Shot Text-to-SQL

**Objectif** : Établir une baseline en testant le LLM sans optimisation

**Approche** : Zero-shot = on donne juste le schéma + la question

**Métriques** :
- **Execution Accuracy (EX)** : Le SQL s'exécute et donne le bon résultat
- **Exact Match (EM)** : Le SQL généré = SQL attendu (après normalisation)

In [1]:
import sqlite3
import pandas as pd
import json
import requests
import re
from pathlib import Path

# Paths
DATA_DIR = Path('../data')
DB_PATH = DATA_DIR / 'database' / 'ecommerce.db'

# Vérifier que la DB existe
assert DB_PATH.exists(), f'DB not found: {DB_PATH}'
print(f'✓ Database: {DB_PATH}')

✓ Database: ../data/database/ecommerce.db


In [2]:
# Charger le schéma
with open(DATA_DIR / 'database' / 'schema.sql', 'r') as f:
    SCHEMA = f.read()

print('SCHÉMA DDL')
print('='*60)
print(SCHEMA)

SCHÉMA DDL
CREATE TABLE customers (
  customer_id TEXT,
  customer_city TEXT,
  customer_state TEXT
);

CREATE TABLE products (
  product_id TEXT,
  product_category TEXT,
  product_weight_g REAL,
  product_length_cm REAL,
  product_height_cm REAL,
  product_width_cm REAL
);

CREATE TABLE orders (
  order_id TEXT,
  customer_id TEXT,
  order_status TEXT,
  order_purchase_timestamp TEXT,
  order_delivered_timestamp TEXT
);

CREATE TABLE order_items (
  order_id TEXT,
  order_item_id INTEGER,
  product_id TEXT,
  price REAL,
  freight_value REAL
);

CREATE TABLE payments (
  order_id TEXT,
  payment_sequential INTEGER,
  payment_type TEXT,
  payment_installments INTEGER,
  payment_value REAL
);

CREATE TABLE reviews (
  review_id TEXT,
  order_id TEXT,
  review_score INTEGER,
  review_comment_title TEXT,
  review_comment_message TEXT
);


In [3]:
# Charger le test set
with open(DATA_DIR / 'results' / 'test_questions.json', 'r') as f:
    TEST_QUESTIONS = json.load(f)

print(f'Test set: {len(TEST_QUESTIONS)} questions')
for q in TEST_QUESTIONS[:3]:
    print(f'  [{q["difficulty"]}] {q["question"]}')

Test set: 12 questions
  [simple] Combien de commandes au total ?
  [simple] Combien de clients différents ?
  [simple] Combien de commandes livrées ?


In [4]:
# Connexion DB
conn = sqlite3.connect(DB_PATH)
print('✓ Connecté à la base')

✓ Connecté à la base


## 1. Client LLM (Ollama)

In [5]:
# Configuration
OLLAMA_URL = 'http://localhost:11434/api/generate'
MODEL = 'mistral'

def call_llm(prompt: str, temperature: float = 0.0) -> str:
    """Appeler le LLM via Ollama."""
    try:
        response = requests.post(
            OLLAMA_URL,
            json={
                'model': MODEL,
                'prompt': prompt,
                'stream': False,
                'options': {
                    'temperature': temperature,
                    'num_predict': 300
                }
            },
            timeout=60
        )
        return response.json().get('response', '').strip()
    except Exception as e:
        return f'ERROR: {e}'

# Test
test_response = call_llm('Say hello in SQL style')
print(f'LLM Test: {test_response[:100]}')

LLM Test: In SQL, we don't have a direct command to say "hello", but you can create a table and insert a row w


## 2. Prompt Zero-Shot (Baseline)

In [6]:
# Prompt template baseline (minimal)
BASELINE_PROMPT = '''Given the following SQL schema:

{schema}

Write a SQL query to answer this question:
{question}

Return ONLY the SQL query, nothing else.'''

def generate_sql_baseline(question: str) -> str:
    """Générer SQL avec prompt baseline."""
    prompt = BASELINE_PROMPT.format(schema=SCHEMA, question=question)
    response = call_llm(prompt)
    
    # Extraire le SQL (enlever markdown si présent)
    sql = response.strip()
    
    # Nettoyer les balises markdown
    if '```sql' in sql:
        sql = sql.split('```sql')[1].split('```')[0].strip()
    elif '```' in sql:
        sql = sql.split('```')[1].split('```')[0].strip()
    
    return sql

# Test
test_sql = generate_sql_baseline('Combien de commandes au total ?')
print(f'Generated SQL:\n{test_sql}')

Generated SQL:
SELECT COUNT(DISTINCT order_id) FROM orders;


## 3. Métriques d'évaluation

In [7]:
def normalize_sql(sql: str) -> str:
    """Normaliser SQL pour comparaison."""
    # Lowercase
    sql = sql.lower()
    # Remove extra whitespace
    sql = ' '.join(sql.split())
    # Remove trailing semicolon
    sql = sql.rstrip(';')
    return sql

def execute_sql(sql: str, conn) -> tuple:
    """Exécuter SQL et retourner (success, result)."""
    try:
        result = pd.read_sql(sql, conn)
        return True, result
    except Exception as e:
        return False, str(e)

def compare_results(result1: pd.DataFrame, result2: pd.DataFrame) -> bool:
    """Comparer deux résultats de requêtes."""
    try:
        # Même shape
        if result1.shape != result2.shape:
            return False
        
        # Comparer les valeurs (tri par première colonne)
        r1 = result1.sort_values(by=result1.columns[0]).reset_index(drop=True)
        r2 = result2.sort_values(by=result2.columns[0]).reset_index(drop=True)
        
        # Comparer
        return r1.equals(r2) or r1.values.tolist() == r2.values.tolist()
    except:
        return False

def exact_match(sql1: str, sql2: str) -> bool:
    """Vérifier si deux SQL sont identiques (après normalisation)."""
    return normalize_sql(sql1) == normalize_sql(sql2)

print('✓ Fonctions de métriques définies')

✓ Fonctions de métriques définies


## 4. Évaluation Baseline

In [8]:
def evaluate_question(question_data: dict, conn) -> dict:
    """Évaluer une question."""
    question = question_data['question']
    expected_sql = question_data['sql']
    difficulty = question_data['difficulty']
    
    # Générer SQL
    generated_sql = generate_sql_baseline(question)
    
    # Exécuter SQL généré
    gen_success, gen_result = execute_sql(generated_sql, conn)
    
    # Exécuter SQL attendu
    exp_success, exp_result = execute_sql(expected_sql, conn)
    
    # Métriques
    em = exact_match(generated_sql, expected_sql)
    ex = False
    if gen_success and exp_success:
        ex = compare_results(gen_result, exp_result)
    
    return {
        'question': question,
        'difficulty': difficulty,
        'expected_sql': expected_sql,
        'generated_sql': generated_sql,
        'execution_success': gen_success,
        'exact_match': em,
        'execution_accuracy': ex,
        'error': None if gen_success else gen_result
    }

print('✓ Fonction d\'évaluation définie')

✓ Fonction d'évaluation définie


In [9]:
# Évaluer toutes les questions
print('ÉVALUATION BASELINE')
print('='*60)
print(f'Model: {MODEL}')
print(f'Prompt: Zero-shot (schéma + question)')
print(f'Questions: {len(TEST_QUESTIONS)}')
print('='*60)

results = []
for i, q in enumerate(TEST_QUESTIONS):
    result = evaluate_question(q, conn)
    results.append(result)
    
    # Status
    status_ex = '✓' if result['execution_accuracy'] else '✗'
    status_em = '✓' if result['exact_match'] else '✗'
    status_run = '✓' if result['execution_success'] else '✗'
    
    print(f'[{i+1}/{len(TEST_QUESTIONS)}] [{result["difficulty"]}] {status_run} run | {status_ex} EX | {status_em} EM')
    print(f'   Q: {result["question"][:50]}...')
    if not result['execution_success']:
        print(f'   Error: {result["error"][:80]}...')
    print()

ÉVALUATION BASELINE
Model: mistral
Prompt: Zero-shot (schéma + question)
Questions: 12
[1/12] [simple] ✓ run | ✓ EX | ✗ EM
   Q: Combien de commandes au total ?...

[2/12] [simple] ✓ run | ✗ EX | ✗ EM
   Q: Combien de clients différents ?...

[3/12] [simple] ✓ run | ✗ EX | ✗ EM
   Q: Combien de commandes livrées ?...

[4/12] [simple] ✓ run | ✓ EX | ✗ EM
   Q: Quel est le nombre de produits dans le catalogue ?...

[5/12] [medium] ✗ run | ✗ EX | ✗ EM
   Q: Quelles sont les 5 villes avec le plus de clients ...
   Error: Execution failed on sql 'SELECT customer_city, COUNT(customer_id) as number_of_c...

[6/12] [medium] ✓ run | ✓ EX | ✗ EM
   Q: Quel est le revenue total par catégorie de produit...

[7/12] [medium] ✓ run | ✗ EX | ✗ EM
   Q: Quelle est la note moyenne par catégorie de produi...

[8/12] [medium] ✓ run | ✗ EX | ✗ EM
   Q: Combien de commandes par méthode de paiement ?...

[9/12] [complex] ✓ run | ✗ EX | ✗ EM
   Q: Quels clients ont passé plus de 2 commandes ?...

[10/12] [com

In [10]:
# Calculer les métriques globales
df_results = pd.DataFrame(results)

# Métriques globales
total = len(df_results)
exec_success = df_results['execution_success'].sum()
ex_correct = df_results['execution_accuracy'].sum()
em_correct = df_results['exact_match'].sum()

print('\n' + '='*60)
print('RÉSULTATS BASELINE')
print('='*60)
print(f'Execution Success: {exec_success}/{total} ({exec_success/total:.1%})')
print(f'Execution Accuracy (EX): {ex_correct}/{total} ({ex_correct/total:.1%})')
print(f'Exact Match (EM): {em_correct}/{total} ({em_correct/total:.1%})')


RÉSULTATS BASELINE
Execution Success: 9/12 (75.0%)
Execution Accuracy (EX): 3/12 (25.0%)
Exact Match (EM): 0/12 (0.0%)


In [11]:
# Métriques par difficulté
print('\nPar difficulté:')
print('-'*40)

for diff in ['simple', 'medium', 'complex']:
    subset = df_results[df_results['difficulty'] == diff]
    n = len(subset)
    if n > 0:
        ex = subset['execution_accuracy'].sum()
        em = subset['exact_match'].sum()
        print(f'{diff.upper():8} | EX: {ex}/{n} ({ex/n:.0%}) | EM: {em}/{n} ({em/n:.0%})')


Par difficulté:
----------------------------------------
SIMPLE   | EX: 2/4 (50%) | EM: 0/4 (0%)
MEDIUM   | EX: 1/4 (25%) | EM: 0/4 (0%)
COMPLEX  | EX: 0/4 (0%) | EM: 0/4 (0%)


## 5. Analyse des erreurs

In [12]:
# Voir les erreurs
errors = df_results[~df_results['execution_success']]
print(f'Erreurs d\'exécution: {len(errors)}')

for _, row in errors.iterrows():
    print(f'\n{"="*60}')
    print(f'Question: {row["question"]}')
    print(f'Difficulty: {row["difficulty"]}')
    print(f'\nGenerated SQL:')
    print(row['generated_sql'])
    print(f'\nError: {row["error"]}')

Erreurs d'exécution: 3

Question: Quelles sont les 5 villes avec le plus de clients ?
Difficulty: medium

Generated SQL:
SELECT customer_city, COUNT(customer_id) as number_of_customers
FROM customers
JOIN orders ON customers.customer_id = orders.customer_id
GROUP BY customer_city
ORDER BY number_of_customers DESC
LIMIT 5;

Error: Execution failed on sql 'SELECT customer_city, COUNT(customer_id) as number_of_customers
FROM customers
JOIN orders ON customers.customer_id = orders.customer_id
GROUP BY customer_city
ORDER BY number_of_customers DESC
LIMIT 5;': ambiguous column name: customer_id

Question: Quel est le panier moyen par ville ?
Difficulty: complex

Generated SQL:
SELECT customer_city, AVG(total_price) as average_cart
FROM (
  SELECT o.customer_id, SUM(oi.price + oi.freight_value) as total_price
  FROM orders o
  JOIN order_items oi ON o.order_id = oi.order_id
  GROUP BY o.customer_id, o.customer_city
) as carts
GROUP BY customer_city;

Error: Execution failed on sql 'SELECT cu

In [13]:
# Voir les EX incorrects (SQL valide mais mauvais résultat)
wrong_results = df_results[df_results['execution_success'] & ~df_results['execution_accuracy']]
print(f'Résultats incorrects: {len(wrong_results)}')

for _, row in wrong_results.iterrows():
    print(f'\n{"="*60}')
    print(f'Question: {row["question"]}')
    print(f'\nExpected SQL:')
    print(row['expected_sql'])
    print(f'\nGenerated SQL:')
    print(row['generated_sql'])

Résultats incorrects: 6

Question: Combien de clients différents ?

Expected SQL:
SELECT COUNT(DISTINCT customer_id) as total FROM customers

Generated SQL:
SELECT DISTINCT customer_id FROM customers;

Question: Combien de commandes livrées ?

Expected SQL:
SELECT COUNT(*) as total FROM orders WHERE order_status = 'delivered'

Generated SQL:
SELECT COUNT(DISTINCT orders.order_id) FROM orders INNER JOIN order_items ON orders.order_id = order_items.order_id INNER JOIN payments ON orders.order_id = payments.order_id WHERE orders.order_delivered_timestamp IS NOT NULL;

Question: Quelle est la note moyenne par catégorie de produit ?

Expected SQL:
SELECT p.product_category, ROUND(AVG(r.review_score), 2) as avg_score
                  FROM reviews r
                  JOIN orders o ON r.order_id = o.order_id
                  JOIN order_items oi ON o.order_id = oi.order_id
                  JOIN products p ON oi.product_id = p.product_id
                  GROUP BY p.product_category
         

## 6. Sauvegarder les résultats

In [14]:
# Sauvegarder
baseline_results = {
    'model': MODEL,
    'prompt_type': 'zero-shot',
    'total_questions': total,
    'metrics': {
        'execution_success': exec_success / total,
        'execution_accuracy': ex_correct / total,
        'exact_match': em_correct / total
    },
    'by_difficulty': {},
    'details': results
}

for diff in ['simple', 'medium', 'complex']:
    subset = df_results[df_results['difficulty'] == diff]
    n = len(subset)
    if n > 0:
        baseline_results['by_difficulty'][diff] = {
            'total': n,
            'execution_accuracy': subset['execution_accuracy'].sum() / n,
            'exact_match': subset['exact_match'].sum() / n
        }

with open(DATA_DIR / 'results' / 'baseline_results.json', 'w') as f:
    json.dump(baseline_results, f, indent=2, default=str)

print(f'✓ Résultats sauvegardés: {DATA_DIR / "results" / "baseline_results.json"}')

✓ Résultats sauvegardés: ../data/results/baseline_results.json


## Résumé

### Baseline établie :
- **Prompt** : Zero-shot (schéma + question)
- **Execution Accuracy** : X%
- **Exact Match** : X%

### Observations :
- Les requêtes simples fonctionnent bien
- Les JOINs complexes posent problème
- Erreurs fréquentes : noms de colonnes, syntaxe SQLite

### Prochaine étape :
**Notebook 03** : Ajouter RAG pour le schema linking (trouver les bonnes tables/colonnes)

In [None]:
conn.close()
print('Connexion fermée.')