Skip to content

Commit

Permalink
Add support for inherited tables.
Browse files Browse the repository at this point in the history
 * pyrseas/dbobject/__init__.py (split_schema_table): New utility
   function split schema-qualified names.  (DbSchemaObject.unqualify):
   Use new function.
 * pyrseas/dbobject/column.py (Column.to_map): Map inherited columns.
   (ColumnDict.query): Retrieve attinhcount as 'inherited'.
 * pyrseas/dbobject/constraint.py (ConstraintDict._from_catalog):
   Recall code by call to split_schema_table.
 * pyrseas/dbobject/table.py (Sequence.get_owner): Replace code by
   call to split_schema_table.  (Table.to_map): Map inherited tables.
   (Table.create): Add INHERITS clause and deal with inherited
   columns.  (ClassDict.inhquery): New query from pg_inherits.
   (ClassDict._from_catalog): Fetch inheritance info.
   (ClassDict.from_map): Map 'inherits' info.  (ClassDict.link_refs):
   Add links for descendant tables.  (ClassDict.diff_map): Deal with
   inherited tables, both for create and drop.
 * tests/dbobject/test_table.py:  Add tests for inherited tables.
 * tests/dbobject/utils.py (PostgresDb.clear): Use IF EXISTS for
   dropping tables.
  • Loading branch information
jmafc committed May 6, 2011
1 parent 197a1a7 commit ffa144c
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 32 deletions.
21 changes: 17 additions & 4 deletions pyrseas/dbobject/__init__.py
Expand Up @@ -9,6 +9,22 @@
"""


def split_schema_table(tbl, sch=None):
"""Return a (schema, table) tuple given a possibly schema-qualified name
:param tbl: table name or schema.table
:return: tuple
"""
qualsch = sch
if sch == None:
qualsch = 'public'
if '.' in tbl:
(qualsch, tbl) = tbl.split('.')
if sch != qualsch:
sch = qualsch
return (sch, tbl)


class DbObject(object):
"A single object in a database catalog, e.g., a schema, a table, a column"

Expand Down Expand Up @@ -55,10 +71,7 @@ def qualname(self):
def unqualify(self):
"""Adjust the schema and table name if the latter is qualified"""
if hasattr(self, 'table') and '.' in self.table:
tbl = self.table
dot = tbl.index('.')
if self.schema == tbl[:dot]:
self.table = tbl[dot + 1:]
(sch, self.table) = split_schema_table(self.table, self.schema)

def comment(self):
"""Return a SQL COMMENT statement for the object
Expand Down
5 changes: 4 additions & 1 deletion pyrseas/dbobject/column.py
Expand Up @@ -23,6 +23,8 @@ def to_map(self):
for k in self.keylist:
del dct[k]
del dct['number'], dct['name'], dct['_table']
if hasattr(self, 'inherited'):
dct['inherited'] = (self.inherited != 0)
return {self.name: dct}

def add(self):
Expand Down Expand Up @@ -121,7 +123,8 @@ class ColumnDict(DbObjectDict):
query = \
"""SELECT nspname AS schema, relname AS table, attname AS name,
attnum AS number, format_type(atttypid, atttypmod) AS type,
attnotnull AS not_null, adsrc AS default, description
attnotnull AS not_null, attinhcount AS inherited,
adsrc AS default, description
FROM pg_attribute JOIN pg_class ON (attrelid = pg_class.oid)
JOIN pg_namespace ON (relnamespace = pg_namespace.oid)
JOIN pg_roles ON (nspowner = pg_roles.oid)
Expand Down
10 changes: 3 additions & 7 deletions pyrseas/dbobject/constraint.py
Expand Up @@ -8,7 +8,7 @@
UniqueConstraint derived from Constraint, and ConstraintDict
derived from DbObjectDict.
"""
from pyrseas.dbobject import DbObjectDict, DbSchemaObject
from pyrseas.dbobject import DbObjectDict, DbSchemaObject, split_schema_table

ACTIONS = {'r': 'restrict', 'c': 'cascade', 'n': 'set null',
'd': 'set default'}
Expand Down Expand Up @@ -274,12 +274,8 @@ def _from_catalog(self):
else:
constr.on_delete = ACTIONS[constr.on_delete]
reftbl = constr.ref_table
if '.' in reftbl:
dot = reftbl.index('.')
constr.ref_table = reftbl[dot + 1:]
constr.ref_schema = reftbl[:dot]
else:
constr.ref_schema = 'public'
(constr.ref_schema, constr.ref_table) = split_schema_table(
reftbl)
self[(sch, tbl, cns)] = ForeignKey(**constr.__dict__)
elif constr_type == 'u':
self[(sch, tbl, cns)] = UniqueConstraint(**constr.__dict__)
Expand Down
90 changes: 71 additions & 19 deletions pyrseas/dbobject/table.py
Expand Up @@ -9,7 +9,7 @@
"""
import sys

from pyrseas.dbobject import DbObjectDict, DbSchemaObject
from pyrseas.dbobject import DbObjectDict, DbSchemaObject, split_schema_table
from constraint import CheckConstraint, PrimaryKey, ForeignKey, \
UniqueConstraint

Expand Down Expand Up @@ -72,12 +72,8 @@ def get_owner(self, dbconn):
WHERE objid = '%s'::regclass
AND refclassid = 'pg_class'::regclass""" % self.qualname())
if data:
self.owner_table = tbl = data[0]
(sch, self.owner_table) = split_schema_table(data[0], self.schema)
self.owner_column = data[1]
if '.' in tbl:
dot = tbl.index('.')
if self.schema == tbl[:dot]:
self.owner_table = tbl[dot + 1:]

