Permalink
Browse files

Change EXTENSIONs so that they appear at the top level. Fixes #40.

 * docs/extension.rst: Update derivation information.  Add
   ExtensionDict.to_map.
 * pyrseas/database.py (Database.Dicts.__init__): Initialize schemas
   first and extensions second.  (Database._link_refs): Exclude
   extensions from call to schemas.link_refs.  (Database.from_map):
   Add processing of extensions.  (Database.to_map): Add mapping of
   extensions. (Database.diff_map): Process extensions ahead of
   languages and schemas.
 * pyrseas/dbobject/extension.py (Extension.keylist): Limit it to just
   the name.  (ExtensionDict.query): Order by extname only.
   (ExtensionDict.from_map): Remove schema parameter.
   (ExtensionDict.to_map): New method.  (ExtensionDict.diff_map):
   Refactor to exclude schema as part of the key.
 * pyrseas/dbobject/language.py (LanguageDict.diff_map): Remove
   dbversion parameter and use self.dbconn.version instead.
 * pyrseas/dbobject/schema.py (Schema.to_map): Exclude extensions from
   object types list.  (SchemaDict.from_map): Exclude extensions from
   processing.  (SchemaDict.link_refs): Exclude extensions from
   parameters and subsequent processing.
 * pyrseas/testutils.py InputMapToSqlTestCase.std_map): Don't include
   schema pg_catalog.
 * tests/dbobject/test_extension.py: Change maps to move extensions at
   the same level as schemas.
  • Loading branch information...
