In [None]:
import contextlib
import io
import mysql.connector
import mysql.connector.errors
import matplotlib.pyplot as plt
import numpy as np
import pickle
import re
import text2sql.cfg as cfg
import timeit
from text2sql.cfg.parser import text2sqlParser, text2sqlListener, text2sqlLexer
from antlr4.InputStream import InputStream
from antlr4 import ParseTreeWalker, CommonTokenStream
from config import CONFIG
from llm.azure_client import AzureClient

In [None]:
def read_dataset(path: str) -> dict:
    with open(path, 'r') as f:
        lines = f.readlines()
        questions = []
        queries = []
        for line in lines:
            if line.startswith('Vraag: '):
                questions.append(line.replace('Vraag: ', '').strip('\n'))
            elif line.startswith('SQL: '):
                queries.append(line.replace('SQL: ', '').strip('\n'))
    return {'questions': questions, 'queries':queries}

In [None]:
train = read_dataset('./text2sql/datasets/text2sql-trainset.txt')
test = read_dataset('./text2sql/datasets/text2sql-testset.txt')

In [None]:
def run_experiments(dataset: dict, method, **kwargs) -> list:
    results = []
    runtime = []
    for question, query in zip(dataset['questions'], dataset['queries']):
        with contextlib.redirect_stdout(io.StringIO()) as s:
            start = timeit.default_timer()
            method(question, **kwargs)
            end = timeit.default_timer()
            results.append(s.getvalue().strip('\n'))
            runtime.append(end - start)
    return results, np.array(runtime)

In [None]:
def rule_based(question: str, **kwargs):
    filtered_question = cfg.vocabulary.filter_question_words(question)
    # print(f'filtered question: {filtered_question}')
    input = InputStream(filtered_question)
    lexer = text2sqlLexer(input)
    tokens = CommonTokenStream(lexer)
    parser = text2sqlParser(tokens)
    tree = parser.prog()
    listener = text2sqlListener()
    walker = ParseTreeWalker()
    walker.walk(listener, tree)

In [None]:
results_train_rule_based, runtime_train_rule_based = run_experiments(train, rule_based)
results_test_rule_based, runtime_test_rule_based = run_experiments(test, rule_based)

In [None]:
def llm_based(question: str, **kwargs):
    client = kwargs['client']
    print(client.generate(client.conversation+[{'role': 'user', 'content': question}]).replace('```sql\n', '').replace('```', '').replace('\n', ' ').strip())

In [None]:
client_zero_shot = AzureClient(**dict(CONFIG['AZURE'].items()))
client_zero_shot.append_conversation(role='user', content='''
                           Jij bent een expert in SQL en jouw taak is om vragen geschreven in natuurlijke taal te vertalen naar geldige SQL queries. 
                           De database waarmee je werkt bevat de volgende SQL tabel:
                           ```sql
                           CREATE TABLE IF NOT EXISTS vulnerability (
                            cve VARCHAR(255),
                            title TEXT,
                            confidence INT,
                            severity VARCHAR(50), -- een van "informational", "low", "medium", "high" or "critical"
                            cvss DECIMAL(4,2), -- een decimaal getal tussen 0 en 10
                            epss DECIMAL(4,2),
                            cwe VARCHAR(255),
                            age INT, -- leeftijd van kwetsbaarheid in dagen
                            kev BOOLEAN
                           );
                           ```
                           Als je een vraag krijgt moet je die naar een SQL query vertalen die met de bovenstaande tabel moet werken. Als de vraag onbestaande kolomnamen bevat of als je de vraag niet kan vertalen, geef dat gewoon aan.
                           Maak geen nieuwe kolommen aan en veronstschuldig je niet.
                           ''')

results_train_llm_based_zero_shot, runtime_train_llm_based_zero_shot = run_experiments(train, llm_based, client=client_zero_shot)
results_test_llm_based_zero_shot, runtime_test_llm_based_zero_shot = run_experiments(test, llm_based, client=client_zero_shot)

In [None]:
client_few_shot = AzureClient(**dict(CONFIG['AZURE'].items()))
client_few_shot.append_conversation(role='user', content='''
                           Jij bent een expert in SQL en jouw taak is om vragen geschreven in natuurlijke taal te vertalen naar geldige SQL queries. 
                           De database waarmee je werkt bevat de volgende SQL tabel:
                           ```sql
                           CREATE TABLE IF NOT EXISTS vulnerability (
                            cve VARCHAR(255),
                            title TEXT,
                            confidence INT,
                            severity VARCHAR(50), -- een van "informational", "low", "medium", "high" or "critical"
                            cvss DECIMAL(4,2), -- een decimaal getal tussen 0 en 10
                            epss DECIMAL(4,2),
                            cwe VARCHAR(255),
                            age INT, -- leeftijd van kwetsbaarheid in dagen
                            kev BOOLEAN
                           );
                           ```
                           Als je een vraag krijgt moet je die naar een SQL query vertalen die met de bovenstaande tabel moet werken. Als de vraag onbestaande kolomnamen bevat of als je de vraag niet kan vertalen, geef dat gewoon aan.
                           Maak geen nieuwe kolommen aan en veronstschuldig je niet.
                           Hieronder vind je een aantal voorbeelden:
                           
                           Voorbeeld 1:
                           Vraag: Geef de titel en ernst weer van alle kwetsbaarheden.
                           Query: SELECT title, severity FROM vulnerability;
                                    
                           Voorbeeld 2:
                           Vraag: Geef de 3 oudste kwetsbaarheden weer.
                           SQL: SELECT * FROM vulnerability ORDER BY age DESC LIMIT 3;
                                    
                           Voorbeeld 3:
                           Vraag: Geef de CVE en titel weer van alle kwetsbaarheden waarvan de titel eindigt op 'Vulnerability'.
                           SQL: SELECT cve, title FROM vulnerability WHERE title LIKE '%Vulnerability';
                           ''')

