Skip to content

Commit

Permalink
Merge pull request #164 from alitrack/main
Browse files Browse the repository at this point in the history
DuckDB support
  • Loading branch information
zainhoda committed Jan 22, 2024
2 parents 38d7239 + aa01dd7 commit 933528b
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ docs/*.html
.ipynb_checkpoints/
.tox/
notebooks/chroma.sqlite3
dist
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
postgres = ["psycopg2", "db-dtypes"]
bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
all = ["psycopg2", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python"]
duckdb = ["duckdb"]
all = ["psycopg2", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python","duckdb"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand Down
47 changes: 47 additions & 0 deletions src/vanna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
import warnings
from dataclasses import dataclass
from typing import Callable, List, Tuple, Union
from urllib.parse import urlparse

import pandas as pd
import plotly
Expand Down Expand Up @@ -2120,3 +2121,49 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:

global run_sql
run_sql = run_sql_bigquery

def connect_to_duckdb(url: str="memory", init_sql: str = None):
"""
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql]
Args:
url (str): The URL of the database to connect to.
init_sql (str, optional): SQL to run when connecting to the database. Defaults to None.
Returns:
None
"""
try:
import duckdb
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install vanna[duckdb]"
)
# URL of the database to download
if url==":memory:" or url=="":
path=":memory:"
else:
# Path to save the downloaded database
print(os.path.exists(url))
if os.path.exists(url):
path=url
else:
path = os.path.basename(urlparse(url).path)
# Download the database if it doesn't exist
if not os.path.exists(path):
response = requests.get(url)
response.raise_for_status() # Check that the request was successful
with open(path, "wb") as f:
f.write(response.content)

# Connect to the database
conn = duckdb.connect(path)
if init_sql:
conn.query(init_sql)

def run_sql_duckdb(sql: str):
return conn.query(sql).to_df()

global run_sql
run_sql = run_sql_duckdb
44 changes: 44 additions & 0 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,51 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:

self.run_sql_is_set = True
self.run_sql = run_sql_bigquery
def connect_to_duckdb(self, url: str, init_sql: str = None):
"""
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql]
Args:
url (str): The URL of the database to connect to.
init_sql (str, optional): SQL to run when connecting to the database. Defaults to None.
Returns:
None
"""
try:
import duckdb
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: \npip install vanna[duckdb]"
)
# URL of the database to download
if url==":memory:" or url=="":
path=":memory:"
else:
# Path to save the downloaded database
print(os.path.exists(url))
if os.path.exists(url):
path=url
else:
path = os.path.basename(urlparse(url).path)
# Download the database if it doesn't exist
if not os.path.exists(path):
response = requests.get(url)
response.raise_for_status() # Check that the request was successful
with open(path, "wb") as f:
f.write(response.content)

# Connect to the database
conn = duckdb.connect(path)
if init_sql:
conn.query(init_sql)

def run_sql_duckdb(sql: str):
return conn.query(sql).to_df()

self.run_sql = run_sql_duckdb
self.run_sql_is_set = True
def run_sql(sql: str, **kwargs) -> pd.DataFrame:
raise NotImplementedError(
"You need to connect_to_snowflake or other database first."
Expand Down

0 comments on commit 933528b

Please sign in to comment.