Skip to content

Commit

Permalink
limit schemas. adding in redshift/s3 stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
glamp committed Nov 6, 2014
1 parent e265b92 commit 410644d
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
57 changes: 52 additions & 5 deletions db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,41 @@ def __repr__(self):
def _repr_html_(self):
return self._tablify().get_html_string()

class S3(object):
def __init__(self, access_key, secret_key, profile=None):

if profile:
self.load_credentials(profile)
else:
self.access_key = access_key
self.secret_key = secret_key

def save_credentials(self, profile):
"""
Saves credentials to a dotfile
"""
home = os.path.expanduser("~")
filename = os.path.join(home, ".db.py_s3_" + profile)
creds = {
access_key: self.access_key,
secret_key: self.secret_key
}
with open(filename, 'wb') as f:
f.write(base64.encodestring(json.dumps(creds)))

def load_credentials(self, profile):
user = os.path.expanduser("~")
f = os.path.join(user, ".db.py_s3_" + profile)
if os.path.exists(f):
creds = json.loads(base64.decodestring(open(f, 'rb').read()))
if 'access_key' not in creds:
raise Exception("`access_key` not found in s3 profile '%s'" % profile)
self.access_key = creds['access_key']
if 'access_key' not in creds:
raise Exception("`secret_key` not found in s3 profile '%s'" % profile)
self.secret_key = creds['secret_key']


class DB(object):
"""
Utility for exploring and querying a database.
Expand All @@ -612,6 +647,8 @@ class DB(object):
path to sqlite database
dbname: str
Name of the database
schemas: list
List of schemas to include. Defaults to all.
profile: str
Preconfigured database credentials / profile for how you like your queries
exclude_system_tables: bool
Expand All @@ -638,8 +675,8 @@ class DB(object):
>>> db = DB(filename="/path/to/mydb.sqlite", dbtype="sqlite")
"""
def __init__(self, username=None, password=None, hostname="localhost",
port=None, filename=None, dbname=None, dbtype=None, profile="default",
exclude_system_tables=True, limit=1000):
port=None, filename=None, dbname=None, dbtype=None, schemas=None,
profile="default", exclude_system_tables=True, limit=1000):

if port is None:
if dbtype=="postgres":
Expand Down Expand Up @@ -669,6 +706,7 @@ def __init__(self, username=None, password=None, hostname="localhost",
self.filename = filename
self.dbname = dbname
self.dbtype = dbtype
self.schemas = schemas
self.limit = limit

if self.dbtype is None:
Expand Down Expand Up @@ -746,6 +784,7 @@ def load_credentials(self, profile="default"):
self.filename = creds.get('filename')
self.dbname = creds.get('dbname')
self.dbtype = creds.get('dbtype')
self.schemas = creds.get('schemas')
self.limit = creds.get('limit')
else:
raise Exception("Credentials not configured!")
Expand Down Expand Up @@ -782,6 +821,7 @@ def save_credentials(self, profile="default"):
"filename": db_filename,
"dbname": self.dbname,
"dbtype": self.dbtype,
"schemas": self.schemas,
"limit": self.limit,
}
with open(f, 'wb') as credentials_file:
Expand Down Expand Up @@ -1108,7 +1148,9 @@ def refresh_schema(self, exclude_system_tables=True):
"""

sys.stderr.write("Refreshing schema. Please wait...")
if exclude_system_tables==True:
if self.schemas is not None and isinstance(self.schemas, list) and 'schema_specified' in self._query_templates:
q = self._query_templates['system']['schema_specified'] % str(self.schemas)
elif exclude_system_tables==True:
q = self._query_templates['system']['schema_no_system']
else:
q = self._query_templates['system']['schema_with_system']
Expand All @@ -1135,7 +1177,8 @@ def _try_command(self, cmd):
self.con.rollback()

def to_redshift(self, name, df, drop_if_exists=False, chunk_size=10000,
AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, print_sql=False):
AWS_ACCESS_KEY=None, AWS_SECRET_KEY=None, s3=None,
print_sql=False):
"""
Upload a dataframe to redshift via s3.
Expand All @@ -1158,6 +1201,8 @@ def to_redshift(self, name, df, drop_if_exists=False, chunk_size=10000,
AWS_SECRET_KEY: str
your aws secrety key. if this is None, the function will try
and grab AWS_SECRET_KEY from your environment variables
s3: S3
alternative to using keys, you can use an S3 object
print_sql: bool (False)
option for printing sql statement that will be executed
Expand All @@ -1172,6 +1217,9 @@ def to_redshift(self, name, df, drop_if_exists=False, chunk_size=10000,
except ImportError:
raise Exception("Couldn't find boto library. Please ensure it is installed")

if s3 is not None:
AWS_ACCESS_KEY = s3.access_key
AWS_SECRET_KEY = s3.secret_key
if AWS_ACCESS_KEY is None:
AWS_ACCESS_KEY = os.environ.get('AWS_ACCESS_KEY')
if AWS_SECRET_KEY is None:
Expand Down Expand Up @@ -1310,4 +1358,3 @@ def DemoDB():
_ROOT = os.path.abspath(os.path.dirname(__file__))
chinook = os.path.join(_ROOT, 'data', "chinook.sqlite")
return DB(filename=chinook, dbtype="sqlite")

9 changes: 9 additions & 0 deletions db/queries/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
from
information_schema.columns;
""",
"schema_specified": """
select
table_name
, column_name
, udt_name
from
information_schema.columns
where table_schema in (%s);
""",
"foreign_keys_for_table": """
select
column_name
Expand Down
9 changes: 9 additions & 0 deletions db/queries/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@
from
information_schema.columns;
""",
"schema_specified": """
select
table_name
, column_name
, udt_name
from
information_schema.columns
where table_schema in (%s);
""",
"foreign_keys_for_table": """
SELECT
kcu.column_name
Expand Down

0 comments on commit 410644d

Please sign in to comment.