results_train_llm_based_few_shot, runtime_train_llm_based_few_shot = run_experiments(train, llm_based, client=client_few_shot)
results_test_llm_based_few_shot, runtime_test_llm_based_few_shot = run_experiments(test, llm_based, client=client_few_shot)

In [None]:
def dump_pickle(obj, file:str):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

def load_pickle(file) -> list:
    with open(file, 'rb') as f:
        return pickle.load(f)

In [None]:
dump_pickle(results_train_rule_based, 'text2sql_results/results_train_rule_based.pkl')
dump_pickle(results_train_llm_based_zero_shot, 'text2sql_results/results_train_llm_based_zero_shot.pkl')
dump_pickle(results_train_llm_based_few_shot, 'text2sql_results/results_train_llm_based_few_shot.pkl')
dump_pickle(results_test_rule_based, 'text2sql_results/results_test_rule_based.pkl')
dump_pickle(results_test_llm_based_zero_shot, 'text2sql_results/results_test_llm_based_zero_shot.pkl')
dump_pickle(results_test_llm_based_few_shot, 'text2sql_results/results_test_llm_based_few_shot.pkl')

In [None]:
dump_pickle(runtime_train_rule_based, 'text2sql_results/runtime_train_rule_based.pkl')
dump_pickle(runtime_train_llm_based_zero_shot, 'text2sql_results/runtime_train_llm_based_zero_shot.pkl')
dump_pickle(runtime_train_llm_based_few_shot, 'text2sql_results/runtime_train_llm_based_few_shot.pkl')
dump_pickle(runtime_test_rule_based, 'text2sql_results/runtime_test_rule_based.pkl')
dump_pickle(runtime_test_llm_based_zero_shot, 'text2sql_results/runtime_test_llm_based_zero_shot.pkl')
dump_pickle(runtime_test_llm_based_few_shot, 'text2sql_results/runtime_test_llm_based_few_shot.pkl')

In [None]:
assert results_train_rule_based == load_pickle('text2sql_results/results_train_rule_based.pkl')
assert results_train_llm_based_zero_shot == load_pickle('text2sql_results/results_train_llm_based_zero_shot.pkl')
assert results_train_llm_based_few_shot == load_pickle('text2sql_results/results_train_llm_based_few_shot.pkl')
assert results_test_rule_based == load_pickle('text2sql_results/results_test_rule_based.pkl')
assert results_test_llm_based_zero_shot == load_pickle('text2sql_results/results_test_llm_based_zero_shot.pkl')
assert results_test_llm_based_few_shot == load_pickle('text2sql_results/results_test_llm_based_few_shot.pkl')

In [None]:
assert np.array_equal(runtime_train_rule_based, load_pickle('text2sql_results/runtime_train_rule_based.pkl'))
assert np.array_equal(runtime_train_llm_based_zero_shot, load_pickle('text2sql_results/runtime_train_llm_based_zero_shot.pkl'))
assert np.array_equal(runtime_train_llm_based_few_shot, load_pickle('text2sql_results/runtime_train_llm_based_few_shot.pkl'))
assert np.array_equal(runtime_test_rule_based, load_pickle('text2sql_results/runtime_test_rule_based.pkl'))
assert np.array_equal(runtime_test_llm_based_zero_shot, load_pickle('text2sql_results/runtime_test_llm_based_zero_shot.pkl'))
assert np.array_equal(runtime_test_llm_based_few_shot, load_pickle('text2sql_results/runtime_test_llm_based_few_shot.pkl'))

# Analysis

In [None]:
def check_select_columns(hypothesis: str, reference: str) -> bool:
    matched_hypothesis = re.search(r'SELECT(?:DISTINCT)?(?P<columns>.*)(?P<rest>FROM\s?.*)', hypothesis)
    matched_reference = re.search(r'SELECT(?:DISTINCT)?(?P<columns>.*)(?P<rest>FROM\s?.*)', reference)
    if matched_hypothesis and matched_reference:
        columns_hypothesis, rest_hypothesis = matched_hypothesis.groups()
        columns_reference, rest_reference = matched_reference.groups()
        return columns_hypothesis != columns_reference and rest_hypothesis == rest_reference
    return False