def to_map(self):
"""Convert a sequence definition to a YAML-suitable format
Expand Down Expand Up @@ -225,6 +221,9 @@ def to_map(self, dbschemas):
for k in self.indexes.values():
tbl['indexes'].update(self.indexes[k.name].to_map(
self.column_names()))
if hasattr(self, 'inherits'):
if not 'inherits' in tbl:
tbl['inherits'] = self.inherits

return {self.extern_key(): tbl}

Expand All @@ -236,9 +235,13 @@ def create(self):
stmts = []
cols = []
for col in self.columns:
cols.append(" " + col.add())
stmts.append("CREATE TABLE %s (\n%s)" % (self.qualname(),
",\n".join(cols)))
if not (hasattr(col, 'inherited') and col.inherited):
cols.append(" " + col.add())
inhclause = ''
if hasattr(self, 'inherits'):
inhclause = " INHERITS (%s)" % ", ".join(t for t in self.inherits)
stmts.append("CREATE TABLE %s (\n%s)%s" % (
self.qualname(), ",\n".join(cols), inhclause))
if hasattr(self, 'description'):
stmts.append(self.comment())
for col in self.columns:
Expand Down Expand Up @@ -347,6 +350,12 @@ class ClassDict(DbObjectDict):
AND (nspname = 'public' OR rolname <> 'postgres')
ORDER BY nspname, relname"""

inhquery = \
"""SELECT inhrelid::regclass AS sub, inhparent::regclass AS parent,
inhseqno
FROM pg_inherits
ORDER BY 1, 3"""

def _from_catalog(self):
"""Initialize the dictionary of tables by querying the catalogs"""
for table in self.fetch():
Expand All @@ -361,6 +370,12 @@ def _from_catalog(self):
inst.get_owner(self.dbconn)
elif kind == 'v':
self[(sch, tbl)] = View(**table.__dict__)
for (tbl, partbl, num) in self.dbconn.fetchall(self.inhquery):
(sch, tbl) = split_schema_table(tbl)
table = self[(sch, tbl)]
if not hasattr(table, 'inherits'):
table.inherits = []
table.inherits.append(partbl)

