In [1]:
# -*- coding: utf-8 -*-
import random
import os
from dashscope import Generation
from torch import seed
from vanna.base import VannaBase
import random 

DEBUG_INFO=None

class QwenLLM(VannaBase):
  def __init__(self,config=None):
    self.model=config['model']
    self.api_key=config['api_key']
  
  def system_message(self,message: str):
    return {'role':'system','content':message}

  def user_message(self, message: str):
    return {'role':'user','content':message}

  def assistant_message(self, message: str):
    return {'role':'assistant','content':message}
  
  def submit_prompt(self,prompt,**kwargs):
    resp=Generation.call(
      model=self.model,
      messages=prompt,
      seed=random.randint(1, 10000),
      result_format='message',
      api_key=self.api_key)
    answer=resp.output.choices[0].message.content
    global DEBUG_INFO
    DEBUG_INFO=(prompt,answer)
    return answer




In [2]:
from vanna.chromadb import ChromaDB_VectorStore
class MyVanna(ChromaDB_VectorStore,QwenLLM):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self,config=config)
        QwenLLM.__init__(self,config=config)

In [None]:
vn=MyVanna({'api_key':'sk-252e6b83b1e6490f945bd740cf8c7cee','model':'qwen-max'})
print('model:',os.getenv('model'))

model: None


In [4]:
vn.connect_to_postgres(host='localhost',dbname='databasetest',user='postgres',password='123456',port=5432) # connect to postgres

In [5]:
DDL='''CREATE TABLE IF NOT EXISTS users (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    age INT
);

-- 添加表注释
COMMENT ON TABLE users is
 '用户信息表';

-- 添加字段注释
COMMENT ON COLUMN users.id IS '用户ID';
COMMENT ON COLUMN users.name IS '姓名';
COMMENT ON COLUMN users.age IS '年龄';

'''


In [6]:
vn.train(ddl=DDL)

Adding ddl: CREATE TABLE IF NOT EXISTS users (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    age INT
);

-- 添加表注释
COMMENT ON TABLE users is
 '用户信息表';

-- 添加字段注释
COMMENT ON COLUMN users.id IS '用户ID';
COMMENT ON COLUMN users.name IS '姓名';
COMMENT ON COLUMN users.age IS '年龄';




Add of existing embedding ID: ab0ac208-2f5e-50b0-9177-423427220940-ddl
Add of existing embedding ID: ab0ac208-2f5e-50b0-9177-423427220940-ddl
Add of existing embedding ID: 54cc1ab3-6dbb-54d1-bbd3-675304c87f9b-ddl
Delete of nonexisting embedding ID: ab0ac208-2f5e-50b0-9177-423427220940-ddl
Add of existing embedding ID: 54cc1ab3-6dbb-54d1-bbd3-675304c87f9b-ddl
Delete of nonexisting embedding ID: ab0ac208-2f5e-50b0-9177-423427220940-ddl
Add of existing embedding ID: 54cc1ab3-6dbb-54d1-bbd3-675304c87f9b-ddl
Delete of nonexisting embedding ID: ab0ac208-2f5e-50b0-9177-423427220940-ddl
Insert of existing embedding ID: 54cc1ab3-6dbb-54d1-bbd3-675304c87f9b-ddl
Add of existing embedding ID: 54cc1ab3-6dbb-54d1-bbd3-675304c87f9b-ddl


'54cc1ab3-6dbb-54d1-bbd3-675304c87f9b-ddl'

In [7]:
vn.train(documentation='"福报"是指age>=35岁，也就是可以向社会输送的人才')

Add of existing embedding ID: 8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc
Add of existing embedding ID: 8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc
Add of existing embedding ID: 8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc
Insert of existing embedding ID: 8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc
Add of existing embedding ID: 8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc


Adding documentation....


'8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc'

In [8]:
'''
1，通过LLM根据SQL构造一个question
2，按question-SQL的JSON入库
            {
                "question": question,
                "sql": sql,
            }
'''

vn.train(sql='select name from users where age between 10 and 50')

Add of existing embedding ID: d6bea7e9-0c59-54cb-aebe-d4e39689cab9-sql
Add of existing embedding ID: d60fb50b-2ff8-51d5-9ef5-f4e79c7ea4f0-sql
Add of existing embedding ID: 34bc30ff-3f11-5dcb-8b29-26a580fccb9b-sql
Add of existing embedding ID: d6bea7e9-0c59-54cb-aebe-d4e39689cab9-sql
Add of existing embedding ID: 34bc30ff-3f11-5dcb-8b29-26a580fccb9b-sql
Add of existing embedding ID: d6bea7e9-0c59-54cb-aebe-d4e39689cab9-sql
Add of existing embedding ID: 34bc30ff-3f11-5dcb-8b29-26a580fccb9b-sql
Insert of existing embedding ID: 34bc30ff-3f11-5dcb-8b29-26a580fccb9b-sql


Question generated with sql: What are the names of users whose age is between 10 and 50? 
Adding SQL...


