Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Refactor utilities

This moves several helper functions out of upgradedb and into utils
for future use (and API support outside of the manage.py command).

It also introduces the beginnings of a test suite.
  • Loading branch information...
commit 77b6b717d271d275507cdf4725a5bb340b230e0c 1 parent bf38d07
@dcramer dcramer authored
View
2  nashvegas/exceptions.py
@@ -0,0 +1,2 @@
+class MigrationError(Exception):
+ pass
View
141 nashvegas/management/commands/upgradedb.py
@@ -3,11 +3,10 @@
import sys
import traceback
-from collections import defaultdict
from optparse import make_option
from subprocess import Popen, PIPE
-from django.db import connections, router, transaction, DEFAULT_DB_ALIAS
+from django.db import connections, transaction, DEFAULT_DB_ALIAS
from django.db.models import get_model
from django.conf import settings
from django.core.management import call_command
@@ -15,8 +14,10 @@
from django.core.management.sql import emit_post_sync_signal
from django.utils.importlib import import_module
+from nashvegas.exceptions import MigrationError
from nashvegas.models import Migration
-from nashvegas.utils import get_sql_for_new_models
+from nashvegas.utils import get_sql_for_new_models, get_capable_databases, \
+ get_pending_migrations
sys.path.append("migrations")
@@ -39,10 +40,6 @@ def __exit__(self, exc_type, exc_value, traceback):
transaction.leave_transaction_management(using=db)
-class MigrationError(Exception):
- pass
-
-
class Command(BaseCommand):
option_list = BaseCommand.option_list + (
@@ -61,7 +58,7 @@ class Command(BaseCommand):
make_option("-s", "--seed", action="store_true",
dest="do_seed", default=False,
help="Seed nashvegas with migrations that have previously been applied in another manner."),
- make_option("-d", "--database", action="store", dest="database",
+ make_option("-d", "--database", action="append", dest="databases",
help="Nominates a database to synchronize."),
make_option("--noinput", action="store_false", dest="interactive", default=False,
help="Tells Django to NOT prompt the user for input of any kind."),
@@ -71,99 +68,6 @@ class Command(BaseCommand):
help = "Upgrade database."
- def _get_capable_databases(self):
- """
- Returns a list of databases which are capable of supporting
- Nashvegas (based on their routing configuration).
- """
- for database in connections:
- if router.allow_syncdb(database, Migration):
- yield database
-
- def _get_file_list(self, path, max_depth=1, cur_depth=0):
- """
- Recursively returns a list of all files up to ``max_depth``
- in a directory.
- """
- for name in os.listdir(path):
- if name.startswith('.'):
- continue
-
- full_path = os.path.join(path, name)
- if os.path.isdir(full_path):
- if cur_depth == max_depth:
- continue
-
- for result in self._get_file_list(full_path, max_depth, cur_depth + 1):
- yield result
-
- else:
- yield full_path
-
- def _get_applied_migrations(self):
- """
- Returns a dictionary containing lists of all applied migrations
- where the key is the database alias.
- """
- results = defaultdict(list)
- for database in self._get_capable_databases():
- for x in Migration.objects.using(database).order_by("migration_label"):
- results[database].append(x.migration_label)
- return results
-
- def _filter_down(self, stop_at=None):
- if stop_at is None:
- stop_at = float("inf")
-
- # database: [(number, full_path)]
- possible_migrations = defaultdict(list)
- # database: [full_path]
- applied_migrations = self._get_applied_migrations()
- # database: [full_path]
- to_execute = defaultdict(list)
-
- try:
- in_directory = sorted(self._get_file_list(self.path))
- except OSError:
- print "An error occurred while reading migrations from %r:" % self.path
- traceback.print_exc()
- return to_execute
-
- # Iterate through our results and discover which migrations are actually runnable
- for full_path in in_directory:
- path, script = os.path.split(full_path)
- name, ext = os.path.splitext(script)
-
- # the database component is default if this is in the root directory
- # is <directory> if in a subdirectory
- if path == self.path:
- database = DEFAULT_DB_ALIAS
- else:
- database = os.path.split(path)[-1]
-
- # filter by database if set
- if self.db and database != self.db:
- continue
-
- match = MIGRATION_NAME_RE.match(name)
- if match is None:
- raise MigrationError("Invalid migration file prefix %r "
- "(must begin with a number)" % name)
-
- number = int(match.group(1))
- if ext in [".sql", ".py"]:
- possible_migrations[database].append((number, full_path))
-
- for database, scripts in possible_migrations.iteritems():
- applied = applied_migrations[database]
- pending = to_execute[database]
- for number, migration in scripts:
- path, script = os.path.split(migration)
- if script not in applied and number <= stop_at:
- pending.append(script)
-
- return dict((k, v) for k, v in to_execute.iteritems() if v)
-
def _get_rev(self, fpath):
"""
Get an SCM version number. Try svn and git.
@@ -286,7 +190,8 @@ def init_nashvegas(self):
raise
# @@@ make cleaner / check explicitly for model instead of looping over and doing string comparisons
- for database in self._get_capable_databases():
+ databases = self.databases or get_capable_databases()
+ for database in databases:
connection = connections[database]
cursor = connection.cursor()
all_new = get_sql_for_new_models(['nashvegas'], using=database)
@@ -300,7 +205,7 @@ def init_nashvegas(self):
transaction.commit_unless_managed(using=database)
def create_all_migrations(self):
- for database in self._get_capable_databases():
+ for database in get_capable_databases():
statements = get_sql_for_new_models(using=database)
if len(statements) == 0:
continue
@@ -332,7 +237,7 @@ def execute_migrations(self, show_traceback=True):
Executes all pending migrations across all capable
databases
"""
- all_migrations = self._filter_down()
+ all_migrations = get_pending_migrations(self.path, self.databases)
if not len(all_migrations):
sys.stdout.write("There are no migrations to apply.\n")
@@ -380,7 +285,7 @@ def seed_migrations(self, stop_at=None):
except IndexError:
raise CommandError("Usage: ./manage.py upgradedb --seed [stop_at]")
- all_migrations = self._filter_down(stop_at=stop_at)
+ all_migrations = get_pending_migrations(self.path, self.databases, stop_at=stop_at)
for db, migrations in all_migrations.iteritems():
for migration in migrations:
migration_path = self._get_migration_path(db, migration)
@@ -398,7 +303,7 @@ def seed_migrations(self, stop_at=None):
print "%s:%s was already applied" % (db, m.migration_label)
def list_migrations(self):
- all_migrations = self._filter_down()
+ all_migrations = get_pending_migrations(self.path, self.databases)
if len(all_migrations) == 0:
print "There are no migrations to apply."
return
@@ -408,6 +313,13 @@ def list_migrations(self):
for script in migrations:
print "\t%s: %s" % (database, script)
+ def _get_default_migration_path(self):
+ try:
+ path = os.path.dirname(os.path.normpath(os.sys.modules[settings.SETTINGS_MODULE].__file__))
+ except KeyError:
+ path = os.getcwd()
+ return os.path.join(path, "migrations")
+
def handle(self, *args, **options):
"""
Upgrades the database.
@@ -425,23 +337,16 @@ def handle(self, *args, **options):
if options.get("path"):
self.path = options.get("path")
else:
- default_path = os.path.join(
- os.path.dirname(
- os.path.normpath(
- os.sys.modules[settings.SETTINGS_MODULE].__file__
- )
- ),
- "migrations"
- )
+ default_path = self._get_default_migration_path()
self.path = getattr(settings, "NASHVEGAS_MIGRATIONS_DIRECTORY", default_path)
self.verbosity = int(options.get("verbosity", 1))
self.interactive = options.get("interactive")
- self.db = options.get("database")
+ self.databases = options.get("databases")
# We only use the default alias in creation scenarios (upgrades default to all databases)
- if self.do_create and not self.db:
- self.db = DEFAULT_DB_ALIAS
+ if self.do_create and not self.databases:
+ self.databases = [DEFAULT_DB_ALIAS]
if self.do_create and self.do_create_all:
raise CommandError("You cannot combine --create and --create-all")
@@ -451,7 +356,7 @@ def handle(self, *args, **options):
if self.do_create_all:
self.create_all_migrations()
elif self.do_create:
- self.create_migrations(self.db)
+ self.create_migrations(self.databases)
if self.do_execute:
self.execute_migrations()
View
131 nashvegas/utils.py
@@ -1,7 +1,16 @@
+import itertools
+import os.path
+import re
+
+from collections import defaultdict
from django.core.management.color import no_style
from django.core.management.sql import custom_sql_for_model
from django.db import connections, router, models, DEFAULT_DB_ALIAS
from django.utils.datastructures import SortedDict
+from nashvegas.exceptions import MigrationError
+from nashvegas.models import Migration
+
+MIGRATION_NAME_RE = re.compile(r"(\d+)(.*)")
def get_sql_for_new_models(apps=None, using=DEFAULT_DB_ALIAS):
@@ -108,3 +117,125 @@ def model_installed(model):
statements.extend(index_sql)
return statements
+
+
+def get_capable_databases():
+ """
+ Returns a list of databases which are capable of supporting
+ Nashvegas (based on their routing configuration).
+ """
+ for database in connections:
+ if router.allow_syncdb(database, Migration):
+ yield database
+
+
+def get_file_list(path, max_depth=1, cur_depth=0):
+ """
+ Recursively returns a list of all files up to ``max_depth``
+ in a directory.
+ """
+ if os.path.exists(path):
+ for name in os.listdir(path):
+ if name.startswith('.'):
+ continue
+
+ full_path = os.path.join(path, name)
+ if os.path.isdir(full_path):
+ if cur_depth == max_depth:
+ continue
+
+ for result in get_file_list(full_path, max_depth, cur_depth + 1):
+ yield result
+
+ else:
+ yield full_path
+
+
+def get_applied_migrations(databases=None):
+ """
+ Returns a dictionary containing lists of all applied migrations
+ where the key is the database alias.
+ """
+ if not databases:
+ databases = get_capable_databases()
+ else:
+ # We only loop through databases that are listed as "capable"
+ all_databases = list(get_capable_databases())
+ databases = list(itertools.ifilter(lambda x: x in all_databases, databases))
+
+ results = defaultdict(list)
+ for db in databases:
+ for x in Migration.objects.using(db).order_by("migration_label"):
+ results[db].append(x.migration_label)
+
+ return results
+
+
+def get_all_migrations(path, databases=None):
+ """
+ Returns a dictionary of database => [migrations] representing all
+ migrations contained in ``path``.
+ """
+ # database: [(number, full_path)]
+ possible_migrations = defaultdict(list)
+
+ try:
+ in_directory = sorted(get_file_list(path))
+ except OSError:
+ import traceback
+ print "An error occurred while reading migrations from %r:" % path
+ traceback.print_exc()
+ return {}
+
+ # Iterate through our results and discover which migrations are actually runnable
+ for full_path in in_directory:
+ child_path, script = os.path.split(full_path)
+ name, ext = os.path.splitext(script)
+
+ # the database component is default if this is in the root directory
+ # is <directory> if in a subdirectory
+ if path == child_path:
+ db = DEFAULT_DB_ALIAS
+ else:
+ db = os.path.split(child_path)[-1]
+
+ # filter by database if set
+ if databases and db not in databases:
+ continue
+
+ match = MIGRATION_NAME_RE.match(name)
+ if match is None:
+ raise MigrationError("Invalid migration file prefix %r "
+ "(must begin with a number)" % name)
+
+ number = int(match.group(1))
+ if ext in [".sql", ".py"]:
+ possible_migrations[db].append((number, full_path))
+
+ return possible_migrations
+
+
+def get_pending_migrations(path, databases=None, stop_at=None):
+ """
+ Returns a dictionary of database => [migrations] representing all pending
+ migrations.
+ """
+ if stop_at is None:
+ stop_at = float("inf")
+
+ # database: [(number, full_path)]
+ possible_migrations = get_all_migrations(path, databases)
+ # database: [full_path]
+ applied_migrations = get_applied_migrations(databases)
+ # database: [full_path]
+ to_execute = defaultdict(list)
+
+ for database, scripts in possible_migrations.iteritems():
+ applied = applied_migrations[database]
+ pending = to_execute[database]
+ for number, migration in scripts:
+ path, script = os.path.split(migration)
+ if script not in applied and number <= stop_at:
+ pending.append(script)
+
+ return dict((k, v) for k, v in to_execute.iteritems() if v)
View
56 runtests.py
@@ -0,0 +1,56 @@
+#!/usr/bin/env python
+import sys
+from os.path import dirname, abspath
+
+sys.path.insert(0, dirname(abspath(__file__)))
+
+from django.conf import settings
+
+if not settings.configured:
+ settings.configure(
+ DATABASES={
+ 'default': {
+ 'ENGINE': 'django.db.backends.sqlite3',
+ 'NAME': ':memory:',
+ },
+ 'other': {
+ 'ENGINE': 'django.db.backends.sqlite3',
+ 'NAME': ':memory:',
+ },
+ },
+ INSTALLED_APPS=[
+ 'nashvegas',
+ 'tests',
+ ],
+ ROOT_URLCONF='',
+ DEBUG=False,
+ SITE_ID=1,
+ TEMPLATE_DEBUG=True,
+ )
+
+from django_nose import NoseTestSuiteRunner
+
+
+def runtests(*test_args, **kwargs):
+ if 'south' in settings.INSTALLED_APPS:
+ from south.management.commands import patch_for_test_db_setup
+ patch_for_test_db_setup()
+
+ if not test_args:
+ test_args = ['tests']
+
+ kwargs.setdefault('interactive', False)
+
+ test_runner = NoseTestSuiteRunner(**kwargs)
+
+ failures = test_runner.run_tests(test_args)
+ sys.exit(failures)
+
+if __name__ == '__main__':
+ from optparse import OptionParser
+ parser = OptionParser()
+ parser.add_option('--verbosity', dest='verbosity', action='store', default=1, type=int)
+ parser.add_options(NoseTestSuiteRunner.options)
+ (options, args) = parser.parse_args()
+
+ runtests(*args, **options.__dict__)
View
8 setup.py
@@ -4,9 +4,15 @@
VERSION = __import__("nashvegas").__version__
+
def read(*path):
return open(os.path.join(os.path.abspath(os.path.dirname(__file__)), *path)).read()
+tests_require = [
+ 'nose>=1.1.2',
+ 'django-nose>=0.1.3',
+]
+
setup(
name="nashvegas",
@@ -19,6 +25,8 @@ def read(*path):
maintainer_email="paltman@gmail.com",
url="http://github.com/paltman/nashvegas/",
packages=find_packages(),
+ tests_require=tests_require,
+ test_suite='runtests.runtests',
zip_safe=False,
classifiers=[
"Development Status :: 4 - Beta",
View
0  tests/__init__.py
No changes.
View
0  tests/fixtures/__init__.py
No changes.
View
0  tests/fixtures/migrations/__init__.py
No changes.
View
0  tests/fixtures/migrations/legacy/0001.sql
No changes.
View
0  tests/fixtures/migrations/legacy/0002.sql
No changes.
View
0  tests/fixtures/migrations/legacy/0003.py
No changes.
View
0  tests/fixtures/migrations/legacy/0004_with_label.sql
No changes.
View
0  tests/fixtures/migrations/multidb/default/0001.sql
No changes.
View
0  tests/fixtures/migrations/multidb/default/0002_foo.py
No changes.
View
0  tests/fixtures/migrations/multidb/other/0001.sql
No changes.
View
0  tests/fixtures/migrations/multidb/other/0002_bar.sql
No changes.
View
0  tests/nashvegas/__init__.py
No changes.
View
0  tests/nashvegas/utils/__init__.py
No changes.
View
45 tests/nashvegas/utils/tests.py
@@ -0,0 +1,45 @@
+from django.test import TestCase
+from nashvegas.utils import get_capable_databases, get_all_migrations, \
+ get_file_list
+
+from os.path import join, dirname
+
+mig_root = join(dirname(__import__('tests', {}, {}, [], -1).__file__), 'fixtures', 'migrations')
+
+
+class GetCapableDatabasesTest(TestCase):
+ def test_default_routing(self):
+ results = list(get_capable_databases())
+ self.assertEquals(len(results), 2)
+ self.assertTrue('default' in results)
+ self.assertTrue('other' in results)
+
+
+class GetFileListTest(TestCase):
+ def test_recursion(self):
+ path = join(mig_root, 'multidb')
+ results = list(get_file_list(path))
+ self.assertEquals(len(results), 4)
+ self.assertTrue(join(path, 'default', '0001.sql') in results)
+ self.assertTrue(join(path, 'default', '0002_foo.py') in results)
+ self.assertTrue(join(path, 'other', '0001.sql') in results)
+ self.assertTrue(join(path, 'other', '0002_bar.sql') in results)
+
+
+class GetAllMigrationsTest(TestCase):
+ def test_multidb(self):
+ path = join(mig_root, 'multidb')
+ results = dict(get_all_migrations(path))
+ self.assertEquals(len(results), 2)
+ self.assertTrue('default' in results)
+ self.assertTrue('other' in results)
+
+ default = results['default']
+ self.assertEquals(len(default), 2)
+ self.assertTrue((1, join(path, 'default', '0001.sql')) in default)
+ self.assertTrue((2, join(path, 'default', '0002_foo.py')) in default)
+
+ other = results['other']
+ self.assertEquals(len(other), 2)
+ self.assertTrue((1, join(path, 'other', '0001.sql')) in other)
+ self.assertTrue((2, join(path, 'other', '0002_bar.sql')) in other)
Please sign in to comment.
Something went wrong with that request. Please try again.