Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task id hashing #1444

Merged
merged 6 commits into from Dec 16, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 1 addition & 8 deletions luigi/contrib/sge.py
Expand Up @@ -110,13 +110,6 @@ def output(self):
POLL_TIME = 5 # decided to hard-code rather than configure here


def _clean_task_id(task_id):
"""Clean the task ID so qsub allows it as a "name" string."""
for c in ['\n', '\t', '\r', '/', ':', '@', '\\', '*', '?', ',', '=', ' ', '(', ')']:
task_id = task_id.replace(c, '-')
return task_id


def _parse_qstat_state(qstat_out, job_id):
"""Parse "state" column from `qstat` output for given job_id

Expand Down Expand Up @@ -200,7 +193,7 @@ def _init_local(self):
# Set up temp folder in shared directory (trim to max filename length)
base_tmp_dir = self.shared_tmp_dir
random_id = '%016x' % random.getrandbits(64)
folder_name = _clean_task_id(self.task_id) + '-' + random_id
folder_name = self.task_id + '-' + random_id
self.tmp_dir = os.path.join(base_tmp_dir, folder_name)
max_filename_length = os.fstatvfs(0).f_namemax
self.tmp_dir = self.tmp_dir[:max_filename_length]
Expand Down
2 changes: 1 addition & 1 deletion luigi/contrib/simulate.py
Expand Up @@ -98,7 +98,7 @@ def done(self):
"""
Creates temporary file to mark the task as `done`
"""
logger.info('Marking %s as done', self.task_id)
logger.info('Marking %s as done', self)

fn = self.get_path()
os.makedirs(os.path.dirname(fn), exist_ok=True)
Expand Down
24 changes: 23 additions & 1 deletion luigi/db_task_history.py
Expand Up @@ -49,6 +49,7 @@
import sqlalchemy.ext.declarative
import sqlalchemy.orm
import sqlalchemy.orm.collections
from sqlalchemy.engine import reflection
Base = sqlalchemy.ext.declarative.declarative_base()

logger = logging.getLogger('luigi-interface')
Expand All @@ -59,6 +60,8 @@ class DbTaskHistory(task_history.TaskHistory):
Task History that writes to a database using sqlalchemy.
Also has methods for useful db queries.
"""
CURRENT_SOURCE_VERSION = 1

@contextmanager
def _session(self, session=None):
if session:
Expand All @@ -81,6 +84,8 @@ def __init__(self):
Base.metadata.create_all(self.engine)
self.tasks = {} # task_id -> TaskRecord

_upgrade_schema(self.engine)

def task_scheduled(self, task):
htask = self._get_task(task, status=PENDING)
self._add_task_event(htask, TaskEvent(event_name=PENDING, ts=datetime.datetime.now()))
Expand Down Expand Up @@ -117,7 +122,7 @@ def _find_or_create_task(self, task):
raise Exception("Task with record_id, but no matching Task record!")
yield (task_record, session)
else:
task_record = TaskRecord(name=task.task_family, host=task.host)
task_record = TaskRecord(task_id=task._task.id, name=task.task_family, host=task.host)
for (k, v) in six.iteritems(task.parameters):
task_record.parameters[k] = TaskParameter(name=k, value=v)
session.add(task_record)
Expand Down Expand Up @@ -219,6 +224,7 @@ class TaskRecord(Base):
"""
__tablename__ = 'tasks'
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True)
task_id = sqlalchemy.Column(sqlalchemy.String(200), index=True)
name = sqlalchemy.Column(sqlalchemy.String(128), index=True)
host = sqlalchemy.Column(sqlalchemy.String(128))
parameters = sqlalchemy.orm.relationship(
Expand All @@ -232,3 +238,19 @@ class TaskRecord(Base):

def __repr__(self):
return "TaskRecord(name=%s, host=%s)" % (self.name, self.host)


def _upgrade_schema(engine):
"""
Ensure the database schema is up to date with the codebase.

