In [50]:
from datasets import load_dataset

train_dataset = load_dataset("xlangai/spider", split='train')
validation_dataset = load_dataset("xlangai/spider", split='validation')
db_schema = load_dataset("richardr1126/spider-schema", split='train')

In [9]:
print(train_dataset)
print(validation_dataset)

Dataset({
    features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
    num_rows: 7000
})
Dataset({
    features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
    num_rows: 1034
})


In [34]:
print(db_schema)

Dataset({
    features: ['db_id', 'Schema (values (type))', 'Primary Keys', 'Foreign Keys'],
    num_rows: 166
})


In [15]:
train_dataset[0]

{'db_id': 'department_management',
 'query': 'SELECT count(*) FROM head WHERE age  >  56',
 'question': 'How many heads of the departments are older than 56 ?',
 'query_toks': ['SELECT',
  'count',
  '(',
  '*',
  ')',
  'FROM',
  'head',
  'WHERE',
  'age',
  '>',
  '56'],
 'query_toks_no_value': ['select',
  'count',
  '(',
  '*',
  ')',
  'from',
  'head',
  'where',
  'age',
  '>',
  'value'],
 'question_toks': ['How',
  'many',
  'heads',
  'of',
  'the',
  'departments',
  'are',
  'older',
  'than',
  '56',
  '?']}

In [40]:
filtered_data = db_schema.filter(lambda row: row['db_id'] == "department_management")
print(filtered_data[0])


{'db_id': 'department_management', 'Schema (values (type))': 'department : Department_ID (number) , Name (text) , Creation (text) , Ranking (number) , Budget_in_Billions (number) , Num_Employees (number) | head : head_ID (number) , name (text) , born_state (text) , age (number) | management : department_ID (number) , head_ID (number) , temporary_acting (text)', 'Primary Keys': 'department : Department_ID | head : head_ID | management : department_ID', 'Foreign Keys': 'management : head_ID equals head : head_ID | management : department_ID equals department : Department_ID'}


# Use llama for generation (0-shot)

In [16]:
!ollama pull llama3