def validate_sql(query: str) -> None|str:
    with mysql.connector.connect(**CONFIG['DB']) as connection:
        with connection.cursor(dictionary=True, prepared=True) as cursor:
            try:
                cursor.execute(query)
                for bla in cursor.fetchall():
                    print(bla)
            except mysql.connector.errors.InterfaceError as e:
                return e

In [None]:
validate_sql("SELECT title FROM vulnerability WHERE title LIKE '[a-z]%';")

In [None]:
def evaluate_outputs(output: list[str], ground_truth: list[str]) -> dict:
    result = {'correct':[], 'different_valid':[], 'syntax_error':[], 'invalid_column':[]}
    for i, (hypothesis, reference) in enumerate(zip(output, ground_truth)):
        if hypothesis.lower().replace(' ', '') == reference.lower().replace(' ', ''):
            result['correct'].append((hypothesis, i))
            continue
        if sql_error := validate_sql(hypothesis):
            if sql_error.msg.startswith('You have an error in your SQL syntax'):
                result['syntax_error'].append((hypothesis, i))
                continue
            if sql_error.msg.startswith('Unknown column'):
                result['invalid_column'].append((hypothesis, i))
                continue
        result['different_valid'].append((hypothesis, i))
    return result

In [None]:
evaluation_rule_based_train = evaluate_outputs(results_train_rule_based, train['queries'])
evaluation_rule_based_test = evaluate_outputs(results_test_rule_based, test['queries'])
evaluation_zero_shot_train = evaluate_outputs(results_train_llm_based_zero_shot, train['queries'])
evaluation_zero_shot_test = evaluate_outputs(results_test_llm_based_zero_shot, test['queries'])
evaluation_few_shot_train = evaluate_outputs(results_train_llm_based_few_shot, train['queries'])
evaluation_few_shot_test = evaluate_outputs(results_test_llm_based_few_shot, test['queries'])

In [None]:
def print_statistics(evaluation_dict, title):
    print(f'Statistics for {title}:')
    print(f'Correct: {len(evaluation_dict['correct'])}')
    print(f'Different but valid: {len(evaluation_dict['different_valid'])}')
    print(f'Syntax error: {len(evaluation_dict['syntax_error'])}')
    print('*'*35)

In [None]:
print_statistics(evaluation_rule_based_train, 'Rule-based train')
print_statistics(evaluation_rule_based_test, 'Rule-based test')
print_statistics(evaluation_zero_shot_train, 'LLM-based zero-shot train')
print_statistics(evaluation_zero_shot_test, 'LLM-based zero-shot test')
print_statistics(evaluation_few_shot_train, 'LLM-based few-shot train')
print_statistics(evaluation_few_shot_test, 'LLM-based few-shot test')

In [None]:
for produced, ground_truth_index in evaluation_zero_shot_train['different_valid']:
    print(f'Ground truth: {train['queries'][ground_truth_index]}')
    print(f'Produced: {produced}')
    print()

In [None]:
def plot_runtime(runtime: np.ndarray, title):
    fig, ax = plt.subplots(figsize=(12,9))
    ax.set_title(title)
    ax.set_xlabel('Runtime (seconds)')
    ax.set_ylabel('Number of queries')
    ax.hist(runtime)
    plt.show()

In [None]:
plot_runtime(runtime_train_rule_based, 'Runtime of the 50 training questions on the rule-based system')
plot_runtime(runtime_test_rule_based, 'Runtime of the 50 test questions on the rule-based system')
plot_runtime(runtime_train_llm_based_zero_shot, 'Runtime of the 50 train questions on the zero-shot LLM-based system')
plot_runtime(runtime_test_llm_based_zero_shot, 'Runtime of the 50 test questions on the zero-shot LLM-based system')

In [None]:
def scatter_plot_runtime(runtime: np.ndarray, title=None, yticks=None):
    fig, ax = plt.subplots(figsize=(12,9))
    ax.plot(np.arange(1,51), runtime, linestyle='--', marker='o')
    ax.set_title(title)
    ax.set_xlabel('Sample number', fontsize=16)
    ax.set_xticklabels(ax.get_xticklabels(), fontsize=16)
    ax.set_ylabel('Runtime (seconds)', fontsize=16)
    if yticks is not None:
        ax.set_yticks(yticks)
    ax.set_yticklabels(ax.get_yticklabels(), fontsize=16)
    plt.grid()
    plt.show()

In [None]:
scatter_plot_runtime(runtime_train_rule_based)
scatter_plot_runtime(runtime_train_llm_based_zero_shot, yticks=np.arange(0,50,5))
scatter_plot_runtime(runtime_train_llm_based_few_shot)
scatter_plot_runtime(runtime_test_rule_based)
scatter_plot_runtime(runtime_test_llm_based_zero_shot, yticks=np.arange(0,50,5))
scatter_plot_runtime(runtime_test_llm_based_few_shot)