jmafc committed Jul 5, 2012
1 parent 8ec7ae0 commit 46bc738609495021fc247d3b21f991b654d0e0f2
View
@@ -4,14 +4,14 @@ Extensions
.. module:: pyrseas.dbobject.extension
The :mod:`extension` module defines two classes, :class:`Extension`
-and :class:`ExtensionDict`, derived from :class:`DbSchemaObject` and
+and :class:`ExtensionDict`, derived from :class:`DbObject` and
:class:`DbObjectDict`, respectively.
Extension
---------
:class:`Extension` is derived from
-:class:`~pyrseas.dbobject.DbSchemaObject` and represents a `PostgreSQL
+:class:`~pyrseas.dbobject.DbObject` and represents a `PostgreSQL
extension
<http://www.postgresql.org/docs/current/static/extend-extensions.html>`_.
@@ -31,4 +31,6 @@ represents the collection of extensions in a database.
.. automethod:: ExtensionDict.from_map
+.. automethod:: ExtensionDict.to_map
+
.. automethod:: ExtensionDict.diff_map
View
@@ -76,9 +76,10 @@ def __init__(self, dbconn=None):
:param dbconn: a DbConnection object
"""
+ self.schemas = SchemaDict(dbconn)
+ self.extensions = ExtensionDict(dbconn)
self.languages = LanguageDict(dbconn)
self.casts = CastDict(dbconn)
- self.schemas = SchemaDict(dbconn)
self.types = TypeDict(dbconn)
self.tables = ClassDict(dbconn)
self.columns = ColumnDict(dbconn)
@@ -99,7 +100,6 @@ def __init__(self, dbconn=None):
self.servers = ForeignServerDict(dbconn)
self.usermaps = UserMappingDict(dbconn)
self.ftables = ForeignTableDict(dbconn)
- self.extensions = ExtensionDict(dbconn)
self.collations = CollationDict(dbconn)
def __init__(self, dbname, user=None, pswd=None, host=None, port=None):
@@ -120,8 +120,7 @@ def _link_refs(self, db):
db.schemas.link_refs(db.types, db.tables, db.functions, db.operators,
db.operfams, db.operclasses, db.conversions,
db.tsconfigs, db.tsdicts, db.tsparsers,
- db.tstempls, db.ftables, db.extensions,
- db.collations)
+ db.tstempls, db.ftables, db.collations)
db.tables.link_refs(db.columns, db.constraints, db.indexes,
db.rules, db.triggers)
db.fdwrappers.link_refs(db.servers)
@@ -176,13 +175,16 @@ def from_map(self, input_map):
"""
self.ndb = self.Dicts()
input_schemas = {}
+ input_extens = {}
input_langs = {}
input_casts = {}
input_fdws = {}
input_ums = {}
for key in list(input_map.keys()):
if key.startswith('schema '):
input_schemas.update({key: input_map[key]})
+ elif key.startswith('extension '):
+ input_extens.update({key: input_map[key]})
elif key.startswith('language '):
input_langs.update({key: input_map[key]})
elif key.startswith('cast '):
@@ -193,6 +195,7 @@ def from_map(self, input_map):
input_ums.update({key: input_map[key]})
else:
raise KeyError("Expected typed object, found '%s'" % key)
+ self.ndb.extensions.from_map(input_extens)
self.ndb.languages.from_map(input_langs)
self.ndb.schemas.from_map(input_schemas, self.ndb)
self.ndb.casts.from_map(input_casts, self.ndb)
@@ -215,7 +218,8 @@ def to_map(self, schemas=[], tables=[], exclude_schemas=[],
"""
if not self.db:
self.from_catalog()
- dbmap = self.db.languages.to_map()
+ dbmap = self.db.extensions.to_map()
+ dbmap.update(self.db.languages.to_map())
dbmap.update(self.db.casts.to_map())
dbmap.update(self.db.fdwrappers.to_map())
dbmap.update(self.db.schemas.to_map())
@@ -276,10 +280,9 @@ def diff_map(self, input_map, schemas=[]):
if schemas:
self._trim_objects(schemas)
self.from_map(input_map)
- stmts = self.db.languages.diff_map(self.ndb.languages,
- self.dbconn.version)
+ stmts = self.db.extensions.diff_map(self.ndb.extensions)
+ stmts.append(self.db.languages.diff_map(self.ndb.languages))
stmts.append(self.db.schemas.diff_map(self.ndb.schemas))
- stmts.append(self.db.extensions.diff_map(self.ndb.extensions))
stmts.append(self.db.types.diff_map(self.ndb.types))
stmts.append(self.db.functions.diff_map(self.ndb.functions))
stmts.append(self.db.operators.diff_map(self.ndb.operators))
@@ -3,16 +3,16 @@
pyrseas.dbobject.extension
~~~~~~~~~~~~~~~~~~~~~~~~~~
- This module defines two classes: Extension derived from
- DbSchemaObject, and ExtensionDict derived from DbObjectDict.
+ This module defines two classes: Extension derived from DbObject,
+ and ExtensionDict derived from DbObjectDict.
"""
-from pyrseas.dbobject import DbObjectDict, DbSchemaObject, quote_id
+from pyrseas.dbobject import DbObjectDict, DbObject, quote_id
-class Extension(DbSchemaObject):
+class Extension(DbObject):
"""An extension"""
- keylist = ['schema', 'name']
+ keylist = ['name']
objtype = "EXTENSION"
def create(self):
@@ -39,12 +39,12 @@ class ExtensionDict(DbObjectDict):
cls = Extension
query = \
- """SELECT nspname AS schema, extname AS name, extversion AS version,
+ """SELECT extname AS name, nspname AS schema, extversion AS version,
obj_description(e.oid, 'pg_extension') AS description
FROM pg_extension e
JOIN pg_namespace n ON (extnamespace = n.oid)
WHERE nspname != 'information_schema'
- ORDER BY 1, 2"""
+ ORDER BY extname"""
def _from_catalog(self):
"""Initialize the dictionary of extensions by querying the catalogs"""
@@ -53,26 +53,37 @@ def _from_catalog(self):
for ext in self.fetch():
self[ext.key()] = ext
- def from_map(self, schema, inexts):
+ def from_map(self, inexts):
"""Initalize the dictionary of extensions by converting the input map
- :param schema: schema owning the extensions
:param inexts: YAML map defining the extensions
"""
for key in list(inexts.keys()):
if not key.startswith('extension '):
raise KeyError("Unrecognized object type: %s" % key)
ext = key[10:]
inexten = inexts[key]
- self[(schema.name, ext)] = exten = Extension(
- schema=schema.name, name=ext)
+ self[ext] = exten = Extension(name=ext)
for attr, val in list(inexten.items()):
setattr(exten, attr, val)
if 'oldname' in inexten:
exten.oldname = inexten['oldname']
if 'description' in inexten:
exten.description = inexten['description']
+ def to_map(self):
+ """Convert the extension dictionary to a regular dictionary
+
+ :return: dictionary
+
+ Invokes the `to_map` method of each extension to construct a
+ dictionary of extensions.
+ """
+ extens = {}
+ for ext in list(self.keys()):
+ extens.update(self[ext].to_map())
+ return extens
+
def diff_map(self, inexts):
"""Generate SQL to transform existing extensions
@@ -85,24 +96,24 @@ def diff_map(self, inexts):
"""
stmts = []
# check input extensions
- for (sch, ext) in list(inexts.keys()):
- inexten = inexts[(sch, ext)]
+ for ext in list(inexts.keys()):
+ inexten = inexts[ext]
# does it exist in the database?
- if (sch, ext) not in self:
+ if ext not in self:
if not hasattr(inexten, 'oldname'):
# create new extension
stmts.append(inexten.create())
else:
- stmts.append(self[(sch, ext)].rename(inexten))
+ stmts.append(self[ext].rename(inexten))
else:
# check extension objects
- stmts.append(self[(sch, ext)].diff_map(inexten))
+ stmts.append(self[ext].diff_map(inexten))
# check existing extensions
- for (sch, ext) in list(self.keys()):
- exten = self[(sch, ext)]
+ for ext in list(self.keys()):
+ exten = self[ext]
# if missing, drop them
- if (sch, ext) not in inexts:
+ if ext not in inexts:
stmts.append(exten.drop())
return stmts
@@ -106,11 +106,10 @@ def to_map(self):
languages.update(self[lng].to_map())
return languages
- def diff_map(self, inlanguages, dbversion):
+ def diff_map(self, inlanguages):
"""Generate SQL to transform existing languages
:param input_map: a YAML map defining the new languages
- :param dbversion: DBMS version number
:return: list of SQL statements
Compares the existing language definitions, as fetched from the
@@ -143,7 +142,8 @@ def diff_map(self, inlanguages, dbversion):
# if missing, drop it
if lng not in inlanguages:
# special case: plpgsql is installed in 9.0
- if dbversion >= 90000 and self[lng].name == 'plpgsql':
+ if self.dbconn.version >= 90000 \
+ and self[lng].name == 'plpgsql':
continue
self[lng].dropped = True
return stmts
View
@@ -45,7 +45,7 @@ def mapper(schema, objtypes):
for objtypes in ['conversions', 'domains', 'ftables', 'functions',
'operators', 'operclasses', 'operfams', 'sequences',
'tsconfigs', 'tsdicts', 'tsparsers', 'tstempls',
- 'types', 'views', 'extensions', 'collations']:
+ 'types', 'views', 'collations']:
schema[key].update(mapper(self, objtypes))
if hasattr(self, 'description'):
@@ -104,7 +104,6 @@ def from_map(self, inmap, newdb):
intsps = {}
intsts = {}
inftbs = {}
- inexts = {}
incolls = {}
for key in list(inschema.keys()):
if key.startswith('domain '):
@@ -135,8 +134,6 @@ def from_map(self, inmap, newdb):
intscs.update({key: inschema[key]})
elif key.startswith('foreign table '):
inftbs.update({key: inschema[key]})
- elif key.startswith('extension '):
- inexts.update({key: inschema[key]})
elif key.startswith('collation '):
incolls.update({key: inschema[key]})
elif key == 'oldname':
@@ -145,7 +142,6 @@ def from_map(self, inmap, newdb):
schema.description = inschema[key]
else:
raise KeyError("Expected typed object, found '%s'" % key)
- newdb.extensions.from_map(schema, inexts)
newdb.types.from_map(schema, intypes, newdb)
newdb.tables.from_map(schema, intables, newdb)
newdb.functions.from_map(schema, infuncs)
@@ -162,7 +158,7 @@ def from_map(self, inmap, newdb):
def link_refs(self, dbtypes, dbtables, dbfunctions, dbopers, dbopfams,
dbopcls, dbconvs, dbtsconfigs, dbtsdicts, dbtspars,
- dbtstmpls, dbftables, dbexts, dbcolls):
+ dbtstmpls, dbftables, dbcolls):
"""Connect types, tables and functions to their respective schemas
:param dbtypes: dictionary of types and domains
@@ -177,7 +173,6 @@ def link_refs(self, dbtypes, dbtables, dbfunctions, dbopers, dbopfams,
:param dbtspars: dictionary of text search parsers
:param dbtstmpls: dictionary of text search templates
:param dbftables: dictionary of foreign tables
- :param dbexts: dictionary of extensions
:param dbcolls: dictionary of collations
Fills in the `domains` dictionary for each schema by
@@ -297,13 +292,6 @@ def link_refs(self, dbtypes, dbtables, dbfunctions, dbopers, dbopfams,
if not hasattr(schema, 'ftables'):
schema.ftables = {}
schema.ftables.update({ftb: ftbl})
- for (sch, ext) in list(dbexts.keys()):
- exten = dbexts[(sch, ext)]
- assert self[sch]
- schema = self[sch]
- if not hasattr(schema, 'extensions'):
- schema.extensions = {}
- schema.extensions.update({ext: exten})
for (sch, cll) in list(dbcolls.keys()):
coll = dbcolls[(sch, cll)]
assert self[sch]
View
@@ -403,7 +403,6 @@ def std_map(self, plpgsql_installed=False):
and self.db._version < 90100:
base.update({'language plpgsql': {'trusted': True}})
if self.db._version >= 90100:
- base.update({'schema pg_catalog': {'extension plpgsql': {
- 'description': "PL/pgSQL procedural language"},
- 'description': 'system catalog schema'}})
+ base.update({'extension plpgsql': {'schema': 'pg_catalog',
+ 'description': "PL/pgSQL procedural language"}})
return base
@@ -19,8 +19,9 @@ def test_map_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
dbmap = self.to_map([CREATE_STMT])
- self.assertEqual(dbmap['schema public']['extension pg_trgm'],
- {'version': '1.0', 'description': TRGM_COMMENT})
+ self.assertEqual(dbmap['extension pg_trgm'], {
+ 'schema': 'public', 'version': '1.0',
+ 'description': TRGM_COMMENT})
def test_map_no_depends(self):
"Ensure no dependencies are included when mapping an extension"
@@ -36,18 +37,18 @@ def test_map_lang_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
dbmap = self.to_map(["CREATE EXTENSION plperl"])
- self.assertEqual(dbmap['schema pg_catalog']['extension plperl'],
- {'version': '1.0',
- 'description': "PL/Perl procedural language"})
+ self.assertEqual(dbmap['extension plperl'], {
+ 'schema': 'pg_catalog', 'version': '1.0',
+ 'description': "PL/Perl procedural language"})
self.assertFalse('language plperl' in dbmap)
def test_map_extension_schema(self):
"Map an existing extension"
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
dbmap = self.to_map(["CREATE SCHEMA s1", CREATE_STMT + " SCHEMA s1"])
- self.assertEqual(dbmap['schema s1']['extension pg_trgm'],
- {'version': '1.0', 'description': TRGM_COMMENT})
+ self.assertEqual(dbmap['extension pg_trgm'], {
+ 'schema': 's1', 'version': '1.0', 'description': TRGM_COMMENT})
class ExtensionToSqlTestCase(InputMapToSqlTestCase):
@@ -58,14 +59,14 @@ def test_create_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
inmap = self.std_map()
- inmap['schema public'].update({'extension pg_trgm': {}})
+ inmap.update({'extension pg_trgm': {'schema': 'public'}})
sql = self.to_sql(inmap)
self.assertEqual(sql, [CREATE_STMT])
def test_bad_extension_map(self):
"Error creating a extension with a bad map"
inmap = self.std_map()
- inmap['schema public'].update({'pg_trgm': {}})
+ inmap.update({'pg_trgm': {'schema': 'public'}})
self.assertRaises(KeyError, self.to_sql, inmap)
def test_drop_extension(self):
@@ -80,7 +81,7 @@ def test_create_extension_schema(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
inmap = self.std_map()
- inmap.update({'schema s1': {'extension pg_trgm': {'version': '1.0'}}})
+ inmap.update({'extension pg_trgm': {'schema': 's1', 'version': '1.0'}})
sql = self.to_sql(inmap, ["CREATE SCHEMA s1"])
self.assertEqual(fix_indent(sql[0]),
"CREATE EXTENSION pg_trgm SCHEMA s1 VERSION '1.0'")
@@ -90,8 +91,8 @@ def test_comment_extension(self):
if self.db.version < 90100:
self.skipTest('Only available on PG 9.1')
inmap = self.std_map()
- inmap['schema public'].update({'extension pg_trgm': {
- 'description': "Trigram extension"}})
+ inmap.update({'extension pg_trgm': {
+ 'schema': 'public', 'description': "Trigram extension"}})
sql = self.to_sql(inmap, [CREATE_STMT])
self.assertEqual(sql, [
"COMMENT ON EXTENSION pg_trgm IS 'Trigram extension'"])

0 comments on commit 46bc738

Please sign in to comment.