def from_map(self, schema, inobjs, newdb):
"""Initalize the dictionary of tables by converting the input map
Expand Down Expand Up @@ -388,6 +403,8 @@ def from_map(self, schema, inobjs, newdb):
except KeyError, exc:
exc.args = ("Table '%s' has no columns" % key, )
raise
if 'inherits' in intable:
table.inherits = intable['inherits']
if 'oldname' in intable:
table.oldname = intable['oldname']
newdb.constraints.from_map(table, intable)
Expand Down Expand Up @@ -437,12 +454,19 @@ def link_refs(self, dbcolumns, dbconstrs, dbindexes):
for col in dbcolumns[(sch, tbl)]:
col._table = self[(sch, tbl)]
for (sch, tbl) in self.keys():
if isinstance(self[(sch, tbl)], Sequence):
seq = self[(sch, tbl)]
if hasattr(seq, 'owner_table'):
if isinstance(seq.owner_column, int):
seq.owner_column = self[(sch, seq.owner_table)]. \
column_names()[seq.owner_column - 1]
table = self[(sch, tbl)]
if isinstance(table, Sequence) and hasattr(table, 'owner_table'):
if isinstance(table.owner_column, int):
table.owner_column = self[(sch, table.owner_table)]. \
column_names()[table.owner_column - 1]
elif isinstance(table, Table) and hasattr(table, 'inherits'):
for partbl in table.inherits:
(parsch, partbl) = split_schema_table(partbl)
assert self[(parsch, partbl)]
parent = self[(parsch, partbl)]
if not hasattr(parent, 'descendants'):
parent.descendants = []
parent.descendants.append(table)
for (sch, tbl, cns) in dbconstrs.keys():
constr = dbconstrs[(sch, tbl, cns)]
if (sch, tbl) not in self: # check constraints on domains
Expand Down Expand Up @@ -513,17 +537,31 @@ def diff_map(self, intables):
stmts.append(inseq.create())

# check input tables
inhstack = []
for (sch, tbl) in intables.keys():
intable = intables[(sch, tbl)]
if not isinstance(intable, Table):
continue
# does it exist in the database?
if (sch, tbl) not in self:
if hasattr(intable, 'oldname'):
stmts.append(self._rename(intable, "table"))
else:
if not hasattr(intable, 'oldname'):
# create new table
stmts.append(intable.create())
if hasattr(intable, 'inherits'):
inhstack.append(intable)
else:
stmts.append(intable.create())
else:
stmts.append(self._rename(intable, "table"))
while len(inhstack):
intable = inhstack.pop()
createit = True
for partbl in intable.inherits:
if intables[split_schema_table(partbl)] in inhstack:
createit = False
if createit:
stmts.append(intable.create())
else:
inhstack.insert(0, intable)

# check input views
for (sch, tbl) in intables.keys():
Expand Down Expand Up @@ -577,6 +615,7 @@ def diff_map(self, intables):
if isinstance(table, View):
stmts.append(table.drop())

inhstack = []
for (sch, tbl) in self.keys():
table = self[(sch, tbl)]
if isinstance(table, Sequence) and hasattr(table, 'owner_table') \
Expand All @@ -599,7 +638,20 @@ def diff_map(self, intables):
stmts.append(table.referred_by.drop())
stmts.append(table.primary_key.drop())
# finally, drop the table itself
if hasattr(table, 'descendants'):
inhstack.append(table)
else:
stmts.append(table.drop())
while len(inhstack):
table = inhstack.pop()
dropit = True
for childtbl in table.descendants:
if self[(childtbl.schema, childtbl.name)] in inhstack:
dropit = False
if dropit:
stmts.append(table.drop())
else:
inhstack.insert(0, table)

# last pass to deal with nextval DEFAULTs
for (sch, tbl) in intables.keys():
Expand Down
47 changes: 47 additions & 0 deletions tests/dbobject/test_table.py
Expand Up @@ -172,6 +172,19 @@ def test_map_column_comments(self):
dbmap = self.db.execute_and_map(ddlstmt)
self.assertEqual(dbmap['schema public']['table t1'], expmap)

def test_map_inherit(self):
"Map a table that inherits from two other tables"
self.db.execute(CREATE_STMT)
self.db.execute("CREATE TABLE t2 (c3 integer)")
ddlstmt = "CREATE TABLE t3 (c4 text) INHERITS (t1, t2)"
expmap = {'columns': [{'c1': {'type': 'integer', 'inherited': True}},
{'c2': {'type': 'text', 'inherited': True}},
{'c3': {'type': 'integer', 'inherited': True}},
{'c4': {'type': 'text'}}],
'inherits': ['t1', 't2']}
dbmap = self.db.execute_and_map(ddlstmt)
self.assertEqual(dbmap['schema public']['table t3'], expmap)


class TableToSqlTestCase(PyrseasTestCase):
"""Test SQL generation of table statements from input schemas"""
Expand Down Expand Up @@ -381,12 +394,46 @@ def test_create_column_comments(self):
"COMMENT ON COLUMN t1.c2 IS 'Test column c2'")


class TableInheritToSqlTestCase(PyrseasTestCase):
"""Test SQL generation of table inheritance statements"""

def test_table_inheritance(self):
"Create a table that inherits from another"
self.db.execute_commit(DROP_STMT)
inmap = new_std_map()
inmap['schema public'].update({'table t1': {
'columns': [{'c1': {'type': 'integer'}},
{'c2': {'type': 'text'}}]}})
inmap['schema public'].update({'table t2': {
'columns': [{'c1': {'type': 'integer', 'inherited': True}},
{'c2': {'type': 'text', 'inherited': True}},
{'c3': {'type': 'numeric'}}],
'inherits': ['t1']}})
dbsql = self.db.process_map(inmap)
self.assertEqual(fix_indent(dbsql[0]), CREATE_STMT)
self.assertEqual(fix_indent(dbsql[1]), "CREATE TABLE t2 (c3 numeric) "
"INHERITS (t1)")

def test_drop_inherited(self):
"Drop tables that inherit from others"
self.db.execute(DROP_STMT)
self.db.execute(CREATE_STMT)
self.db.execute("CREATE TABLE t2 (c3 numeric) INHERITS (t1)")
self.db.execute_commit("CREATE TABLE t3 (c4 date) INHERITS (t2)")
inmap = new_std_map()
dbsql = self.db.process_map(inmap)
self.assertEqual(dbsql, ["DROP TABLE t3", "DROP TABLE t2",
"DROP TABLE t1"])


def suite():
tests = unittest.TestLoader().loadTestsFromTestCase(TableToMapTestCase)
tests.addTest(unittest.TestLoader().loadTestsFromTestCase(
TableToSqlTestCase))
tests.addTest(unittest.TestLoader().loadTestsFromTestCase(
TableCommentToSqlTestCase))
tests.addTest(unittest.TestLoader().loadTestsFromTestCase(
TableInheritToSqlTestCase))
return tests

if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion tests/dbobject/utils.py
Expand Up @@ -112,7 +112,8 @@ def clear(self):
self.conn.rollback()
for obj in objs:
if obj['relkind'] == 'r':
self.execute("DROP TABLE %s.%s CASCADE" % (obj[0], obj[1]))
self.execute("DROP TABLE IF EXISTS %s.%s CASCADE" % (
obj[0], obj[1]))
elif obj['relkind'] == 'S':
self.execute("DROP SEQUENCE %s.%s CASCADE" % (obj[0], obj[1]))
elif obj['relkind'] == 'v':
Expand Down

0 comments on commit ffa144c

Please sign in to comment.