In [102]:
from llama_cpp import Llama
import json
import graphviz
import numpy as np

In [4]:
schema_path = "schemas/schema_en.sql"
schema = open(schema_path, "r").read()

In [6]:
template = open("templates/gemini_fix_template.txt", "r").read()
print(template.format(schema=schema, 
                      question="List the name and capital of companies registered in the state of 'SP' with a capital greater than 1,000,000", 
                      predicted="select * from carro",
                      error="no such table: carro"))

Instruction:
You will be provided with a user question, the schema of a database along with two sample rows from each table, the SQL code previous produced, and finally the sqlite error.
Your task is to fix the provided SQL query using a chain of thought, with the final step being the Corrected SQL query itself. This chain of thought should clearly explain the reasoning process step-by-step.

Input:
User Question: List the name and capital of companies registered in the state of 'SP' with a capital greater than 1,000,000
Schema:
CREATE TABLE company (	
	basic_cnpj TEXT,
	name TEXT,
	capital DOUBLE,
	responsible_federal_entity TEXT,
	legal_nature_code INTEGER,
	responsible_qualification_code INTEGER,
	company_size_code INTEGER,
	PRIMARY KEY (basic_cnpj),
	FOREIGN KEY(basic_cnpj) REFERENCES establishment (basic_cnpj),
	FOREIGN KEY(legal_nature_code) REFERENCES legal_nature (code),
	FOREIGN KEY(responsible_qualification_code) REFERENCES qualification (code),
	FOREIGN KEY(company_size_code

In [104]:
def softmax(x, axis=0):
    """
    Compute the softmax of vector x.
    
    Parameters:
    x (ndarray): Input array or matrix.
    
    Returns:
    ndarray: Softmax of input.
    """
    e_x = np.exp(x - np.max(x))  # subtracting max for numerical stability
    return e_x / e_x.sum(axis=axis)

In [None]:
model_path = "models/stable_code_Q8_0.gguf"
model = Llama(model_path=model_path, n_ctx=4096, logits_all=True)

In [156]:
schema_path = "schemas/schema_en.sql"
schema = open(schema_path, "r").read()

question = "List name and number of companies the are in the top 10 companies by number of establishments in the city of limas. "
messages = [
        {
            "role": "schema",
            "content": schema
        },
        {
            "role": "user",
            "content": question
        }]


In [157]:
ans = model.create_chat_completion(messages, logprobs=True, top_logprobs=5, temperature=-1)

Llama.generate: prefix-match hit

llama_print_timings:        load time =   30260.37 ms
llama_print_timings:      sample time =     316.92 ms /     1 runs   (  316.92 ms per token,     3.16 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =   11959.07 ms /    63 runs   (  189.83 ms per token,     5.27 tokens per second)
llama_print_timings:       total time =   12335.19 ms /    63 tokens


In [158]:
from pprint import pprint

In [159]:
print(ans['choices'][0]['message']['content'])

SELECT T2.name,
       COUNT (*)
FROM establishment AS T1
JOIN city AS T2 ON T1.city_code = T2.code
GROUP BY T2.name
ORDER BY COUNT (*) DESC
LIMIT 10;<|end|>


In [160]:
print(ans['choices'][0]['message']['content'])

SELECT T2.name,
       COUNT (*)
FROM establishment AS T1
JOIN city AS T2 ON T1.city_code = T2.code
GROUP BY T2.name
ORDER BY COUNT (*) DESC
LIMIT 10;<|end|>


In [161]:
pprint(ans['choices'][0]['logprobs'])

{'text_offset': [1177,
                 1177,
                 1183,
                 1185,
                 1186,
                 1187,
                 1191,
                 1192,
                 1193,
                 1200,
                 1205,
                 1208,
                 1209,
                 1210,
                 1214,
                 1228,
                 1231,
                 1233,
                 1234,
                 1235,
                 1237,
                 1239,
                 1244,
                 1247,
                 1249,
                 1250,
                 1253,
                 1255,
                 1256,
                 1257,
                 1261,
                 1262,
                 1266,
                 1268,
                 1270,
                 1271,
                 1272,
                 1276,
                 1277,
                 1282,
                 1285,
                 1287,
                 1288,
           

In [162]:
top_logprobs = ans['choices'][0]['logprobs']['top_logprobs']
tokens = ans['choices'][0]['logprobs']['tokens']

In [163]:
tokens = tokens[:15]

In [164]:
import numpy as np

In [165]:
logprobs

{'|>': 0.9999999,
 '|=': 1.2837685e-07,
 '|<': 2.3619894e-08,
 '|': 6.625713e-09,
 'inee': 2.785301e-09}

In [166]:
values = softmax(list(logprobs.values()))
logprobs = {k: v for k, v in zip(logprobs.keys(), values)}

In [167]:
dot = graphviz.Digraph(comment='Token Probabilities')

dot.attr(rankdir='same')
ith = 0
# Adicionar nós e arestas com pesos
for (i, token), logprobs in zip(enumerate(tokens[:-1]), top_logprobs[1:]):
    ith += 1
    next_token = tokens[i + 1]
    # logprobs = top_logprobs[i]
    values = softmax(list(logprobs.values()))
    logprobs = {k: v for k, v in zip(logprobs.keys(), values)}
    chosen_logprob = logprobs.get(next_token, None)
    
    length = len(logprobs)
    for j, (other_token, logprob) in enumerate(logprobs.items()):
        if other_token != next_token:
            dot.node(f'{other_token}_{i+1}', repr(other_token))
            # Adicionar aresta com outras probabilidades
            dot.edge(f'{token}_{i}', f'{other_token}_{i+1}', label=f'{logprob:.2f}', color='blue')

        if chosen_logprob is not None and length // 2  == j:
            # Adicionar aresta com a probabilidade do token escolhido
            dot.node(f'{next_token}_{i+1}', repr(next_token))
            dot.edge(f'{token}_{i}', f'{next_token}_{i+1}', label=f'{chosen_logprob:.2f}', color='red')
            chosen_logprob = None


# Renderizar o grafo e salvar em um arquivo
dot.render('token_probabilities_graph', format='png', cleanup=True)

# Mostrar o caminho do arquivo gerado
print("O grafo foi salvo como 'token_probabilities_graph.png'")


O grafo foi salvo como 'token_probabilities_graph.png'
