Skip to content

Commit

Permalink
Change EXTENSIONs so that they appear at the top level. Fixes #40.
Browse files Browse the repository at this point in the history
 * docs/extension.rst: Update derivation information.  Add
   ExtensionDict.to_map.
 * pyrseas/database.py (Database.Dicts.__init__): Initialize schemas
   first and extensions second.  (Database._link_refs): Exclude
   extensions from call to schemas.link_refs.  (Database.from_map):
   Add processing of extensions.  (Database.to_map): Add mapping of
   extensions. (Database.diff_map): Process extensions ahead of
   languages and schemas.
 * pyrseas/dbobject/extension.py (Extension.keylist): Limit it to just
   the name.  (ExtensionDict.query): Order by extname only.
   (ExtensionDict.from_map): Remove schema parameter.
   (ExtensionDict.to_map): New method.  (ExtensionDict.diff_map):
   Refactor to exclude schema as part of the key.
 * pyrseas/dbobject/language.py (LanguageDict.diff_map): Remove
   dbversion parameter and use self.dbconn.version instead.
 * pyrseas/dbobject/schema.py (Schema.to_map): Exclude extensions from
   object types list.  (SchemaDict.from_map): Exclude extensions from
   processing.  (SchemaDict.link_refs): Exclude extensions from
   parameters and subsequent processing.
 * pyrseas/testutils.py InputMapToSqlTestCase.std_map): Don't include
   schema pg_catalog.
 * tests/dbobject/test_extension.py: Change maps to move extensions at
   the same level as schemas.
  • Loading branch information
jmafc committed Jul 5, 2012
1 parent 8ec7ae0 commit 46bc738
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 61 deletions.
6 changes: 4 additions & 2 deletions docs/extension.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ Extensions
.. module:: pyrseas.dbobject.extension

The :mod:`extension` module defines two classes, :class:`Extension`
and :class:`ExtensionDict`, derived from :class:`DbSchemaObject` and
and :class:`ExtensionDict`, derived from :class:`DbObject` and
:class:`DbObjectDict`, respectively.

Extension
---------

:class:`Extension` is derived from
:class:`~pyrseas.dbobject.DbSchemaObject` and represents a `PostgreSQL
:class:`~pyrseas.dbobject.DbObject` and represents a `PostgreSQL
extension
<http://www.postgresql.org/docs/current/static/extend-extensions.html>`_.

Expand All @@ -31,4 +31,6 @@ represents the collection of extensions in a database.

.. automethod:: ExtensionDict.from_map

.. automethod:: ExtensionDict.to_map