:param engine: SQLAlchemy engine of the underlying database.
"""
inspector = reflection.Inspector.from_engine(engine)
conn = engine.connect()

# Upgrade 1. Add task_id column and index to tasks
if 'task_id' not in [x['name'] for x in inspector.get_columns('tasks')]:
logger.warn('Upgrading DbTaskHistory schema: Adding tasks.task_id')
conn.execute('ALTER TABLE tasks ADD COLUMN task_id VARCHAR(200)')
conn.execute('CREATE INDEX ix_task_id ON tasks (task_id)')
10 changes: 4 additions & 6 deletions luigi/scheduler.py
Expand Up @@ -912,12 +912,10 @@ def _traverse_graph(self, root_task_id, seen=None, dep_func=None):
if task is None or not task.family:
logger.warn('Missing task for id [%s]', task_id)

# try to infer family and params from task_id
try:
family, _, param_str = task_id.rstrip(')').partition('(')
params = dict(param.split('=') for param in param_str.split(', '))
except BaseException:
family, params = '', {}
# NOTE : If a dependency is missing from self._state there is no way to deduce the
# task family and parameters.

family, params = UNKNOWN, {}
serialized[task_id] = {
'deps': [],
'status': UNKNOWN,
Expand Down
2 changes: 1 addition & 1 deletion luigi/static/visualiser/index.html
Expand Up @@ -467,7 +467,7 @@ <h3 class="box-title">{{name}}</h3>
<div class="container-fluid">
<div class="form-group col-md-6 col-sm-4">
<form class="form-inline" id="loadTaskForm">
<input type="text" class="search-query form-control" placeholder="TaskId(param1=val1,param2=val2)">
<input type="text" class="search-query form-control" placeholder="TaskId">
<button type="submit" class="btn btn-default form-control">Show task details</button>
</form>
</div>
Expand Down
5 changes: 2 additions & 3 deletions luigi/static/visualiser/js/visualiserApp.js
Expand Up @@ -58,9 +58,8 @@ function visualiserApp(luigi) {
}

function taskToDisplayTask(task) {
var taskIdParts = /([A-Za-z0-9_]*)\(([\s\S]*)\)/.exec(task.taskId);
var taskName = taskIdParts[1];
var taskParams = taskIdParts[2];
var taskName = task.name;
var taskParams = JSON.stringify(task.params);
var displayTime = new Date(Math.floor(task.start_time*1000)).toLocaleString();
var time_running = -1;
if (task.status == "RUNNING" && "time_running" in task) {
Expand Down
48 changes: 38 additions & 10 deletions luigi/task.py
Expand Up @@ -27,6 +27,9 @@
import logging
import traceback
import warnings
import json
import hashlib
import re

from luigi import six

Expand Down Expand Up @@ -253,14 +256,23 @@ def __init__(self, *args, **kwargs):
self.param_args = tuple(value for key, value in param_values)
self.param_kwargs = dict(param_values)

# Build up task id
task_id_parts = []
param_objs = dict(params)
for param_name, param_value in param_values:
if param_objs[param_name].significant:
task_id_parts.append('%s=%s' % (param_name, param_objs[param_name].serialize(param_value)))
# task_id is a concatenation of task family, the first values of the first 3 parameters
# sorted by parameter name and a md5hash of the family/parameters as a cananocalised json.
TASK_ID_INCLUDE_PARAMS = 3
TASK_ID_TRUNCATE_PARAMS = 16
TASK_ID_TRUNCATE_HASH = 10
TASK_ID_INVALID_CHAR_REGEX = r'[^A-Za-z0-9_]'

params = self.to_str_params(only_significant=True)
param_str = json.dumps(params, separators=(',', ':'), sort_keys=True)
param_hash = hashlib.md5(param_str.encode('utf-8')).hexdigest()

param_summary = '_'.join(p[:TASK_ID_TRUNCATE_PARAMS]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should probably remove non-alphanumeric characters from param_summary

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, all non-alphanumeric (+ "") converted to "".

for p in (params[p] for p in sorted(params)[:TASK_ID_INCLUDE_PARAMS]))
param_summary = re.sub(TASK_ID_INVALID_CHAR_REGEX, '_', param_summary)

self.task_id = '{}_{}_{}'.format(self.task_family, param_summary, param_hash[:TASK_ID_TRUNCATE_HASH])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this doesn't include a # but some other code implies we're using a # (eg your change to luigi/static/visualiser/index.html)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed refs in index.html and sge.py. Shout if you see any more.


self.task_id = '%s(%s)' % (self.task_family, ', '.join(task_id_parts))
self.__hash = hash(self.task_id)

def initialized(self):
Expand All @@ -283,14 +295,15 @@ def from_str_params(cls, params_str):

return cls(**kwargs)

def to_str_params(self):
def to_str_params(self, only_significant=False):
"""
Convert all parameters to a str->str hash.
"""
params_str = {}
params = dict(self.get_params())
for param_name, param_value in six.iteritems(self.param_kwargs):
params_str[param_name] = params[param_name].serialize(param_value)
if (not only_significant) or params[param_name].significant:
params_str[param_name] = params[param_name].serialize(param_value)

return params_str

Expand Down Expand Up @@ -324,7 +337,22 @@ def __hash__(self):
return self.__hash

def __repr__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a docstring:

Build a task representation like "MyTask(param1=5, param2='5')"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried this but the results are inconsistent in the tests. I'm seeing a mixture of param='1' and param=1 for IntParameters. I'm not sure what's going on but I think keeping str(task) the same for backward compatibility is the pragmatic choice.

return self.task_id
"""
Build a task representation like `MyTask(param1=1.5, param2='5')`
"""
params = self.get_params()
param_values = self.get_param_values(params, [], self.param_kwargs)

# Build up task id
repr_parts = []
param_objs = dict(params)
for param_name, param_value in param_values:
if param_objs[param_name].significant:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, since we're not sending this to the scheduler any more. Maybe we should include the insignificant parameters too?

(What I'm suggesting is to remove this one line)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm undecided. Maybe people will want to keep repr() short by marking configuration parameters as insignificant? I don't use this feature so I don't really have an opinion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__repr__ is only used to assist debugging right? yeah in that case let's include all params

repr_parts.append('%s=%s' % (param_name, param_objs[param_name].serialize(param_value)))

task_str = '{}({})'.format(self.task_family, ', '.join(repr_parts))

return task_str

def __eq__(self, other):
return self.__class__ == other.__class__ and self.param_args == other.param_args
Expand Down
2 changes: 1 addition & 1 deletion luigi/util.py
Expand Up @@ -236,7 +236,7 @@ def get_previous_completed(task, max_steps=10):
prev = task
for _ in xrange(max_steps):
prev = previous(prev)
logger.debug("Checking if %s is complete", prev.task_id)
logger.debug("Checking if %s is complete", prev)
if prev.complete():
return prev
return None
8 changes: 4 additions & 4 deletions luigi/worker.py
Expand Up @@ -133,7 +133,7 @@ def _run_get_new_deps(self):
return new_deps

def run(self):
logger.info('[pid %s] Worker %s running %s', os.getpid(), self.worker_id, self.task.task_id)
logger.info('[pid %s] Worker %s running %s', os.getpid(), self.worker_id, self.task)

if self.random_seed:
# Need to have different random seeds if running in separate processes
Expand Down Expand Up @@ -165,13 +165,13 @@ def run(self):
if new_deps:
logger.info(
'[pid %s] Worker %s new requirements %s',
os.getpid(), self.worker_id, self.task.task_id)
os.getpid(), self.worker_id, self.task)
elif status == DONE:
self.task.trigger_event(
Event.PROCESSING_TIME, self.task, time.time() - t0)
expl = json.dumps(self.task.on_success())
logger.info('[pid %s] Worker %s done %s', os.getpid(),
self.worker_id, self.task.task_id)
self.worker_id, self.task)
self.task.trigger_event(Event.SUCCESS, self.task)

except KeyboardInterrupt:
Expand Down Expand Up @@ -581,7 +581,7 @@ def _add(self, task, is_complete):
task.trigger_event(Event.DEPENDENCY_MISSING, task)
logger.warning('Data for %s does not exist (yet?). The task is an '
'external data depedency, so it can not be run from'
' this luigi process.', task.task_id)
' this luigi process.', task)

else:
try:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -70,6 +70,7 @@ def get_static_files(path):
'luigid = luigi.cmdline:luigid',
'luigi-grep = luigi.tools.luigi_grep:main',
'luigi-deps = luigi.tools.deps:main',
'luigi-migrate = luigi.tools.migrate:main'
]
},
install_requires=install_requires,
Expand Down
8 changes: 4 additions & 4 deletions test/contrib/esindex_test.py
Expand Up @@ -259,7 +259,7 @@ def will_raise():
result = self.es.search(index=MARKER_INDEX, doc_type=MARKER_DOC_TYPE,
body={'query': {'match_all': {}}})
marker_doc = result.get('hits').get('hits')[0].get('_source')
self.assertEqual('IndexingTask1()', marker_doc.get('update_id'))
self.assertEqual(task1.task_id, marker_doc.get('update_id'))
self.assertEqual(INDEX, marker_doc.get('target_index'))
self.assertEqual(DOC_TYPE, marker_doc.get('target_doc_type'))
self.assertTrue('date' in marker_doc)
Expand All @@ -286,8 +286,8 @@ def will_raise():
first = next(it)
second = next(it)
self.assertTrue(first.date < second.date)
self.assertEqual(first.update_id, 'IndexingTask1()')
self.assertEqual(second.update_id, 'IndexingTask2()')
self.assertEqual(first.update_id, task1.task_id)
self.assertEqual(second.update_id, task2.task_id)


class IndexingTask4(CopyToTestIndex):
Expand Down Expand Up @@ -333,5 +333,5 @@ def test_limited_history(self):
marker_index_document_id = task4_3.output().marker_index_document_id()
result = self.es.get(id=marker_index_document_id, index=MARKER_INDEX,
doc_type=MARKER_DOC_TYPE)
self.assertEqual('IndexingTask4(date=2002-01-01)',
self.assertEqual(task4_3.task_id,
result.get('_source').get('update_id'))
4 changes: 2 additions & 2 deletions test/contrib/redshift_test.py
Expand Up @@ -101,7 +101,7 @@ def test_s3_copy_to_table(self, mock_redshift_target, mock_copy):
# returned by S3CopyToTable.output(self).
mock_redshift_target.assert_called_with(database=task.database,
host=task.host,
update_id='DummyS3CopyToTable(table=dummy_table)',
update_id=task.task_id,
user=task.user,
table=task.table,
password=task.password)
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_s3_copy_to_temp_table(self, mock_redshift_target, mock_copy):
mock_redshift_target.assert_called_once_with(
database=task.database,
host=task.host,
update_id='DummyS3CopyToTempTable(table=stage_dummy_table)',
update_id=task.task_id,
user=task.user,
table=task.table,
password=task.password,
Expand Down
2 changes: 1 addition & 1 deletion test/customized_run_test.py
Expand Up @@ -37,7 +37,7 @@ def complete(self):
return self.has_run

def run(self):
logging.debug("%s - setting has_run", self.task_id)
logging.debug("%s - setting has_run", self)
self.has_run = True


Expand Down