From ebadbeb44651d866df23b39ec9b397c000472de3 Mon Sep 17 00:00:00 2001 From: Brian Rosner Date: Sat, 4 Sep 2010 12:45:13 -0600 Subject: [PATCH] changed interface to --seed a little to take a migration to stop at as opposed to a list of migrations to seed --- nashvegas/management/commands/upgradedb.py | 35 +++++++++++++++------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/nashvegas/management/commands/upgradedb.py b/nashvegas/management/commands/upgradedb.py index 1ae0abb..912a3f9 100644 --- a/nashvegas/management/commands/upgradedb.py +++ b/nashvegas/management/commands/upgradedb.py @@ -8,7 +8,7 @@ from django.db.models import get_model from django.conf import settings from django.core.management import call_command -from django.core.management.base import BaseCommand +from django.core.management.base import BaseCommand, CommandError from django.core.management.sql import emit_post_sync_signal from nashvegas.models import Migration @@ -18,6 +18,10 @@ sys.path.append("migrations") +class MigrationError(Exception): + pass + + class Command(BaseCommand): option_list = BaseCommand.option_list + ( @@ -44,7 +48,10 @@ class Command(BaseCommand): help="The path to the database migration scripts.")) help = "Upgrade database." - def _filter_down(self): + def _filter_down(self, stop_at=None): + + if stop_at is None: + stop_at = float("inf") applied = [] to_execute = [] @@ -61,11 +68,16 @@ def _filter_down(self): applied.sort() for script in in_directory: - if os.path.splitext(script)[-1] in [".sql", ".py"]: - scripts_in_directory.append(script) + name, ext = os.path.splitext(script) + try: + number = int(name.split("_")[0]) + except ValueError: + raise MigrationError("Invalid migration file prefix (must begin with a number)") + if ext in [".sql", ".py"]: + scripts_in_directory.append((number, script)) - for script in scripts_in_directory: - if script not in applied: + for number, script in scripts_in_directory: + if script not in applied and number < stop_at: to_execute.append(script) except OSError, e: print str(e) @@ -183,10 +195,13 @@ def execute_migrations(self): database=self.db ) - def seed_migrations(self): - migrations = [os.path.join(self.path, m) for m in self._filter_down()] - if len(self.args) > 0: - migrations = [arg for arg in self.args if not arg.endswith(".pyc")] + def seed_migrations(self, stop_at=None): + # @@@ the command-line interface needs to be re-thinked + try: + stop_at = int(self.args[0]) + except ValueError: + raise CommandError("Invalid --seed migration") + migrations = [os.path.join(self.path, m) for m in self._filter_down(stop_at=stop_at)] for migration in migrations: m, created = Migration.objects.get_or_create( migration_label=os.path.basename(migration),