# 04 - Few-Shot Examples : La méthode la plus efficace

**Baseline** : 25% EX → **Objectif : 50%+**

In [1]:
import sqlite3
import pandas as pd
import json
import requests
from pathlib import Path

DATA_DIR = Path('../data')
DB_PATH = DATA_DIR / 'database' / 'ecommerce.db'
conn = sqlite3.connect(DB_PATH)

with open(DATA_DIR / 'results' / 'test_questions.json', 'r') as f:
    TEST_QUESTIONS = json.load(f)

print(f'✓ {len(TEST_QUESTIONS)} questions chargées')

✓ 12 questions chargées


In [2]:
SIMPLE_SCHEMA = '''Tables:
- customers(customer_id, customer_city, customer_state)
- orders(order_id, customer_id, order_status, order_purchase_timestamp, order_delivered_timestamp)
- order_items(order_id, order_item_id, product_id, price, freight_value)
- products(product_id, product_category, product_weight_g)
- payments(order_id, payment_sequential, payment_type, payment_installments, payment_value)
- reviews(review_id, order_id, review_score, review_comment_title, review_comment_message)'''

print(SIMPLE_SCHEMA)

Tables:
- customers(customer_id, customer_city, customer_state)
- orders(order_id, customer_id, order_status, order_purchase_timestamp, order_delivered_timestamp)
- order_items(order_id, order_item_id, product_id, price, freight_value)
- products(product_id, product_category, product_weight_g)
- payments(order_id, payment_sequential, payment_type, payment_installments, payment_value)
- reviews(review_id, order_id, review_score, review_comment_title, review_comment_message)


In [3]:
FEW_SHOT_EXAMPLES = '''
Example 1:
Q: How many orders are there?
SQL: SELECT COUNT(*) as total FROM orders;

Example 2:
Q: How many delivered orders?
SQL: SELECT COUNT(*) as total FROM orders WHERE order_status = 'delivered';

Example 3:
Q: How many orders per status?
SQL: SELECT order_status, COUNT(*) as count FROM orders GROUP BY order_status ORDER BY count DESC;

Example 4:
Q: Total revenue by product category?
SQL: SELECT p.product_category, SUM(oi.price) as revenue FROM order_items oi JOIN products p ON oi.product_id = p.product_id GROUP BY p.product_category ORDER BY revenue DESC;

Example 5:
Q: Average review score by city?
SQL: SELECT c.customer_city, ROUND(AVG(r.review_score), 2) as avg_score FROM customers c JOIN orders o ON c.customer_id = o.customer_id JOIN reviews r ON o.order_id = r.order_id GROUP BY c.customer_city ORDER BY avg_score DESC;
'''

print('Few-shot examples définis')

Few-shot examples définis


In [4]:
FEW_SHOT_PROMPT = '''You are a SQLite expert for an e-commerce database.

{schema}

Rules:
- Use table aliases (c for customers, o for orders, oi for order_items, p for products, r for reviews)
- SQLite: use julianday() for dates, not DATEDIFF

{examples}

Q: {question}
SQL:'''

In [5]:
OLLAMA_URL = 'http://localhost:11434/api/generate'
MODEL = 'mistral'

def call_llm(prompt: str) -> str:
    try:
        response = requests.post(
            OLLAMA_URL,
            json={'model': MODEL, 'prompt': prompt, 'stream': False,
                  'options': {'temperature': 0, 'num_predict': 200}},
            timeout=60
        )
        return response.json().get('response', '').strip()
    except Exception as e:
        return f'ERROR: {e}'

def extract_sql(response: str) -> str:
    sql = response.strip()
    if '```sql' in sql:
        sql = sql.split('```sql')[1].split('```')[0].strip()
    elif '```' in sql:
        sql = sql.split('```')[1].split('```')[0].strip()
    # Take first statement only
    lines = [l for l in sql.split('\n') if l.strip() and not l.strip().startswith('--')]
    sql = ' '.join(lines)
    if ';' in sql:
        sql = sql.split(';')[0] + ';'
    return sql

def generate_sql(question: str) -> str:
    prompt = FEW_SHOT_PROMPT.format(
        schema=SIMPLE_SCHEMA,
        examples=FEW_SHOT_EXAMPLES,
        question=question
    )
    response = call_llm(prompt)
    return extract_sql(response)

# Test
test = generate_sql('Combien de commandes au total ?')
print(f'Test: {test}')

Test: SELECT COUNT(*) as total FROM orders;


In [6]:
def execute_sql(sql: str, conn) -> tuple:
    try:
        result = pd.read_sql(sql, conn)
        return True, result
    except Exception as e:
        return False, str(e)

def compare_results(r1: pd.DataFrame, r2: pd.DataFrame) -> bool:
    try:
        if r1.empty and r2.empty:
            return True
        if r1.shape[0] != r2.shape[0]:
            return False
        # Compare values
        v1 = sorted([str(x) for x in r1.iloc[:, -1].tolist()])
        v2 = sorted([str(x) for x in r2.iloc[:, -1].tolist()])
        if v1 == v2:
            return True
        # Numeric tolerance
        try:
            n1 = sorted([float(x) for x in r1.iloc[:, -1].tolist()])
            n2 = sorted([float(x) for x in r2.iloc[:, -1].tolist()])
            return all(abs(a - b) < 1 for a, b in zip(n1, n2))
        except:
            return False
    except:
        return False

In [7]:
print('ÉVALUATION FEW-SHOT')
print('='*60)

