Skip to content

Commit

Permalink
Merge pull request #3057 from spotify/fix-not-bound-to-session
Browse files Browse the repository at this point in the history
Keep orm session open during unit test
  • Loading branch information
narape committed Mar 31, 2021
2 parents dba82cc + 6a18a50 commit 9f1c44f
Showing 1 changed file with 54 additions and 45 deletions.
99 changes: 54 additions & 45 deletions test/db_task_history_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,51 +45,58 @@ def test_task_list(self):
self.run_task(DummyTask())
self.run_task(DummyTask(foo='bar'))

tasks = list(self.history.find_all_by_name('DummyTask'))
with self.history._session() as session:
tasks = list(self.history.find_all_by_name('DummyTask', session))

self.assertEqual(len(tasks), 2)
for task in tasks:
self.assertEqual(task.name, 'DummyTask')
self.assertEqual(task.host, 'hostname')
self.assertEqual(len(tasks), 2)
for task in tasks:
self.assertEqual(task.name, 'DummyTask')
self.assertEqual(task.host, 'hostname')

def test_task_events(self):
self.run_task(DummyTask())
tasks = list(self.history.find_all_by_name('DummyTask'))
self.assertEqual(len(tasks), 1)
[task] = tasks
self.assertEqual(task.name, 'DummyTask')
self.assertEqual(len(task.events), 3)
for (event, name) in zip(task.events, [DONE, RUNNING, PENDING]):
self.assertEqual(event.event_name, name)

with self.history._session() as session:
tasks = list(self.history.find_all_by_name('DummyTask', session))
self.assertEqual(len(tasks), 1)
[task] = tasks
self.assertEqual(task.name, 'DummyTask')
self.assertEqual(len(task.events), 3)
for (event, name) in zip(task.events, [DONE, RUNNING, PENDING]):
self.assertEqual(event.event_name, name)

def test_task_by_params(self):
task1 = ParamTask('foo', 'bar')
task2 = ParamTask('bar', 'foo')

self.run_task(task1)
self.run_task(task2)
task1_record = self.history.find_all_by_parameters(task_name='ParamTask', param1='foo', param2='bar')
task2_record = self.history.find_all_by_parameters(task_name='ParamTask', param1='bar', param2='foo')
for task, records in zip((task1, task2), (task1_record, task2_record)):
records = list(records)
self.assertEqual(len(records), 1)
[record] = records
self.assertEqual(task.task_family, record.name)
for param_name, param_value in task.param_kwargs.items():
self.assertTrue(param_name in record.parameters)
self.assertEqual(str(param_value), record.parameters[param_name].value)
with self.history._session() as session:
self.run_task(task1)
self.run_task(task2)
task1_record = self.history.find_all_by_parameters(task_name='ParamTask', session=session,
param1='foo', param2='bar')
task2_record = self.history.find_all_by_parameters(task_name='ParamTask', session=session,
param1='bar', param2='foo')
for task, records in zip((task1, task2), (task1_record, task2_record)):
records = list(records)
self.assertEqual(len(records), 1)
[record] = records
self.assertEqual(task.task_family, record.name)
for param_name, param_value in task.param_kwargs.items():
self.assertTrue(param_name in record.parameters)
self.assertEqual(str(param_value), record.parameters[param_name].value)

def test_task_blank_param(self):
self.run_task(DummyTask(foo=""))

tasks = list(self.history.find_all_by_name('DummyTask'))
with self.history._session() as session:
tasks = list(self.history.find_all_by_name('DummyTask', session))

self.assertEqual(len(tasks), 1)
task_record = tasks[0]
self.assertEqual(task_record.name, 'DummyTask')
self.assertEqual(task_record.host, 'hostname')
self.assertIn('foo', task_record.parameters)
self.assertEqual(task_record.parameters['foo'].value, '')
self.assertEqual(len(tasks), 1)
task_record = tasks[0]
self.assertEqual(task_record.name, 'DummyTask')
self.assertEqual(task_record.host, 'hostname')
self.assertIn('foo', task_record.parameters)
self.assertEqual(task_record.parameters['foo'].value, '')

def run_task(self, task):
task2 = luigi.scheduler.Task(task.task_id, PENDING, [], family=task.task_family,
Expand All @@ -110,26 +117,28 @@ def setUp(self):
raise unittest.SkipTest('DBTaskHistory cannot be created: probably no MySQL available')

def test_subsecond_timestamp(self):
# Add 2 events in <1s
task = DummyTask()
self.run_task(task)
with self.history._session() as session:
# Add 2 events in <1s
task = DummyTask()
self.run_task(task)

task_record = next(self.history.find_all_by_name('DummyTask'))
print(task_record.events)
self.assertEqual(task_record.events[0].event_name, DONE)
task_record = next(self.history.find_all_by_name('DummyTask', session))
print(task_record.events)
self.assertEqual(task_record.events[0].event_name, DONE)

def test_utc_conversion(self):
from luigi.server import from_utc

task = DummyTask()
self.run_task(task)
with self.history._session() as session:
task = DummyTask()
self.run_task(task)

task_record = next(self.history.find_all_by_name('DummyTask'))
last_event = task_record.events[0]
try:
print(from_utc(str(last_event.ts)))
except ValueError:
self.fail("Failed to convert timestamp {} to UTC".format(last_event.ts))
task_record = next(self.history.find_all_by_name('DummyTask', session))
last_event = task_record.events[0]
try:
print(from_utc(str(last_event.ts)))
except ValueError:
self.fail("Failed to convert timestamp {} to UTC".format(last_event.ts))

def run_task(self, task):
task2 = luigi.scheduler.Task(task.task_id, PENDING, [],
Expand Down

0 comments on commit 9f1c44f

Please sign in to comment.