Skip to content

Commit

Permalink
fix: make checkpoint updates synchronous (#2871)
Browse files Browse the repository at this point in the history
<!--Add a description of your PR here-->

### QC
<!-- Make sure that you can tick the boxes below. -->

* [x] The PR contains a test case for the changes or the changes are
already covered by an existing test case.
* [x] The documentation (`docs/`) is updated to reflect the changes or
this is not necessary (e.g. if the change does neither modify the
language nor the behavior or functionalities of Snakemake).
  • Loading branch information
johanneskoester committed May 11, 2024
1 parent 25a361b commit b0e7ebd
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 80 deletions.
108 changes: 71 additions & 37 deletions snakemake/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tarfile
import textwrap
import time
from typing import Iterable, Optional, Set, Union
from typing import Dict, Iterable, Optional, Set, Union
import uuid
import subprocess
from collections import Counter, defaultdict, deque, namedtuple
Expand All @@ -31,6 +31,7 @@
from snakemake import workflow as _workflow
from snakemake.common import (
ON_WINDOWS,
async_run,
group_into_chunks,
is_local_file,
)
Expand Down Expand Up @@ -1506,18 +1507,20 @@ def dfs(job, group, visited, outside_jobs, outside_jobs_all, skip_this):
rule=job.rule,
)

async def update_incomplete_input_expand_jobs(self):
def update_incomplete_input_expand_jobs(self):
"""Update (re-evaluate) all jobs which have incomplete input file expansions.
only filled in the second pass of postprocessing.
"""
updated = False
for job in list(self.jobs):
if job.incomplete_input_expand:
newjob = job.updated()
await self.replace_job(job, newjob, recursive=False)
updated = True
return updated
replacer = JobReplacer(
self,
jobs={
job: job.updated()
for job in list(self.jobs)
if job.incomplete_input_expand
},
)
return replacer.process()