'34bc30ff-3f11-5dcb-8b29-26a580fccb9b-sql'

In [9]:
Q,A=DEBUG_INFO
print('PROMPT:',Q[0]['content'])
print('ANSWER:',A)

PROMPT: The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question.
ANSWER: What are the names of users whose age is between 10 and 50?


In [10]:
'''
按question-SQL的JSON入库
            {
                "question": question,
                "sql": sql,
            }
'''
vn.train(question='小鱼儿的年龄',sql='select age from users where name="小鱼儿"')

Add of existing embedding ID: d6bea7e9-0c59-54cb-aebe-d4e39689cab9-sql
Insert of existing embedding ID: d6bea7e9-0c59-54cb-aebe-d4e39689cab9-sql


'd6bea7e9-0c59-54cb-aebe-d4e39689cab9-sql'

In [31]:

vn.remove_training_data(id='4d634f6e-a9c5-5673-ab4c-20964575fc2a-sql')
vn.get_training_data()

Delete of nonexisting embedding ID: 4d634f6e-a9c5-5673-ab4c-20964575fc2a-sql
Delete of nonexisting embedding ID: 4d634f6e-a9c5-5673-ab4c-20964575fc2a-sql


Unnamed: 0,id,question,content,training_data_type
0,9ccf7bcd-5091-5b97-bf72-af9d41e526a5-sql,What are the names of users whose age is betwe...,select name from user where age between 10 and 20,sql
1,ce40ba37-b3e8-5b0a-ae6e-6e9404652319-sql,What are the names of users whose age is betwe...,select name from user where age between 10 and 50,sql
2,d6bea7e9-0c59-54cb-aebe-d4e39689cab9-sql,小鱼儿的年龄,"select age from users where name=""小鱼儿""",sql
3,d60fb50b-2ff8-51d5-9ef5-f4e79c7ea4f0-sql,用户的平均年龄,select avg(age) from users,sql
4,34bc30ff-3f11-5dcb-8b29-26a580fccb9b-sql,What are the names of users whose age is betwe...,select name from users where age between 10 an...,sql
5,5c785b40-dfdd-582a-90e2-3e2f20113b6d-sql,用户的平均年龄,SELECT AVG(age) FROM users;,sql
6,3192ce10-f704-56b7-9da7-e5272bec739a-sql,打算给一批员工送福报，把他们的名字过滤出来,SELECT name FROM users WHERE age >= 35;,sql
7,7acdb5f3-93e3-52b8-a1bc-46bd2c809db6-sql,统计一下各年龄段的用户数量,"SELECT age, COUNT(*) FROM users GROUP BY age;",sql
0,54cc1ab3-6dbb-54d1-bbd3-675304c87f9b-ddl,,CREATE TABLE IF NOT EXISTS users (\n id INT...,ddl
0,8fc54ebe-8bb3-5fb7-88a6-5c98d817ed07-doc,,"""福报""是指age>=35岁，也就是可以向社会输送的人才",documentation


In [12]:
result=vn.generate_sql('用户的平均年龄')
print('SQL:',result)

Q,A=DEBUG_INFO
print('PROMPT:',Q[0]['content'])
print('ANSWER:',A)

Number of requested results 10 is greater than number of elements in index 7, updating n_results = 7
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1


SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \nCREATE TABLE IF NOT EXISTS users (\n    id INT PRIMARY KEY,\n    name VARCHAR(100),\n    age INT\n);\n\n-- 添加表注释\nCOMMENT ON TABLE users is\n \'用户信息表\';\n\n-- 添加字段注释\nCOMMENT ON COLUMN users.id IS \'用户ID\';\nCOMMENT ON COLUMN users.name IS \'姓名\';\nCOMMENT ON COLUMN users.age IS \'年龄\';\n\n\n\n\n===Additional Context \n\n"福报"是指age>=35岁，也就是可以向社会输送的人才\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment 

In [13]:
# vn.ask('用户的平均年龄')

In [14]:
result=vn.generate_sql('打算给一批员工送福报，把他们的名字过滤出来')

Number of requested results 10 is greater than number of elements in index 7, updating n_results = 7
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1


SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \nCREATE TABLE IF NOT EXISTS users (\n    id INT PRIMARY KEY,\n    name VARCHAR(100),\n    age INT\n);\n\n-- 添加表注释\nCOMMENT ON TABLE users is\n \'用户信息表\';\n\n-- 添加字段注释\nCOMMENT ON COLUMN users.id IS \'用户ID\';\nCOMMENT ON COLUMN users.name IS \'姓名\';\nCOMMENT ON COLUMN users.age IS \'年龄\';\n\n\n\n\n===Additional Context \n\n"福报"是指age>=35岁，也就是可以向社会输送的人才\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment 

In [15]:
vn.ask('打算给一批员工送福报，把他们的名字过滤出来',visualize=False)

Number of requested results 10 is greater than number of elements in index 7, updating n_results = 7
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1


SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \nCREATE TABLE IF NOT EXISTS users (\n    id INT PRIMARY KEY,\n    name VARCHAR(100),\n    age INT\n);\n\n-- 添加表注释\nCOMMENT ON TABLE users is\n \'用户信息表\';\n\n-- 添加字段注释\nCOMMENT ON COLUMN users.id IS \'用户ID\';\nCOMMENT ON COLUMN users.name IS \'姓名\';\nCOMMENT ON COLUMN users.age IS \'年龄\';\n\n\n\n\n===Additional Context \n\n"福报"是指age>=35岁，也就是可以向社会输送的人才\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment 

Add of existing embedding ID: 3192ce10-f704-56b7-9da7-e5272bec739a-sql
Insert of existing embedding ID: 3192ce10-f704-56b7-9da7-e5272bec739a-sql


LLM Response: SELECT name FROM users WHERE age >= 35;
Extracted SQL: SELECT name FROM users WHERE age >= 35;
SELECT name FROM users WHERE age >= 35;
  name
0  yhq


('SELECT name FROM users WHERE age >= 35;',
   name
 0  yhq,
 None)

In [16]:
Q,A=DEBUG_INFO
print('PROMPT:',Q[0]['content'])
print('ANSWER:',A)

PROMPT: You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. 
===Tables 
CREATE TABLE IF NOT EXISTS users (
    id INT PRIMARY KEY,
    name VARCHAR(100),
    age INT
);

-- 添加表注释
COMMENT ON TABLE users is
 '用户信息表';

-- 添加字段注释
COMMENT ON COLUMN users.id IS '用户ID';
COMMENT ON COLUMN users.name IS '姓名';
COMMENT ON COLUMN users.age IS '年龄';




===Additional Context 

"福报"是指age>=35岁，也就是可以向社会输送的人才

===Response Guidelines 
1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. 
2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql 
3. If the provided context is insufficient, p

In [29]:
vn.generate_sql('统计一下各年龄段的用户数量,年龄段是指0-10,10-20,20-30,30-40,40-50,50-60,60-70,70-80...左闭右开区间')

Number of requested results 10 is greater than number of elements in index 7, updating n_results = 7
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1


SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \nCREATE TABLE IF NOT EXISTS users (\n    id INT PRIMARY KEY,\n    name VARCHAR(100),\n    age INT\n);\n\n-- 添加表注释\nCOMMENT ON TABLE users is\n \'用户信息表\';\n\n-- 添加字段注释\nCOMMENT ON COLUMN users.id IS \'用户ID\';\nCOMMENT ON COLUMN users.name IS \'姓名\';\nCOMMENT ON COLUMN users.age IS \'年龄\';\n\n\n\n\n===Additional Context \n\n"福报"是指age>=35岁，也就是可以向社会输送的人才\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment 

"SELECT CASE \n         WHEN age BETWEEN 0 AND 10 THEN '[0-10)'\n         WHEN age BETWEEN 11 AND 20 THEN '[11-20)'\n         WHEN age BETWEEN 21 AND 30 THEN '[21-30)'\n         WHEN age BETWEEN 31 AND 40 THEN '[31-40)'\n         WHEN age BETWEEN 41 AND 50 THEN '[41-50)'\n         WHEN age BETWEEN 51 AND 60 THEN '[51-60)'\n         WHEN age BETWEEN 61 AND 70 THEN '[61-70)'\n         WHEN age BETWEEN 71 AND 80 THEN '[71-80)'\n         ELSE 'Others'\n       END AS age_group, COUNT(*) \nFROM users \nGROUP BY age_group;"

In [34]:
vn.ask('统计一下各年龄段的用户数量',visualize=False)

Number of requested results 10 is greater than number of elements in index 9, updating n_results = 9
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1


SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \nCREATE TABLE IF NOT EXISTS users (\n    id INT PRIMARY KEY,\n    name VARCHAR(100),\n    age INT\n);\n\n-- 添加表注释\nCOMMENT ON TABLE users is\n \'用户信息表\';\n\n-- 添加字段注释\nCOMMENT ON COLUMN users.id IS \'用户ID\';\nCOMMENT ON COLUMN users.name IS \'姓名\';\nCOMMENT ON COLUMN users.age IS \'年龄\';\n\n\n\n\n===Additional Context \n\n"福报"是指age>=35岁，也就是可以向社会输送的人才\n\n===Response Guidelines \n1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment 

Add of existing embedding ID: 7acdb5f3-93e3-52b8-a1bc-46bd2c809db6-sql
Insert of existing embedding ID: 7acdb5f3-93e3-52b8-a1bc-46bd2c809db6-sql


LLM Response: SELECT age, COUNT(*) FROM users GROUP BY age;
Extracted SQL: SELECT age, COUNT(*) FROM users GROUP BY age;
SELECT age, COUNT(*) FROM users GROUP BY age;
   age  count
0   40      1
1   15      1
2   18      1


('SELECT age, COUNT(*) FROM users GROUP BY age;',
    age  count
 0   40      1
 1   15      1
 2   18      1,
 None)