From 1fd5848ff36c41b9dddfdeb87a26fada1ea22a2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20S=C3=A1nchez-Gallego?= Date: Wed, 13 May 2020 22:35:49 -0700 Subject: [PATCH] Simplify code by assumming both SQLA and Peewee --- CHANGELOG.rst | 1 + python/sdssdb/__init__.py | 22 +- python/sdssdb/connection.py | 523 +++++++++++++-------------- python/sdssdb/peewee/__init__.py | 13 +- python/sdssdb/sqlalchemy/__init__.py | 6 - setup.cfg | 22 +- 6 files changed, 288 insertions(+), 299 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a9be0852..bb5a691f 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,7 @@ This document records the main changes to the ``sdssdb`` code. * Add `.PeeweeDatabaseConnection.get_model` to retrieve the model for a given table. * :bug:`28` Temporarily remove SQLAlchemy implementation of ``sds5db`` since it's not maintained. We may reintroduce it later once the schema is stable. * Use ``host=localhost`` when a profile is being used on its own domain. +* :support:`32` Assume that both SQLAlchemy and Peewee will be installed and simplify code. * :release:`0.3.2 <2020-03-10>` * Change ``operations-test`` profile to ``operations`` using the new machine hostname. diff --git a/python/sdssdb/__init__.py b/python/sdssdb/__init__.py index f34b4cca..2431597b 100644 --- a/python/sdssdb/__init__.py +++ b/python/sdssdb/__init__.py @@ -19,23 +19,5 @@ config = get_config(NAME, user_path='~/.sdssdb/sdssdb.yml') -try: - import peewee # noqa - _peewee = True -except ImportError: - _peewee = False - -try: - import sqlalchemy # noqa - _sqla = True -except ImportError: - _sqla = False - if _peewee is False: - raise ImportError('neither SQLAlchemy nor Peewee are installed. ' - 'Install at least one of them to use sdssdb.') - - -if _peewee: - from .connection import PeeweeDatabaseConnection # noqa -if _sqla: - from .connection import SQLADatabaseConnection # noqa +from .connection import PeeweeDatabaseConnection # noqa +from .connection import SQLADatabaseConnection # noqa diff --git a/python/sdssdb/connection.py b/python/sdssdb/connection.py index 41b4829f..cbb86f92 100644 --- a/python/sdssdb/connection.py +++ b/python/sdssdb/connection.py @@ -14,21 +14,18 @@ import pgpasslib import six -from sdssdb import _peewee, _sqla, config, log -from sdssdb.utils.internals import get_database_columns - +from sqlalchemy import MetaData, create_engine +from sqlalchemy.engine import url +from sqlalchemy.exc import OperationalError as OpError +from sqlalchemy.orm import scoped_session, sessionmaker -if _peewee: - import peewee - from peewee import OperationalError, PostgresqlDatabase - from playhouse.postgres_ext import ArrayField - from playhouse.reflection import Introspector, UnknownField +import peewee +from peewee import OperationalError, PostgresqlDatabase +from playhouse.postgres_ext import ArrayField +from playhouse.reflection import Introspector, UnknownField -if _sqla: - from sqlalchemy import create_engine, MetaData - from sqlalchemy.engine import url - from sqlalchemy.exc import OperationalError as OpError - from sqlalchemy.orm import sessionmaker, scoped_session +from sdssdb import config, log +from sdssdb.utils.internals import get_database_columns __all__ = ['DatabaseConnection', 'PeeweeDatabaseConnection', 'SQLADatabaseConnection'] @@ -326,334 +323,328 @@ def post_connect(self): pass -if _peewee: - - class PeeweeDatabaseConnection(DatabaseConnection, PostgresqlDatabase): - """Peewee database connection implementation. - - Attributes - ---------- - models : list - Models bound to this database. Only models that are bound using - `~sdssdb.peewee.BaseModel` are handled. +class PeeweeDatabaseConnection(DatabaseConnection, PostgresqlDatabase): + """Peewee database connection implementation. - """ - - def __init__(self, *args, **kwargs): + Attributes + ---------- + models : list + Models bound to this database. Only models that are bound using + `~sdssdb.peewee.BaseModel` are handled. - self.models = {} - self.introspector = {} + """ - self._metadata = {} + def __init__(self, *args, **kwargs): - PostgresqlDatabase.__init__(self, None) - DatabaseConnection.__init__(self, *args, **kwargs) + self.models = {} + self.introspector = {} - @property - def connected(self): - """Reports whether the connection is active.""" + self._metadata = {} - return self.is_connection_usable() + PostgresqlDatabase.__init__(self, None) + DatabaseConnection.__init__(self, *args, **kwargs) - @property - def connection_params(self): - """Returns a dictionary with the connection parameters.""" + @property + def connected(self): + """Reports whether the connection is active.""" - if self.connected: - return self.connect_params.copy() + return self.is_connection_usable() - return None + @property + def connection_params(self): + """Returns a dictionary with the connection parameters.""" - def _conn(self, dbname, silent_on_fail=False, **params): - """Connects to the DB and tests the connection.""" + if self.connected: + return self.connect_params.copy() - if 'password' not in params: - try: - params['password'] = pgpasslib.getpass(dbname=dbname, **params) - except pgpasslib.FileNotFound: - params['password'] = None + return None - PostgresqlDatabase.init(self, dbname, **params) + def _conn(self, dbname, silent_on_fail=False, **params): + """Connects to the DB and tests the connection.""" + if 'password' not in params: try: - PostgresqlDatabase.connect(self) - self.dbname = dbname - except OperationalError as ee: - if not silent_on_fail: - log.warning(f'failed to connect to database {self.database!r}: {ee}') - PostgresqlDatabase.init(self, None) + params['password'] = pgpasslib.getpass(dbname=dbname, **params) + except pgpasslib.FileNotFound: + params['password'] = None + + PostgresqlDatabase.init(self, dbname, **params) + + try: + PostgresqlDatabase.connect(self) + self.dbname = dbname + except OperationalError as ee: + if not silent_on_fail: + log.warning(f'failed to connect to database {self.database!r}: {ee}') + PostgresqlDatabase.init(self, None) + + if self.is_connection_usable() and self.auto_reflect: + with self.atomic(): + for model in self.models.values(): + if getattr(model._meta, 'use_reflection', False): + if hasattr(model, 'reflect'): + model.reflect() + + if self.connected: + self.post_connect() - if self.is_connection_usable() and self.auto_reflect: - with self.atomic(): - for model in self.models.values(): - if getattr(model._meta, 'use_reflection', False): - if hasattr(model, 'reflect'): - model.reflect() - - if self.connected: - self.post_connect() - - return self.connected - - def get_model(self, table_name, schema=None): - """Returns the model for a table. - - Parameters - ---------- - table_name : str - The name of the table whose model will be returned. - schema : str - The schema for the table. If `None`, the first model that - matches the table name will be returned. - - Returns - ------- - :class:`peewee:Model` or `None` - The model associated with the table, or `None` if no model - was found. - - """ - - for model in self.models: - if schema and model._meta.schema != schema: - continue - if model._meta.table_name == table_name: - return model - - return None + return self.connected - def get_introspector(self, schema=None): - """Gets a Peewee database :class:`peewee:Introspector`.""" + def get_model(self, table_name, schema=None): + """Returns the model for a table. - schema_key = schema or '' + Parameters + ---------- + table_name : str + The name of the table whose model will be returned. + schema : str + The schema for the table. If `None`, the first model that + matches the table name will be returned. - if schema_key not in self.introspector: - self.introspector[schema_key] = Introspector.from_database( - self, schema=schema) + Returns + ------- + :class:`peewee:Model` or `None` + The model associated with the table, or `None` if no model + was found. - return self.introspector[schema_key] + """ - def get_fields(self, table_name, schema=None, cache=True): - """Returns a list of Peewee fields for a table.""" + for model in self.models: + if schema and model._meta.schema != schema: + continue + if model._meta.table_name == table_name: + return model - schema = schema or 'public' + return None - if schema not in self._metadata or not cache: - self._metadata[schema] = get_database_columns(self, - schema=schema) + def get_introspector(self, schema=None): + """Gets a Peewee database :class:`peewee:Introspector`.""" - if table_name not in self._metadata[schema]: - return [] + schema_key = schema or '' - table_metadata = self._metadata[schema][table_name] + if schema_key not in self.introspector: + self.introspector[schema_key] = Introspector.from_database( + self, schema=schema) - pk = table_metadata['pk'] - composite_key = pk is not None and len(pk) > 1 + return self.introspector[schema_key] - columns = table_metadata['columns'] + def get_fields(self, table_name, schema=None, cache=True): + """Returns a list of Peewee fields for a table.""" - fields = [] - for col_name, field_type, array_type, nullable in columns: + schema = schema or 'public' - is_pk = True if (pk is not None and not composite_key and - pk[0] == col_name) else False + if schema not in self._metadata or not cache: + self._metadata[schema] = get_database_columns(self, schema=schema) - params = {'column_name': col_name, - 'null': nullable, - 'primary_key': is_pk, - 'unique': is_pk} + if table_name not in self._metadata[schema]: + return [] - if array_type: - field = ArrayField(array_type, **params) - elif array_type is False and field_type is UnknownField: - field = peewee.BareField(**params) - else: - field = field_type(**params) + table_metadata = self._metadata[schema][table_name] - fields.append(field) + pk = table_metadata['pk'] + composite_key = pk is not None and len(pk) > 1 - return fields + columns = table_metadata['columns'] - def get_primary_keys(self, table_name, schema=None, cache=True): - """Returns the primary keys for a table.""" + fields = [] + for col_name, field_type, array_type, nullable in columns: - schema = schema or 'public' + is_pk = True if (pk is not None and not composite_key and + pk[0] == col_name) else False - if schema not in self._metadata or not cache: - self._metadata[schema] = get_database_columns(self, - schema=schema) + params = {'column_name': col_name, + 'null': nullable, + 'primary_key': is_pk, + 'unique': is_pk} - if table_name not in self._metadata[schema]: - return [] + if array_type: + field = ArrayField(array_type, **params) + elif array_type is False and field_type is UnknownField: + field = peewee.BareField(**params) else: - return self._metadata[schema][table_name]['pk'] or [] + field = field_type(**params) + fields.append(field) -if _sqla: + return fields - class SQLADatabaseConnection(DatabaseConnection): - ''' SQLAlchemy database connection implementation ''' + def get_primary_keys(self, table_name, schema=None, cache=True): + """Returns the primary keys for a table.""" - engine = None - bases = [] - Session = None - metadata = None + schema = schema or 'public' - def __init__(self, *args, **kwargs): + if schema not in self._metadata or not cache: + self._metadata[schema] = get_database_columns(self, schema=schema) - #: Reports whether the connection is active. - self.connected = False + if table_name not in self._metadata[schema]: + return [] + else: + return self._metadata[schema][table_name]['pk'] or [] - self._connect_params = None - DatabaseConnection.__init__(self, *args, **kwargs) - @property - def connection_params(self): - """Returns a dictionary with the connection parameters.""" +class SQLADatabaseConnection(DatabaseConnection): + ''' SQLAlchemy database connection implementation ''' - return self._connect_params + engine = None + bases = [] + Session = None + metadata = None - def _get_password(self, **params): - ''' Get a db password from a pgpass file + def __init__(self, *args, **kwargs): - Parameters: - params (dict): - A dictionary of database connection parameters + #: Reports whether the connection is active. + self.connected = False - Returns: - The database password for a given set of connection parameters + self._connect_params = None + DatabaseConnection.__init__(self, *args, **kwargs) - ''' + @property + def connection_params(self): + """Returns a dictionary with the connection parameters.""" - password = params.get('password', None) - if not password: - try: - password = pgpasslib.getpass(params['host'], params['port'], - params['database'], params['username']) - except KeyError: - raise RuntimeError('ERROR: invalid server configuration') - return password + return self._connect_params - def _make_connection_string(self, dbname, **params): - ''' Build a db connection string + def _get_password(self, **params): + ''' Get a db password from a pgpass file - Parameters: - dbname (str): - The name of the database to connect to - params (dict): - A dictionary of database connection parameters + Parameters: + params (dict): + A dictionary of database connection parameters - Returns: - A database connection string + Returns: + The database password for a given set of connection parameters - ''' + ''' - db_params = params.copy() - db_params['drivername'] = 'postgresql+psycopg2' - db_params['database'] = dbname - db_params['username'] = db_params.pop('user', None) - db_params['host'] = db_params.pop('host', 'localhost') - db_params['port'] = db_params.pop('port', 5432) - if db_params['username']: - db_params['password'] = self._get_password(**db_params) - db_connection_string = url.URL(**db_params) - self._connect_params = params - return db_connection_string + password = params.get('password', None) + if not password: + try: + password = pgpasslib.getpass(params['host'], params['port'], + params['database'], params['username']) + except KeyError: + raise RuntimeError('ERROR: invalid server configuration') + return password - def _conn(self, dbname, silent_on_fail=False, **params): - '''Connects to the DB and tests the connection.''' + def _make_connection_string(self, dbname, **params): + ''' Build a db connection string - # get connection string - db_connection_string = self._make_connection_string(dbname, **params) + Parameters: + dbname (str): + The name of the database to connect to + params (dict): + A dictionary of database connection parameters - try: - self.create_engine(db_connection_string, echo=False, - pool_size=10, pool_recycle=1800) - self.engine.connect() - except OpError: - if not silent_on_fail: - log.warning('Failed to connect to database {0}'.format(dbname)) - self.engine.dispose() - self.engine = None - self.connected = False - self.Session = None - self.metadata = None - else: - self.connected = True - self.dbname = dbname - self.prepare_bases() + Returns: + A database connection string - if self.connected: - self.post_connect() + ''' - return self.connected + db_params = params.copy() + db_params['drivername'] = 'postgresql+psycopg2' + db_params['database'] = dbname + db_params['username'] = db_params.pop('user', None) + db_params['host'] = db_params.pop('host', 'localhost') + db_params['port'] = db_params.pop('port', 5432) + if db_params['username']: + db_params['password'] = self._get_password(**db_params) + db_connection_string = url.URL(**db_params) + self._connect_params = params + return db_connection_string + + def _conn(self, dbname, silent_on_fail=False, **params): + '''Connects to the DB and tests the connection.''' + + # get connection string + db_connection_string = self._make_connection_string(dbname, **params) + + try: + self.create_engine(db_connection_string, echo=False, + pool_size=10, pool_recycle=1800) + self.engine.connect() + except OpError: + if not silent_on_fail: + log.warning('Failed to connect to database {0}'.format(dbname)) + self.engine.dispose() + self.engine = None + self.connected = False + self.Session = None + self.metadata = None + else: + self.connected = True + self.dbname = dbname + self.prepare_bases() - def reset_engine(self): - ''' Reset the engine, metadata, and session ''' + if self.connected: + self.post_connect() - self.bases = [] - if self.engine: - self.engine.dispose() - self.engine = None - self.metadata = None - self.Session.close() - self.Session = None + return self.connected - def create_engine(self, db_connection_string=None, echo=False, pool_size=10, - pool_recycle=1800, expire_on_commit=True): - ''' Create a new database engine + def reset_engine(self): + ''' Reset the engine, metadata, and session ''' - Resets and creates a new sqlalchemy database engine. Also creates and binds - engine metadata and a new scoped session. + self.bases = [] + if self.engine: + self.engine.dispose() + self.engine = None + self.metadata = None + self.Session.close() + self.Session = None - ''' + def create_engine(self, db_connection_string=None, echo=False, pool_size=10, + pool_recycle=1800, expire_on_commit=True): + ''' Create a new database engine - self.reset_engine() + Resets and creates a new sqlalchemy database engine. Also creates and binds + engine metadata and a new scoped session. - if not db_connection_string: - dbname = self.dbname or self.DATABASE_NAME - db_connection_string = self._make_connection_string(dbname, - **self.connection_params) + ''' - self.engine = create_engine(db_connection_string, echo=echo, pool_size=pool_size, - pool_recycle=pool_recycle) - self.metadata = MetaData(bind=self.engine) - self.Session = scoped_session(sessionmaker(bind=self.engine, autocommit=True, - expire_on_commit=expire_on_commit)) + self.reset_engine() - def add_base(self, base, prepare=True): - """Binds a base to this connection.""" + if not db_connection_string: + dbname = self.dbname or self.DATABASE_NAME + db_connection_string = self._make_connection_string(dbname, + **self.connection_params) - if base not in self.bases: - self.bases.append(base) + self.engine = create_engine(db_connection_string, echo=echo, pool_size=pool_size, + pool_recycle=pool_recycle) + self.metadata = MetaData(bind=self.engine) + self.Session = scoped_session(sessionmaker(bind=self.engine, autocommit=True, + expire_on_commit=expire_on_commit)) - if prepare and self.connected: - self.prepare_bases(base=base) + def add_base(self, base, prepare=True): + """Binds a base to this connection.""" - def prepare_bases(self, base=None): - """Prepare a Model Base + if base not in self.bases: + self.bases.append(base) - Prepares a SQLalchemy Base for reflection. This binds a database - engine to a specific Base which maps to a set of ModelClasses. - If ``base`` is passed only that base will be prepared. Otherwise, - all the bases bound to this database connection will be prepared. + if prepare and self.connected: + self.prepare_bases(base=base) - """ + def prepare_bases(self, base=None): + """Prepare a Model Base - do_bases = [base] if base else self.bases + Prepares a SQLalchemy Base for reflection. This binds a database + engine to a specific Base which maps to a set of ModelClasses. + If ``base`` is passed only that base will be prepared. Otherwise, + all the bases bound to this database connection will be prepared. - for base in do_bases: - base.prepare(self.engine) + """ - # If the base has an attribute _relations that's the function - # to call to set up the relationships once the engine has been - # bound to the base. - if hasattr(base, '_relations'): - if isinstance(base._relations, str): - module = importlib.import_module(base.__module__) - relations_func = getattr(module, base._relations) - relations_func() - elif callable(base._relations): - base._relations() - else: - pass + do_bases = [base] if base else self.bases + + for base in do_bases: + base.prepare(self.engine) + + # If the base has an attribute _relations that's the function + # to call to set up the relationships once the engine has been + # bound to the base. + if hasattr(base, '_relations'): + if isinstance(base._relations, str): + module = importlib.import_module(base.__module__) + relations_func = getattr(module, base._relations) + relations_func() + elif callable(base._relations): + base._relations() + else: + pass diff --git a/python/sdssdb/peewee/__init__.py b/python/sdssdb/peewee/__init__.py index 5636afb1..cb14eb4f 100644 --- a/python/sdssdb/peewee/__init__.py +++ b/python/sdssdb/peewee/__init__.py @@ -1,13 +1,14 @@ -# isort:skip_file +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# @Author: José Sánchez-Gallego (gallegoj@uw.edu) +# @Date: 2018-09-22 +# @Filename: __init__.py +# @License: BSD 3-clause (http://www.opensource.org/licenses/BSD-3-Clause) # flake8: noqa import warnings -from sdssdb import _peewee - -if _peewee is False: - raise ImportError('Peewee must be installed to use this module.') - import peewee from peewee import Model, ModelBase, fn from playhouse.hybrid import hybrid_method diff --git a/python/sdssdb/sqlalchemy/__init__.py b/python/sdssdb/sqlalchemy/__init__.py index d3273c79..0a81e8e7 100644 --- a/python/sdssdb/sqlalchemy/__init__.py +++ b/python/sdssdb/sqlalchemy/__init__.py @@ -10,12 +10,6 @@ from __future__ import absolute_import, division, print_function -from sdssdb import _sqla - -if _sqla is False: - raise ImportError('SQLAlchemy must be installed to use this module.') - - from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy.sql.expression import func diff --git a/setup.cfg b/setup.cfg index 14a5c1bc..fe9e7e64 100644 --- a/setup.cfg +++ b/setup.cfg @@ -74,9 +74,29 @@ docs = [isort] line_length = 79 +sections = + FUTURE + STDLIB + THIRDPARTY + SQLA + PEEWEE + SDSS + FIRSTPARTY + LOCALFOLDER +default_section = THIRDPARTY +known_first_party = + sdssdb +known_sqla = + sqlalchemy +known_peewee = + peewee + playhouse +known_sdss = + sdsstools +balanced_wrapping = true +include_trailing_comma = false lines_after_imports = 2 use_parentheses = true -balanced_wrapping = true [flake8] ignore =