# Generating SQL for Postgres using Ollama, ChromaDB
This notebook runs through the process of using the `vanna` Python package to generate SQL using AI (RAG + LLMs) including connecting to a database and training.

## Setup

In [13]:
%pip install pandas ollama vanna chromadb postgres

Note: you may need to restart the kernel to use updated packages.


In [2]:
from vanna.ollama import Ollama
from vanna.chromadb import ChromaDB_VectorStore


In [4]:

class MyVanna(ChromaDB_VectorStore, Ollama):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        Ollama.__init__(self, config=config)

vn = MyVanna(config={'model': 'mistral'})


## Postgres Setup
- Install Postgres database on your machine
- import this sample database - https://www.postgresqltutorial.com/postgresql-getting-started/postgresql-sample-database/
- Change the next cell's user name/password according to your setup

In [5]:
# set these properties according to your database setup
vn.connect_to_postgres(host='localhost', dbname='dvdrental', user='postgres', password='password', port='5432')


## Training
You only need to train once. Do not train again unless you want to add more training data.

In [6]:

# gather the database schema definitions 
df_information_schema = vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS")

# This will break up the information schema into bite-sized chunks that can be referenced by the LLM
plan = vn.get_training_plan_generic(df_information_schema)

# train the model with the schema information
vn.train(plan=plan)



In [7]:

# The following are methods for adding training data. Make sure you modify the examples to match your database.

# DDL statements are powerful because they specify table names, colume names, types, and potentially relationships
vn.train(ddl="""
CREATE TABLE public.film (
	film_id serial4 NOT NULL,
	title varchar(255) NOT NULL,
	description text NULL,
	release_year public."year" NULL,
	language_id int2 NOT NULL,
	rental_duration int2 DEFAULT 3 NOT NULL,
	rental_rate numeric(4, 2) DEFAULT 4.99 NOT NULL,
	length int2 NULL,
	replacement_cost numeric(5, 2) DEFAULT 19.99 NOT NULL,
	rating public.mpaa_rating DEFAULT 'G'::mpaa_rating NULL,
	last_update timestamp DEFAULT now() NOT NULL,
	special_features _text NULL,
	fulltext tsvector NOT NULL,
	CONSTRAINT film_pkey PRIMARY KEY (film_id)
);
CREATE INDEX film_fulltext_idx ON public.film USING gist (fulltext);
CREATE INDEX idx_fk_language_id ON public.film USING btree (language_id);
CREATE INDEX idx_title ON public.film USING btree (title);


create trigger film_fulltext_trigger before
insert
    or
update
    on
    public.film for each row execute function tsvector_update_trigger('fulltext',
    'pg_catalog.english',
    'title',
    'description');
create trigger last_updated before
update
    on
    public.film for each row execute function last_updated();

ALTER TABLE public.film ADD CONSTRAINT film_language_id_fkey FOREIGN KEY (language_id) REFERENCES public."language"(language_id) ON DELETE RESTRICT ON UPDATE CASCADE;
""")

# Sometimes you may want to add documentation about your business terminology or definitions.
vn.train(documentation="The films table stores film data such as title, release year, length, rating, etc.")

# You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
vn.train(sql="SELECT * FROM public.film WHERE rental_duration >= 1")


Adding ddl: 
CREATE TABLE public.film (
	film_id serial4 NOT NULL,
	title varchar(255) NOT NULL,
	description text NULL,
	release_year public."year" NULL,
	language_id int2 NOT NULL,
	rental_duration int2 DEFAULT 3 NOT NULL,
	rental_rate numeric(4, 2) DEFAULT 4.99 NOT NULL,
	length int2 NULL,
	replacement_cost numeric(5, 2) DEFAULT 19.99 NOT NULL,
	rating public.mpaa_rating DEFAULT 'G'::mpaa_rating NULL,
	last_update timestamp DEFAULT now() NOT NULL,
	special_features _text NULL,
	fulltext tsvector NOT NULL,
	CONSTRAINT film_pkey PRIMARY KEY (film_id)
);
CREATE INDEX film_fulltext_idx ON public.film USING gist (fulltext);
CREATE INDEX idx_fk_language_id ON public.film USING btree (language_id);
CREATE INDEX idx_title ON public.film USING btree (title);