[?25lpulling manifest ⠋ [?25h[?25l[2K[1Gpulling manifest ⠙ [?25h[?25l[2K[1Gpulling manifest ⠹ [?25h[?25l[2K[1Gpulling manifest ⠸ [?25h[?25l[2K[1Gpulling manifest ⠼ [?25h[?25l[2K[1Gpulling manifest ⠴ [?25h[?25l[2K[1Gpulling manifest ⠦ [?25h[?25l[2K[1Gpulling manifest ⠧ [?25h[?25l[2K[1Gpulling manifest 
pulling 6a0746a1ec1a... 100% ▕████████████████▏ 4.7 GB                         
pulling 4fa551d4f938... 100% ▕████████████████▏  12 KB                         
pulling 8ab4849b038c... 100% ▕████████████████▏  254 B                         
pulling 577073ffcc6c... 100% ▕████████████████▏  110 B                         [?25h[?25l[2K[1G[A[2K[1G[A[2K[1G[A[2K[1G[A[2K[1Gpulling manifest 
pulling 6a0746a1ec1a... 100% ▕████████████████▏ 4.7 GB                         
pulling 4fa551d4f938... 100% ▕████████████████▏  12 KB                         
pulling 8ab4849b038c... 100% ▕████████████████▏  254 B                         
pulling 577073ffcc6c...

In [18]:
!pip -qq install langchain
!pip -qq install langchain-core
!pip -qq install langchain-community

In [19]:
from langchain_community.llms import Ollama
llm = Ollama(model = "llama3.2")
answer = llm.invoke("what is the Meaning of life")
print(answer)

  llm = Ollama(model = "llama3.2")


The meaning of life is a complex and debated topic that has been explored by philosophers, theologians, scientists, and individuals from various cultures and backgrounds. There is no one definitive answer to this question, as it can vary greatly depending on individual perspectives, beliefs, and values.

Some possible answers to the meaning of life include:

1. **Happiness**: Many people believe that the meaning of life is to seek happiness and fulfillment. This can be achieved through relationships, personal growth, and pursuing one's passions.
2. **Self-actualization**: According to psychologist Abraham Maslow, the meaning of life is to realize one's full potential and become the best version of oneself.
3. **Love and connection**: For some, the meaning of life is found in loving and being loved by others, forming meaningful connections with family, friends, and community.
4. **Personal growth and self-improvement**: Some believe that the meaning of life is to continually learn, grow

## Trying to generate some SQL queries

In [None]:
nb_queries = 5
prompt = "Write the SQL query that answer the user's question. Answer only the SQL query. Question: {question}.\nSQL Query:"
classification_prompt = "Tell if these two SQL queries are giving the same result, answer yes or no only. If no, explain. Query 1: {query1}.\nQuery 2: {query2}.\nSame (correction if necessary):"

for i in range(nb_queries):
    print(f"\n--------\n")
    print(f"question: {question}")
    query1 = train_dataset[i]['query']
    question = train_dataset[i]['question']
    prompt_completed = prompt.format(question=question)
    query2 = llm.invoke(prompt_completed)
    print(f"\nAnswer: {query2}\n")
    print(f"Correct answer: {train_dataset[i]['query']}\n")

    correct = llm.invoke(classification_prompt.format(query1 = query1, query2 = query2))
    print(f"Correct: {correct}")
    



--------

question: What is the average number of employees of the departments whose rank is between 10 and 15?

Answer: SELECT COUNT(*) FROM employees WHERE age > 56 AND job_title IN ('CEO', 'Director', 'Manager')

Correct answer: SELECT count(*) FROM head WHERE age  >  56

Correct: No.

Query 1 is counting the number of people with an age greater than 56 in the "head" table, which may not be directly comparable to Query 2 because it only considers one table ("employees"). Additionally, Query 2 filters by job title, whereas Query 1 does not have any such filter.

--------

question: How many heads of the departments are older than 56 ?

Answer: SELECT name, birth_state, DATEDIFF(YEAR, birth_date, getdate()) as age FROM department_heads ORDER BY age

Correct answer: SELECT name ,  born_state ,  age FROM head ORDER BY age

Correct: No.

In Query 1, `age` is not included in the `SELECT` clause. In SQL Server, when you use an aggregate function like `MAX()` or `MIN()`, you need to includ

In [48]:
import re
from tqdm import tqdm

pattern = r'\b(yes|no)\b'
nb_queries = 100
prompt_schema = "Based on the SQL schema, write a SQL query that answer the user's question. Answer only the SQL query. Schema: {schema}.\nQuestion: {question}.\nSQL Query:"
classification_prompt = "Tell if these two SQL queries are giving the same result, answer yes or no only. Query 1: {query1}.\nQuery 2: {query2}.\nSame (correction if necessary):"
verbose = False
nb_correct = 0
list_incorrect = []

for i in tqdm(range(nb_queries), desc="Processing queries"):
    dataset_i = train_dataset[i]
    db_id = dataset_i['db_id']
    filtered_data = db_schema.filter(lambda row: row['db_id'] == db_id)
    schema = filtered_data[0]
    
    query1 = dataset_i['query']
    question = dataset_i['question']
    prompt_completed = prompt.format(question=question, schema=schema)
    query2 = llm.invoke(prompt_completed)
    correct = llm.invoke(classification_prompt.format(query1 = query1, query2 = query2))
    matches = re.findall(pattern, correct, flags=re.IGNORECASE)

    # Increment nb_yes for each "yes" found
    nb_correct += sum(1 for match in matches if match.lower() == 'yes')

    if verbose:
        print(f"\n--------")
        print(f"question: {question}")
        print(f"schema: {schema}")
        print(f"\nAnswer: {query2}\n")
        print(f"Correct answer: {dataset_i['query']}")
        print(f"Correct: {correct}")

print(f"Accuracy: {nb_correct/nb_queries}")
    


Filter: 100%|██████████| 166/166 [00:00<00:00, 27843.50 examples/s]]
Filter: 100%|██████████| 166/166 [00:00<00:00, 27048.46 examples/s]]
Processing queries: 100%|██████████| 100/100 [01:25<00:00,  1.17it/s]

Accuracy: 0.41





In [21]:
train_dataset[0]

{'db_id': 'department_management',
 'query': 'SELECT count(*) FROM head WHERE age  >  56',
 'question': 'How many heads of the departments are older than 56 ?',
 'query_toks': ['SELECT',
  'count',
  '(',
  '*',
  ')',
  'FROM',
  'head',
  'WHERE',
  'age',
  '>',
  '56'],
 'query_toks_no_value': ['select',
  'count',
  '(',
  '*',
  ')',
  'from',
  'head',
  'where',
  'age',
  '>',
  'value'],
 'question_toks': ['How',
  'many',
  'heads',
  'of',
  'the',
  'departments',
  'are',
  'older',
  'than',
  '56',
  '?']}