def update_ready(self, jobs=None):
"""Update information whether a job is ready to execute.
Expand Down Expand Up @@ -1566,7 +1569,7 @@ def get_jobs_or_groups(self):
visited_groups.add(group)
yield group

async def postprocess(
def postprocess(
self, update_needrun=True, update_incomplete_input_expand_jobs=True
):
"""Postprocess the DAG. This has to be invoked after any change to the
Expand All @@ -1576,19 +1579,19 @@ async def postprocess(
if update_needrun:
self.update_container_imgs()
self.update_conda_envs()
await self.update_needrun()
async_run(self.update_needrun())
self.update_priority()
self.handle_pipes_and_services()
self.handle_update_flags()
self.update_groups()

if update_incomplete_input_expand_jobs:
updated = await self.update_incomplete_input_expand_jobs()
updated = self.update_incomplete_input_expand_jobs()
if updated:
# run a second pass, some jobs have been updated
# with potentially new input files that have depended
# on group ids.
await self.postprocess(
self.postprocess(
update_needrun=True, update_incomplete_input_expand_jobs=False
)

Expand Down Expand Up @@ -1722,23 +1725,24 @@ def _ready(self, job):
(self._n_until_ready[job] - n_internal_deps(job)) == 0 for job in group
)

async def update_queue_input_jobs(self):
def update_queue_input_jobs(self):
updated = False
if self.has_unfinished_queue_input_jobs():
logger.info("Updating jobs with queue input...")
replacer = JobReplacer(self)
for job in self.queue_input_jobs:
if (
job.has_queue_input()
and job not in self._jobs_with_finished_queue_input
):
newjob = job.updated()
if newjob.input != job.input:
await self.replace_job(job, newjob, recursive=False)
updated = True
if updated and not job.has_unfinished_queue_input():
self._jobs_with_finished_queue_input.add(job)
replacer.add(job, newjob)
if not job.has_unfinished_queue_input():
self._jobs_with_finished_queue_input.add(job)
updated = replacer.process()
if updated:
await self.postprocess_after_update()
self.postprocess_after_update()
# reset queue_input_jobs such that it is recomputed next time
self._queue_input_jobs = None
return updated
Expand All @@ -1756,7 +1760,7 @@ def queue_input_jobs(self):
def has_unfinished_queue_input_jobs(self):
return any(job.has_unfinished_queue_input() for job in self.queue_input_jobs)

async def update_checkpoint_dependencies(self, jobs=None):
def update_checkpoint_dependencies(self, jobs=None):
"""Update dependencies of checkpoints."""
updated = False
self.update_checkpoint_outputs()
Expand All @@ -1769,32 +1773,39 @@ async def update_checkpoint_dependencies(self, jobs=None):
# re-evaluate depending jobs, replace and update DAG
# Note: even for touch, this needs retrieval from storage!
if depending:
try:
async with asyncio.TaskGroup() as tg:
for f in job.output:
if f.is_storage:
tg.create_task(f.retrieve_from_storage())
except ExceptionGroup as e:
raise WorkflowError("Failed to retrieve checkpoint output.", e)

async def retrieve():
try:
async with asyncio.TaskGroup() as tg:
for f in job.output:
if f.is_storage:
tg.create_task(f.retrieve_from_storage())
except ExceptionGroup as e:
raise WorkflowError(
"Failed to retrieve checkpoint output.", e
)

async_run(retrieve())
all_depending.extend(depending)
replacer = JobReplacer(self)
for j in all_depending:
logger.debug(f"Updating job {j}.")
newjob = j.updated()
await self.replace_job(j, newjob, recursive=False)
updated = True
replacer.add(j, newjob)
updated = replacer.process()
if updated:
await self.postprocess_after_update()
self.postprocess_after_update()
return updated

async def postprocess_after_update(self):
await self.postprocess()
def postprocess_after_update(self):
self.postprocess()
shared_input_output = (
SharedFSUsage.INPUT_OUTPUT in self.workflow.storage_settings.shared_fs_usage
)
if (
self.workflow.is_main_process and shared_input_output
) or self.workflow.remote_exec:
await self.retrieve_storage_inputs()
async_run(self.retrieve_storage_inputs())

def register_running(self, jobs):
self._running.update(jobs)
Expand All @@ -1806,7 +1817,7 @@ def register_running(self, jobs):
# already gone
pass

async def finish(self, job, update_checkpoint_dependencies=True):
def finish(self, job, update_checkpoint_dependencies=True):
"""Finish a given job (e.g. remove from ready jobs, mark depending jobs
as ready)."""

Expand All @@ -1833,7 +1844,7 @@ async def finish(self, job, update_checkpoint_dependencies=True):

updated_dag = False
if update_checkpoint_dependencies:
updated_dag = await self.update_checkpoint_dependencies(jobs)
updated_dag = self.update_checkpoint_dependencies(jobs)

depending = [
j
Expand Down Expand Up @@ -1871,8 +1882,11 @@ async def finish(self, job, update_checkpoint_dependencies=True):
# temp files.
# TODO: we maybe could be more accurate and determine whether there is a
# checkpoint that depends on the temp file.
for job in jobs:
await self.handle_temp(job)
async def handle_temp():
for job in jobs:
await self.handle_temp(job)

async_run(handle_temp())

return potential_new_ready_jobs

Expand Down Expand Up @@ -2852,3 +2866,23 @@ def __hash__(self):

def merge(self, other):
self.id = other.id


class JobReplacer:
def __init__(self, dag, jobs: Optional[Dict[Job, Job]] = None):
self.jobs = jobs or dict()
self.dag = dag

def add(self, job, newjob):
self.jobs[job] = newjob

def process(self):
if self.jobs:

async def replace():
for job, newjob in self.jobs.items():
await self.dag.replace_job(job, newjob, recursive=False)

async_run(replace())
return True
return False
85 changes: 44 additions & 41 deletions snakemake/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,73 +326,76 @@ def schedule(self):
def _finish_jobs(self):
# must be called from within lock
# clear the global tofinish such that parallel calls do not interfere
async def postprocess():
for job in self._tofinish:
if not self.workflow.dryrun:
try:
if self.workflow.exec_mode == ExecMode.DEFAULT:
await job.postprocess(
for job in self._tofinish:
if not self.workflow.dryrun:
try:
if self.workflow.exec_mode == ExecMode.DEFAULT:
async_run(
job.postprocess(
store_in_storage=not self.touch,
handle_log=True,
handle_touch=not self.touch,
ignore_missing_output=self.touch,
)
elif self.workflow.exec_mode == ExecMode.SUBPROCESS:
await job.postprocess(
)
elif self.workflow.exec_mode == ExecMode.SUBPROCESS:
async_run(
job.postprocess(
store_in_storage=False,
handle_log=True,
handle_touch=True,
)
else:
await job.postprocess(
)
else:
async_run(
job.postprocess(
# storage upload will be done after all jobs of
# this remote job (e.g. in case of group) are finished
# DAG.store_storage_outputs()
store_in_storage=False,
handle_log=True,
handle_touch=True,
)
except (RuleException, WorkflowError) as e:
# if an error occurs while processing job output,
# we do the same as in case of errors during execution
print_exception(e, self.workflow.linemaps)
await job.postprocess(error=True)
self._handle_error(job, postprocess_job=False)
continue

if self.handle_job_success:
self.get_executor(job).handle_job_success(job)

if self.update_resources:
# normal jobs have len=1, group jobs have len>1
self.finished_jobs += len(job)
logger.debug(
f"jobs registered as running before removal {self.running}"
)
self.running.remove(job)
self._free_resources(job)
)
except (RuleException, WorkflowError) as e:
# if an error occurs while processing job output,
# we do the same as in case of errors during execution
print_exception(e, self.workflow.linemaps)
async_run(job.postprocess(error=True))
self._handle_error(job, postprocess_job=False)
continue

if self.print_progress:
if job.is_group():
for j in job:
logger.job_finished(jobid=j.jobid)
else:
logger.job_finished(jobid=job.jobid)
self.progress()
if self.handle_job_success:
self.get_executor(job).handle_job_success(job)

await self.workflow.dag.finish(
job,
update_checkpoint_dependencies=self.update_checkpoint_dependencies,
if self.update_resources:
# normal jobs have len=1, group jobs have len>1
self.finished_jobs += len(job)
logger.debug(
f"jobs registered as running before removal {self.running}"
)
self.running.remove(job)
self._free_resources(job)

async_run(postprocess())
if self.print_progress:
if job.is_group():
for j in job:
logger.job_finished(jobid=j.jobid)
else:
logger.job_finished(jobid=job.jobid)
self.progress()

self.workflow.dag.finish(
job,
update_checkpoint_dependencies=self.update_checkpoint_dependencies,
)
self._tofinish.clear()

def update_queue_input_jobs(self):
currtime = time.time()
if currtime - self._last_update_queue_input_jobs >= 10:
self._last_update_queue_input_jobs = currtime
async_run(self.workflow.dag.update_queue_input_jobs())
self.workflow.dag.update_queue_input_jobs()

def _error_jobs(self):
# must be called from within lock
Expand Down
4 changes: 2 additions & 2 deletions snakemake/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def container_cleanup_images(self):
def _build_dag(self):
logger.info("Building DAG of jobs...")
async_run(self.dag.init())
async_run(self.dag.update_checkpoint_dependencies())
self.dag.update_checkpoint_dependencies()

def execute(
self,
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def execute(
self._build_dag()

with self.persistence.lock():
async_run(self.dag.postprocess(update_needrun=False))
self.dag.postprocess(update_needrun=False)
if not self.dryrun:
# deactivate IOCache such that from now on we always get updated
# size, existence and mtime information
Expand Down

0 comments on commit b0e7ebd

Please sign in to comment.