results = []
for i, q in enumerate(TEST_QUESTIONS):
    question = q['question']
    expected_sql = q['sql']
    difficulty = q['difficulty']
    
    generated_sql = generate_sql(question)
    gen_ok, gen_res = execute_sql(generated_sql, conn)
    exp_ok, exp_res = execute_sql(expected_sql, conn)
    
    ex = False
    if gen_ok and exp_ok:
        ex = compare_results(gen_res, exp_res)
    
    results.append({
        'question': question,
        'difficulty': difficulty,
        'expected_sql': expected_sql,
        'generated_sql': generated_sql,
        'exec_success': gen_ok,
        'exec_accuracy': ex,
        'error': None if gen_ok else gen_res
    })
    
    status = '✓' if ex else ('⚠' if gen_ok else '✗')
    print(f'[{i+1:2d}] {status} [{difficulty:7s}] {question[:40]}...')

ÉVALUATION FEW-SHOT
[ 1] ✓ [simple ] Combien de commandes au total ?...
[ 2] ✓ [simple ] Combien de clients différents ?...
[ 3] ✓ [simple ] Combien de commandes livrées ?...
[ 4] ✓ [simple ] Quel est le nombre de produits dans le c...
[ 5] ✓ [medium ] Quelles sont les 5 villes avec le plus d...
[ 6] ⚠ [medium ] Quel est le revenue total par catégorie ...
[ 7] ✓ [medium ] Quelle est la note moyenne par catégorie...
[ 8] ⚠ [medium ] Combien de commandes par méthode de paie...
[ 9] ✓ [complex] Quels clients ont passé plus de 2 comman...
[10] ⚠ [complex] Quel est le panier moyen par ville ?...
[11] ✗ [complex] Quels produits ont une note moyenne infé...
[12] ✗ [complex] Quel est le délai moyen de livraison par...


In [8]:
df = pd.DataFrame(results)
total = len(df)
success = df['exec_success'].sum()
accuracy = df['exec_accuracy'].sum()

print('\n' + '='*60)
print('RÉSULTATS FEW-SHOT')
print('='*60)
print(f'Execution Success: {success}/{total} ({success/total:.0%})')
print(f'Execution Accuracy: {accuracy}/{total} ({accuracy/total:.0%})')

print('\nPar difficulté:')
for diff in ['simple', 'medium', 'complex']:
    sub = df[df['difficulty'] == diff]
    n = len(sub)
    ex = sub['exec_accuracy'].sum()
    print(f'  {diff:8s}: {ex}/{n} ({ex/n:.0%})')

print('\n' + '='*60)
print('PROGRESSION')
print('='*60)
print(f'Baseline:  25% EX')
print(f'Few-shot:  {accuracy/total:.0%} EX')
print(f'Gain: {(accuracy/total - 0.25)*100:+.0f} points')


RÉSULTATS FEW-SHOT
Execution Success: 10/12 (83%)
Execution Accuracy: 7/12 (58%)

Par difficulté:
  simple  : 4/4 (100%)
  medium  : 2/4 (50%)
  complex : 1/4 (25%)

PROGRESSION
Baseline:  25% EX
Few-shot:  58% EX
Gain: +33 points


In [10]:
# Voir les erreurs
print('\nDÉTAIL DES ERREURS')
print('='*60)

for _, row in df[~df['exec_accuracy']].iterrows():
    print(f"\nQ: {row['question']}")
    print(f"Expected: {row['expected_sql'][:80]}...")
    print(f"Generated: {row['generated_sql'][:80]}...")
    if pd.notna(row['error']) and row['error']:
        print(f"Error: {str(row['error'])[:60]}")


DÉTAIL DES ERREURS

Q: Quel est le revenue total par catégorie de produit ?
Expected: SELECT p.product_category, SUM(oi.price) as revenue
                  FROM order...
Generated: SELECT p.product_category, SUM(oi.price + oi.freight_value) as total_revenue FRO...

Q: Combien de commandes par méthode de paiement ?
Expected: SELECT payment_type, COUNT(DISTINCT order_id) as num_orders
                  FR...
Generated: SELECT payment_type, COUNT(*) as count FROM payments GROUP BY payment_type ORDER...

Q: Quel est le panier moyen par ville ?
Expected: SELECT c.customer_city, ROUND(AVG(oi.price), 2) as avg_basket
                  ...
Generated: SELECT c.customer_city, ROUND(AVG(oi.price + oi.freight_value), 2) as average_ba...

Q: Quels produits ont une note moyenne inférieure à 3 ?
Expected: SELECT p.product_id, p.product_category, ROUND(AVG(r.review_score), 2) as avg_sc...
Generated: SELECT p.product_id, p.product_category FROM products p JOIN (     SELECT order_...
Error: Execution f

In [11]:
# Sauvegarder
fewshot_results = {
    'model': MODEL,
    'prompt_type': 'few-shot',
    'metrics': {
        'execution_success': success / total,
        'execution_accuracy': accuracy / total
    },
    'by_difficulty': {d: float(df[df['difficulty']==d]['exec_accuracy'].mean()) 
                     for d in ['simple', 'medium', 'complex']},
    'progression': {'baseline': 0.25, 'fewshot': accuracy / total}
}

with open(DATA_DIR / 'results' / 'fewshot_results.json', 'w') as f:
    json.dump(fewshot_results, f, indent=2)

print('✓ Sauvegardé')
conn.close()

✓ Sauvegardé