create trigger film_fulltext_trigger before
insert
    or
update
    on
    public.film for each row execute function tsvector_update_trigger('fulltext',
    'pg_catalog.english',
    'title',
    'description');
create trigger last_up

'56b820f8-537c-5839-9f8d-012ff4a28f9e-sql'

In [8]:
# At any time you can inspect what training data the package is able to reference
training_data = vn.get_training_data()
training_data

Unnamed: 0,id,question,content,training_data_type
0,56b820f8-537c-5839-9f8d-012ff4a28f9e-sql,What film(s) have a rental duration of at lea...,SELECT * FROM public.film WHERE rental_duratio...,sql
0,29047e10-cecc-56d2-88f9-63b199afd4aa-ddl,,\nCREATE TABLE public.film (\n\tfilm_id serial...,ddl
0,012d26b1-8219-52c1-bfb0-c9f283dda62c-doc,,The following columns are in the pg_stat_sys_t...,documentation
1,022c5082-e52f-5d26-b752-1d816166f3bd-doc,,The following columns are in the element_types...,documentation
2,02e69209-e3fe-58c0-a83e-fa80ae7536cf-doc,,The following columns are in the pg_foreign_da...,documentation
...,...,...,...,...
227,fc5c1a59-85da-5a12-be90-7ca9ab31129f-doc,,The following columns are in the pg_statistic ...,documentation
228,fd544066-638e-557d-b35c-fda1a9e4d7ef-doc,,The following columns are in the pg_config tab...,documentation
229,fdadba53-1bb0-55d8-ab8b-b3375dfbccd5-doc,,The following columns are in the pg_user table...,documentation
230,fe2486f0-113a-53ea-bc3e-85db6f869871-doc,,The following columns are in the pg_rewrite ta...,documentation


## Asking the AI
Whenever you ask a new question, it will find the 10 most relevant pieces of training data and use it as part of the LLM prompt to generate the SQL.

In [9]:
vn.ask(question='how many movies are there with rating pg-13')

Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1


SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n\nCREATE TABLE public.film (\n\tfilm_id serial4 NOT NULL,\n\ttitle varchar(255) NOT NULL,\n\tdescription text NULL,\n\trelease_year public."year" NULL,\n\tlanguage_id int2 NOT NULL,\n\trental_duration int2 DEFAULT 3 NOT NULL,\n\trental_rate numeric(4, 2) DEFAULT 4.99 NOT NULL,\n\tlength int2 NULL,\n\treplacement_cost numeric(5, 2) DEFAULT 19.99 NOT NULL,\n\trating public.mpaa_rating DEFAULT \'G\'::mpaa_rating NULL,\n\tlast_update timestamp DEFAULT now() NOT NULL,\n\tspecial_features _text NULL,\n\tfulltext tsvector NOT NULL,\n\tCONSTRAINT film_pkey PRIMARY KEY (film_id)\n);\nCREATE INDEX film_fulltext_idx ON public.film USING gist (fulltext);\nCREATE INDEX idx_fk_language_id ON public.film USING btree (language_id);\nCREATE IND

## Launch the User Interface

In [11]:
from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn, allow_llm_to_see_data=True)
app.run()

Your app is running at:
http://localhost:8084
 * Serving Flask app 'vanna.flask'
 * Debug mode: on


Number of requested results 10 is greater than number of elements in index 3, updating n_results = 3
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 3, updating n_results = 3
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 3, updating n_results = 3
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 4, updating n_results = 4
Number of requested results 10 is greater than number of elements in index 1, updating n_results = 1
Number of requested results 10 is greater than number of elements in index 4, updating n_results = 4
Number of requested results 10 is greater than number of elements in index 1, updating n_re

## Next Steps
Using Vanna via Jupyter notebooks is great for getting started but check out additional customizable interfaces like the 
- [Streamlit app](https://github.com/vanna-ai/vanna-streamlit)
- [Flask app](https://github.com/vanna-ai/vanna-flask)
- [Slackbot](https://github.com/vanna-ai/vanna-slack)
