Skip to content

Commit

Permalink
Automatically check schema version when instantiating DbTaskHistory.
Browse files Browse the repository at this point in the history
Luigi will exit if schema version does not match source version.
  • Loading branch information
Stephen Pascoe authored and Stephen Pascoe committed Dec 14, 2015
1 parent d93b9ba commit 7c2873c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 43 deletions.
47 changes: 47 additions & 0 deletions luigi/db_task_history.py
Expand Up @@ -59,6 +59,8 @@ class DbTaskHistory(task_history.TaskHistory):
Task History that writes to a database using sqlalchemy.
Also has methods for useful db queries.
"""
CURRENT_SOURCE_VERSION = 1

@contextmanager
def _session(self, session=None):
if session:
Expand All @@ -81,6 +83,15 @@ def __init__(self):
Base.metadata.create_all(self.engine)
self.tasks = {} # task_id -> TaskRecord

self._check_version()

def _check_version(self):
with self._session() as session:
schema_version = get_schema_version(session)
if schema_version != self.CURRENT_SOURCE_VERSION:
raise SystemExit(("DbTaskHistory source version {} is not consistent with schema version {}. "
"Please run luigi-migrate.").format(self.CURRENT_SOURCE_VERSION, schema_version))

def task_scheduled(self, task):
htask = self._get_task(task, status=PENDING)
self._add_task_event(htask, TaskEvent(event_name=PENDING, ts=datetime.datetime.now()))
Expand Down Expand Up @@ -235,6 +246,42 @@ def __repr__(self):
return "TaskRecord(name=%s, host=%s)" % (self.name, self.host)


# ---------------------------------------------------------------------------
# Database schema management
#
version_table = sqlalchemy.Table('version', Base.metadata,
sqlalchemy.Column('version', sqlalchemy.Integer)
)


def get_schema_version(session):
version_row = session.execute(version_table.select()).first()
if version_row is None:
session.execute(version_table.insert(values={'version': 0}))
session.commit()

return 0
else:
return version_row[0]


def set_schema_version(version, session):
session.execute(version_table.update(values={'version': version}))


# ---------------------------------------------------------------------------
# Version migration functions

def version_1(session):
"""
Add task_id column to tasks table. Required to make a robust connection between
task_id and a TaskRecord.
"""

session.execute('ALTER TABLE tasks ADD COLUMN task_id VARCHAR(200)')
session.execute('CREATE INDEX ix_task_id ON tasks (task_id)')


# version_func[db-version]() --> next version
version_funcs = {0: version_1}
54 changes: 11 additions & 43 deletions luigi/tools/migrate.py
Expand Up @@ -10,29 +10,11 @@
"""

from __future__ import print_function
import sys
from luigi.db_task_history import DbTaskHistory, version_table
from luigi import configuration

CURRENT_SOURCE_VERSION = 1


# ---------------------------------------------------------------------------
# Version migration functions

def version_1(session):
"""
Add task_id column to tasks table. Required to make a robust connection between
task_id and a TaskRecord.
"""

session.execute('ALTER TABLE tasks ADD COLUMN task_id VARCHAR(200)')
session.execute('CREATE INDEX ix_task_id ON tasks (task_id)')

import sys

# version_func[db-version]() --> next version
version_funcs = {0: version_1}
from luigi import configuration
from luigi.db_task_history import DbTaskHistory, get_schema_version, set_schema_version, version_funcs


# ---------------------------------------------------------------------------
Expand All @@ -46,27 +28,28 @@ def main():
connection_string = config.get('task_history', 'db_connection')

print('Luigi db_task_history migration tool')
db_version = get_version(session)
db_version = get_schema_version(session)

if db_version == CURRENT_SOURCE_VERSION:
source_version = task_history.CURRENT_SOURCE_VERSION
if db_version == source_version:
print('Your schema version is up to date')
sys.exit(0)
elif db_version > CURRENT_SOURCE_VERSION:
elif db_version > source_version:
print('ERROR: Your schema version is greater than the source version ({}>{})'.format(db_version,
CURRENT_SOURCE_VERSION))
source_version))
sys.exit(1)

print('Migration required. '
'Your schema version is less than the source version ({}<{})'.format(db_version,
CURRENT_SOURCE_VERSION))
source_version))

print('******************************************************')
print('** WARNING Do not proceed without a database backup. ')
print('******************************************************')
print()

if query_yes_no('Do you want to migrate database {} now?'.format(connection_string), default='no'):
do_migration(db_version, CURRENT_SOURCE_VERSION, session)
do_migration(db_version, source_version, session)
else:
sys.exit(1)

Expand All @@ -79,24 +62,9 @@ def do_migration(from_version, to_version, session):
new_version = v + 1
print('Migrating version {} -> {}'.format(v, new_version))
version_funcs[v](session)
set_version(new_version, session)
session.commit()


def get_version(session):
version_row = session.execute(version_table.select()).first()
if version_row is None:
session.execute(version_table.insert(values={'version': 0}))
set_schema_version(new_version, session)
session.commit()

return 0
else:
return version_row[0]


def set_version(version, session):
session.execute(version_table.update(values={'version': version}))


def query_yes_no(question, default="yes"):
"""Ask a yes/no question via raw_input() and return their answer.
Expand Down

0 comments on commit 7c2873c

Please sign in to comment.