.. automethod:: ExtensionDict.diff_map
19 changes: 11 additions & 8 deletions pyrseas/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ def __init__(self, dbconn=None):
:param dbconn: a DbConnection object
"""
self.schemas = SchemaDict(dbconn)
self.extensions = ExtensionDict(dbconn)
self.languages = LanguageDict(dbconn)
self.casts = CastDict(dbconn)
self.schemas = SchemaDict(dbconn)
self.types = TypeDict(dbconn)
self.tables = ClassDict(dbconn)
self.columns = ColumnDict(dbconn)
Expand All @@ -99,7 +100,6 @@ def __init__(self, dbconn=None):
self.servers = ForeignServerDict(dbconn)
self.usermaps = UserMappingDict(dbconn)
self.ftables = ForeignTableDict(dbconn)
self.extensions = ExtensionDict(dbconn)
self.collations = CollationDict(dbconn)

def __init__(self, dbname, user=None, pswd=None, host=None, port=None):
Expand All @@ -120,8 +120,7 @@ def _link_refs(self, db):
db.schemas.link_refs(db.types, db.tables, db.functions, db.operators,
db.operfams, db.operclasses, db.conversions,
db.tsconfigs, db.tsdicts, db.tsparsers,
db.tstempls, db.ftables, db.extensions,
db.collations)
db.tstempls, db.ftables, db.collations)
db.tables.link_refs(db.columns, db.constraints, db.indexes,
db.rules, db.triggers)
db.fdwrappers.link_refs(db.servers)
Expand Down Expand Up @@ -176,13 +175,16 @@ def from_map(self, input_map):
"""
self.ndb = self.Dicts()
input_schemas = {}
input_extens = {}
input_langs = {}
input_casts = {}
input_fdws = {}
input_ums = {}
for key in list(input_map.keys()):
if key.startswith('schema '):
input_schemas.update({key: input_map[key]})
elif key.startswith('extension '):
input_extens.update({key: input_map[key]})
elif key.startswith('language '):
input_langs.update({key: input_map[key]})
elif key.startswith('cast '):
Expand All @@ -193,6 +195,7 @@ def from_map(self, input_map):
input_ums.update({key: input_map[key]})
else:
raise KeyError("Expected typed object, found '%s'" % key)
self.ndb.extensions.from_map(input_extens)
self.ndb.languages.from_map(input_langs)
self.ndb.schemas.from_map(input_schemas, self.ndb)
self.ndb.casts.from_map(input_casts, self.ndb)
Expand All @@ -215,7 +218,8 @@ def to_map(self, schemas=[], tables=[], exclude_schemas=[],
"""
if not self.db:
self.from_catalog()
dbmap = self.db.languages.to_map()
dbmap = self.db.extensions.to_map()
dbmap.update(self.db.languages.to_map())
dbmap.update(self.db.casts.to_map())
dbmap.update(self.db.fdwrappers.to_map())
dbmap.update(self.db.schemas.to_map())
Expand Down Expand Up @@ -276,10 +280,9 @@ def diff_map(self, input_map, schemas=[]):
if schemas:
self._trim_objects(schemas)
self.from_map(input_map)
stmts = self.db.languages.diff_map(self.ndb.languages,
self.dbconn.version)
stmts = self.db.extensions.diff_map(self.ndb.extensions)
stmts.append(self.db.languages.diff_map(self.ndb.languages))
stmts.append(self.db.schemas.diff_map(self.ndb.schemas))
stmts.append(self.db.extensions.diff_map(self.ndb.extensions))
stmts.append(self.db.types.diff_map(self.ndb.types))
stmts.append(self.db.functions.diff_map(self.ndb.functions))
stmts.append(self.db.operators.diff_map(self.ndb.operators))
Expand Down
49 changes: 30 additions & 19 deletions pyrseas/dbobject/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
pyrseas.dbobject.extension
~~~~~~~~~~~~~~~~~~~~~~~~~~
This module defines two classes: Extension derived from
DbSchemaObject, and ExtensionDict derived from DbObjectDict.
This module defines two classes: Extension derived from DbObject,
and ExtensionDict derived from DbObjectDict.
"""
from pyrseas.dbobject import DbObjectDict, DbSchemaObject, quote_id
from pyrseas.dbobject import DbObjectDict, DbObject, quote_id


class Extension(DbSchemaObject):
class Extension(DbObject):
"""An extension"""

keylist = ['schema', 'name']
keylist = ['name']
objtype = "EXTENSION"

def create(self):
Expand All @@ -39,12 +39,12 @@ class ExtensionDict(DbObjectDict):

cls = Extension
query = \
"""SELECT nspname AS schema, extname AS name, extversion AS version,
"""SELECT extname AS name, nspname AS schema, extversion AS version,
obj_description(e.oid, 'pg_extension') AS description
FROM pg_extension e
JOIN pg_namespace n ON (extnamespace = n.oid)
WHERE nspname != 'information_schema'
ORDER BY 1, 2"""
ORDER BY extname"""

def _from_catalog(self):
"""Initialize the dictionary of extensions by querying the catalogs"""
Expand All @@ -53,26 +53,37 @@ def _from_catalog(self):
for ext in self.fetch():
self[ext.key()] = ext

def from_map(self, schema, inexts):
def from_map(self, inexts):
"""Initalize the dictionary of extensions by converting the input map
:param schema: schema owning the extensions
:param inexts: YAML map defining the extensions
"""
for key in list(inexts.keys()):
if not key.startswith('extension '):
raise KeyError("Unrecognized object type: %s" % key)
ext = key[10:]
inexten = inexts[key]
self[(schema.name, ext)] = exten = Extension(
schema=schema.name, name=ext)
self[ext] = exten = Extension(name=ext)
for attr, val in list(inexten.items()):
setattr(exten, attr, val)
if 'oldname' in inexten:
exten.oldname = inexten['oldname']
if 'description' in inexten:
exten.description = inexten['description']

def to_map(self):
"""Convert the extension dictionary to a regular dictionary
:return: dictionary
Invokes the `to_map` method of each extension to construct a
dictionary of extensions.
"""
extens = {}
for ext in list(self.keys()):
extens.update(self[ext].to_map())
return extens

def diff_map(self, inexts):
"""Generate SQL to transform existing extensions
Expand All @@ -85,24 +96,24 @@ def diff_map(self, inexts):
"""
stmts = []
# check input extensions
for (sch, ext) in list(inexts.keys()):
inexten = inexts[(sch, ext)]
for ext in list(inexts.keys()):
inexten = inexts[ext]
# does it exist in the database?
if (sch, ext) not in self:
if ext not in self:
if not hasattr(inexten, 'oldname'):
# create new extension
stmts.append(inexten.create())
else:
stmts.append(self[(sch, ext)].rename(inexten))
stmts.append(self[ext].rename(inexten))
else:
# check extension objects
stmts.append(self[(sch, ext)].diff_map(inexten))
stmts.append(self[ext].diff_map(inexten))

# check existing extensions
for (sch, ext) in list(self.keys()):
exten = self[(sch, ext)]
for ext in list(self.keys()):
exten = self[ext]
# if missing, drop them
if (sch, ext) not in inexts:
if ext not in inexts:
stmts.append(exten.drop())

return stmts
Expand Down
6 changes: 3 additions & 3 deletions pyrseas/dbobject/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,10 @@ def to_map(self):
languages.update(self[lng].to_map())
return languages

def diff_map(self, inlanguages, dbversion):
def diff_map(self, inlanguages):
"""Generate SQL to transform existing languages
:param input_map: a YAML map defining the new languages
:param dbversion: DBMS version number
:return: list of SQL statements
Compares the existing language definitions, as fetched from the
Expand Down Expand Up @@ -143,7 +142,8 @@ def diff_map(self, inlanguages, dbversion):
# if missing, drop it
if lng not in inlanguages:
# special case: plpgsql is installed in 9.0
if dbversion >= 90000 and self[lng].name == 'plpgsql':
if self.dbconn.version >= 90000 \
and self[lng].name == 'plpgsql':
continue
self[lng].dropped = True
return stmts
Expand Down
16 changes: 2 additions & 14 deletions pyrseas/dbobject/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def mapper(schema, objtypes):
for objtypes in ['conversions', 'domains', 'ftables', 'functions',
'operators', 'operclasses', 'operfams', 'sequences',
'tsconfigs', 'tsdicts', 'tsparsers', 'tstempls',
'types', 'views', 'extensions', 'collations']:
'types', 'views', 'collations']:
schema[key].update(mapper(self, objtypes))

if hasattr(self, 'description'):
Expand Down Expand Up @@ -104,7 +104,6 @@ def from_map(self, inmap, newdb):
intsps = {}
intsts = {}
inftbs = {}
inexts = {}
incolls = {}
for key in list(inschema.keys()):
if key.startswith('domain '):
Expand Down Expand Up @@ -135,8 +134,6 @@ def from_map(self, inmap, newdb):
intscs.update({key: inschema[key]})
elif key.startswith('foreign table '):
inftbs.update({key: inschema[key]})
elif key.startswith('extension '):
inexts.update({key: inschema[key]})
elif key.startswith('collation '):
incolls.update({key: inschema[key]})
elif key == 'oldname':
Expand All @@ -145,7 +142,6 @@ def from_map(self, inmap, newdb):
schema.description = inschema[key]
else:
raise KeyError("Expected typed object, found '%s'" % key)
newdb.extensions.from_map(schema, inexts)
newdb.types.from_map(schema, intypes, newdb)
newdb.tables.from_map(schema, intables, newdb)
newdb.functions.from_map(schema, infuncs)
Expand All @@ -162,7 +158,7 @@ def from_map(self, inmap, newdb):

def link_refs(self, dbtypes, dbtables, dbfunctions, dbopers, dbopfams,
dbopcls, dbconvs, dbtsconfigs, dbtsdicts, dbtspars,
dbtstmpls, dbftables, dbexts, dbcolls):
dbtstmpls, dbftables, dbcolls):
"""Connect types, tables and functions to their respective schemas
:param dbtypes: dictionary of types and domains
Expand All @@ -177,7 +173,6 @@ def link_refs(self, dbtypes, dbtables, dbfunctions, dbopers, dbopfams,
:param dbtspars: dictionary of text search parsers
:param dbtstmpls: dictionary of text search templates
:param dbftables: dictionary of foreign tables
:param dbexts: dictionary of extensions
:param dbcolls: dictionary of collations
Fills in the `domains` dictionary for each schema by
Expand Down Expand Up @@ -297,13 +292,6 @@ def link_refs(self, dbtypes, dbtables, dbfunctions, dbopers, dbopfams,
if not hasattr(schema, 'ftables'):
schema.ftables = {}
schema.ftables.update({ftb: ftbl})
for (sch, ext) in list(dbexts.keys()):
exten = dbexts[(sch, ext)]
assert self[sch]
schema = self[sch]
if not hasattr(schema, 'extensions'):
schema.extensions = {}
schema.extensions.update({ext: exten})
for (sch, cll) in list(dbcolls.keys()):
coll = dbcolls[(sch, cll)]
assert self[sch]
Expand Down
5 changes: 2 additions & 3 deletions pyrseas/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def std_map(self, plpgsql_installed=False):
and self.db._version < 90100:
base.update({'language plpgsql': {'trusted': True}})
if self.db._version >= 90100:
base.update({'schema pg_catalog': {'extension plpgsql': {
'description': "PL/pgSQL procedural language"},
'description': 'system catalog schema'}})
base.update({'extension plpgsql': {'schema': 'pg_catalog',
'description': "PL/pgSQL procedural language"}})
return base
25 changes: 13 additions & 12 deletions tests/dbobject/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ def test_map_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
dbmap = self.to_map([CREATE_STMT])
self.assertEqual(dbmap['schema public']['extension pg_trgm'],
{'version': '1.0', 'description': TRGM_COMMENT})
self.assertEqual(dbmap['extension pg_trgm'], {
'schema': 'public', 'version': '1.0',
'description': TRGM_COMMENT})

def test_map_no_depends(self):
"Ensure no dependencies are included when mapping an extension"
Expand All @@ -36,18 +37,18 @@ def test_map_lang_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
dbmap = self.to_map(["CREATE EXTENSION plperl"])
self.assertEqual(dbmap['schema pg_catalog']['extension plperl'],
{'version': '1.0',
'description': "PL/Perl procedural language"})
self.assertEqual(dbmap['extension plperl'], {
'schema': 'pg_catalog', 'version': '1.0',
'description': "PL/Perl procedural language"})
self.assertFalse('language plperl' in dbmap)

def test_map_extension_schema(self):
"Map an existing extension"
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
dbmap = self.to_map(["CREATE SCHEMA s1", CREATE_STMT + " SCHEMA s1"])
self.assertEqual(dbmap['schema s1']['extension pg_trgm'],
{'version': '1.0', 'description': TRGM_COMMENT})
self.assertEqual(dbmap['extension pg_trgm'], {
'schema': 's1', 'version': '1.0', 'description': TRGM_COMMENT})


class ExtensionToSqlTestCase(InputMapToSqlTestCase):
Expand All @@ -58,14 +59,14 @@ def test_create_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
inmap = self.std_map()
inmap['schema public'].update({'extension pg_trgm': {}})
inmap.update({'extension pg_trgm': {'schema': 'public'}})
sql = self.to_sql(inmap)
self.assertEqual(sql, [CREATE_STMT])

def test_bad_extension_map(self):
"Error creating a extension with a bad map"
inmap = self.std_map()
inmap['schema public'].update({'pg_trgm': {}})
inmap.update({'pg_trgm': {'schema': 'public'}})
self.assertRaises(KeyError, self.to_sql, inmap)

def test_drop_extension(self):
Expand All @@ -80,7 +81,7 @@ def test_create_extension_schema(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
inmap = self.std_map()
inmap.update({'schema s1': {'extension pg_trgm': {'version': '1.0'}}})
inmap.update({'extension pg_trgm': {'schema': 's1', 'version': '1.0'}})
sql = self.to_sql(inmap, ["CREATE SCHEMA s1"])
self.assertEqual(fix_indent(sql[0]),
"CREATE EXTENSION pg_trgm SCHEMA s1 VERSION '1.0'")
Expand All @@ -90,8 +91,8 @@ def test_comment_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
inmap = self.std_map()
inmap['schema public'].update({'extension pg_trgm': {
'description': "Trigram extension"}})
inmap.update({'extension pg_trgm': {
'schema': 'public', 'description': "Trigram extension"}})
sql = self.to_sql(inmap, [CREATE_STMT])
self.assertEqual(sql, [
"COMMENT ON EXTENSION pg_trgm IS 'Trigram extension'"])
Expand Down

0 comments on commit 46bc738

Please sign in to comment.