Skip to content

Commit

Permalink
fix: fix opening of checkpoint output files from remote storage (#2873)
Browse files Browse the repository at this point in the history
<!--Add a description of your PR here-->

### QC
<!-- Make sure that you can tick the boxes below. -->

* [x] The PR contains a test case for the changes or the changes are
already covered by an existing test case.
* [x] The documentation (`docs/`) is updated to reflect the changes or
this is not necessary (e.g. if the change does neither modify the
language nor the behavior or functionalities of Snakemake).
  • Loading branch information
johanneskoester committed May 14, 2024
1 parent 21ec649 commit e7cb7fb
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 145 deletions.
143 changes: 52 additions & 91 deletions snakemake/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tarfile
import textwrap
import time
from typing import Dict, Iterable, Optional, Set, Union
from typing import Iterable, Optional, Set, Union
import uuid
import subprocess
from collections import Counter, defaultdict, deque, namedtuple
Expand All @@ -31,7 +31,6 @@
from snakemake import workflow as _workflow
from snakemake.common import (
ON_WINDOWS,
async_run,
group_into_chunks,
is_local_file,
)
Expand Down Expand Up @@ -176,13 +175,13 @@ def batch(self):

async def init(self, progress=False):
"""Initialise the DAG."""
for job in map(self.rule2job, self.targetrules):
for job in [await self.rule2job(rule) for rule in self.targetrules]:
job = await self.update([job], progress=progress, create_inventory=True)
self.targetjobs.add(job)

for file in self.targetfiles:
job = await self.update(
self.file2jobs(file),
await self.file2jobs(file),
file=file,
progress=progress,
create_inventory=True,
Expand All @@ -192,7 +191,7 @@ async def init(self, progress=False):
for spec in self.workflow.dag_settings.target_jobs:
job = await self.update(
[
self.new_job(
await self.new_job(
self.workflow.get_rule(spec.rulename),
wildcards_dict=spec.wildcards_dict,
)
Expand Down Expand Up @@ -1040,9 +1039,12 @@ async def update_(
known_producers = dict()
visited.add(job)
dependencies = self._dependencies[job]
potential_dependencies = self.collect_potential_dependencies(
job, known_producers=known_producers
)
potential_dependencies = [
res
async for res in self.collect_potential_dependencies(
job, known_producers=known_producers
)
]

missing_input = set()
producer = dict()
Expand Down Expand Up @@ -1523,20 +1525,18 @@ def dfs(job, group, visited, outside_jobs, outside_jobs_all, skip_this):
rule=job.rule,
)

def update_incomplete_input_expand_jobs(self):
async def update_incomplete_input_expand_jobs(self):
"""Update (re-evaluate) all jobs which have incomplete input file expansions.
only filled in the second pass of postprocessing.
"""
replacer = JobReplacer(
self,
jobs={
job: job.updated()
for job in list(self.jobs)
if job.incomplete_input_expand
},
)
return replacer.process()
updated = False
for job in list(self.jobs):
if job.incomplete_input_expand:
newjob = await job.updated()
await self.replace_job(job, newjob, recursive=False)
updated = True
return updated

def update_ready(self, jobs=None):
"""Update information whether a job is ready to execute.
Expand Down Expand Up @@ -1585,7 +1585,7 @@ def get_jobs_or_groups(self):
visited_groups.add(group)
yield group

def postprocess(
async def postprocess(
self, update_needrun=True, update_incomplete_input_expand_jobs=True
):
"""Postprocess the DAG. This has to be invoked after any change to the
Expand All @@ -1595,20 +1595,20 @@ def postprocess(
if update_needrun:
self.update_container_imgs()
self.update_conda_envs()
async_run(self.update_needrun())
await self.update_needrun()
self.update_priority()
self.handle_pipes_and_services()
self.handle_update_flags()
self.update_groups()
self.update_storage_inputs()

if update_incomplete_input_expand_jobs:
updated = self.update_incomplete_input_expand_jobs()
updated = await self.update_incomplete_input_expand_jobs()
if updated:
# run a second pass, some jobs have been updated
# with potentially new input files that have depended
# on group ids.
self.postprocess(
await self.postprocess(
update_needrun=True, update_incomplete_input_expand_jobs=False
)

Expand Down Expand Up @@ -1742,24 +1742,23 @@ def _ready(self, job):
(self._n_until_ready[job] - n_internal_deps(job)) == 0 for job in group
)

def update_queue_input_jobs(self):
async def update_queue_input_jobs(self):
updated = False
if self.has_unfinished_queue_input_jobs():
logger.info("Updating jobs with queue input...")
replacer = JobReplacer(self)
for job in self.queue_input_jobs:
if (
job.has_queue_input()
and job not in self._jobs_with_finished_queue_input
):
newjob = job.updated()
newjob = await job.updated()
if newjob.input != job.input:
replacer.add(job, newjob)
if not job.has_unfinished_queue_input():
self._jobs_with_finished_queue_input.add(job)
updated = replacer.process()
await self.replace_job(job, newjob, recursive=False)
updated = True
if updated and not job.has_unfinished_queue_input():
self._jobs_with_finished_queue_input.add(job)
if updated:
self.postprocess_after_update()
await self.postprocess_after_update()
# reset queue_input_jobs such that it is recomputed next time
self._queue_input_jobs = None
return updated
Expand All @@ -1777,7 +1776,7 @@ def queue_input_jobs(self):
def has_unfinished_queue_input_jobs(self):
return any(job.has_unfinished_queue_input() for job in self.queue_input_jobs)

def update_checkpoint_dependencies(self, jobs=None):
async def update_checkpoint_dependencies(self, jobs=None):
"""Update dependencies of checkpoints."""
updated = False
self.update_checkpoint_outputs()
Expand All @@ -1787,42 +1786,25 @@ def update_checkpoint_dependencies(self, jobs=None):
for job in jobs:
if job.is_checkpoint:
depending = list(self.depending[job])
# re-evaluate depending jobs, replace and update DAG
# Note: even for touch, this needs retrieval from storage!
if depending:

async def retrieve():
try:
async with asyncio.TaskGroup() as tg:
for f in job.output:
if f.is_storage:
tg.create_task(f.retrieve_from_storage())
except ExceptionGroup as e:
raise WorkflowError(
"Failed to retrieve checkpoint output.", e
)

async_run(retrieve())
all_depending.extend(depending)
replacer = JobReplacer(self)
for j in all_depending:
logger.debug(f"Updating job {j}.")
newjob = j.updated()
replacer.add(j, newjob)
updated = replacer.process()
newjob = await j.updated()
await self.replace_job(j, newjob, recursive=False)
updated = True
if updated:
self.postprocess_after_update()
await self.postprocess_after_update()
return updated

def postprocess_after_update(self):
self.postprocess()
async def postprocess_after_update(self):
await self.postprocess()
shared_input_output = (
SharedFSUsage.INPUT_OUTPUT in self.workflow.storage_settings.shared_fs_usage
)
if (
self.workflow.is_main_process and shared_input_output
) or self.workflow.remote_exec:
async_run(self.retrieve_storage_inputs())
await self.retrieve_storage_inputs()

def register_running(self, jobs):
self._running.update(jobs)
Expand All @@ -1834,7 +1816,7 @@ def register_running(self, jobs):
# already gone
pass

def finish(self, job, update_checkpoint_dependencies=True):
async def finish(self, job, update_checkpoint_dependencies=True):
"""Finish a given job (e.g. remove from ready jobs, mark depending jobs
as ready)."""

Expand All @@ -1861,7 +1843,7 @@ def finish(self, job, update_checkpoint_dependencies=True):

updated_dag = False
if update_checkpoint_dependencies:
updated_dag = self.update_checkpoint_dependencies(jobs)
updated_dag = await self.update_checkpoint_dependencies(jobs)

depending = [
j
Expand Down Expand Up @@ -1899,15 +1881,12 @@ def finish(self, job, update_checkpoint_dependencies=True):
# temp files.
# TODO: we maybe could be more accurate and determine whether there is a
# checkpoint that depends on the temp file.
async def handle_temp():
for job in jobs:
await self.handle_temp(job)

async_run(handle_temp())
for job in jobs:
await self.handle_temp(job)

return potential_new_ready_jobs

def new_job(
async def new_job(
self, rule, targetfile=None, format_wildcards=None, wildcards_dict=None
):
"""Create new job for given rule and (optional) targetfile.
Expand All @@ -1928,7 +1907,7 @@ def new_job(
assert targetfile is not None
return self.job_cache[key]
wildcards_dict = rule.get_wildcards(targetfile, wildcards_dict=wildcards_dict)
job = self.job_factory.new(
job = await self.job_factory.new(
rule,
self,
wildcards_dict=wildcards_dict,
Expand Down Expand Up @@ -2043,7 +2022,7 @@ def handle_update_flags(self):
update_job.add_aux_resource(mutex, 1)
self.workflow.register_resource(mutex, 1)

def collect_potential_dependencies(self, job, known_producers):
async def collect_potential_dependencies(self, job, known_producers):
"""Collect all potential dependencies of a job. These might contain
ambiguities. The keys of the returned dict represent the files to be considered.
"""
Expand Down Expand Up @@ -2078,7 +2057,7 @@ def collect_potential_dependencies(self, job, known_producers):
yield PotentialDependency(
file,
[
self.new_job(
await self.new_job(
job.dependencies[file],
targetfile=file,
wildcards_dict=job.wildcards_dict,
Expand All @@ -2089,7 +2068,9 @@ def collect_potential_dependencies(self, job, known_producers):
else:
yield PotentialDependency(
file,
file2jobs(file, wildcards_dict=job.wildcards_dict),
await file2jobs(
file, wildcards_dict=job.wildcards_dict
),
False,
)
except MissingRuleException as ex:
Expand Down Expand Up @@ -2161,7 +2142,7 @@ def new_wildcards(self, job):
new_wildcards.discard(wildcard)
return new_wildcards

def rule2job(self, targetrule):
async def rule2job(self, targetrule):
"""Generate a new job from a given rule."""
if targetrule.has_wildcards():
raise WorkflowError(
Expand All @@ -2170,17 +2151,17 @@ def rule2job(self, targetrule):
"or have a rule without wildcards at the very top of your workflow (e.g. the typical "
'"rule all" which just collects all results you want to generate in the end).'
)
return self.new_job(targetrule)
return await self.new_job(targetrule)

def file2jobs(self, targetfile, wildcards_dict=None):
async def file2jobs(self, targetfile, wildcards_dict=None):
rules = self.output_index.match(targetfile)
jobs = []
exceptions = list()
for rule in rules:
if rule.is_producer(targetfile):
try:
jobs.append(
self.new_job(
await self.new_job(
rule, targetfile=targetfile, wildcards_dict=wildcards_dict
)
)
Expand Down Expand Up @@ -2883,23 +2864,3 @@ def __hash__(self):

def merge(self, other):
self.id = other.id


class JobReplacer:
def __init__(self, dag, jobs: Optional[Dict[Job, Job]] = None):
self.jobs = jobs or dict()
self.dag = dag

def add(self, job, newjob):
self.jobs[job] = newjob

def process(self):
if self.jobs:

async def replace():
for job, newjob in self.jobs.items():
await self.dag.replace_job(job, newjob, recursive=False)

async_run(replace())
return True
return False
6 changes: 6 additions & 0 deletions snakemake/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,12 @@ def __init__(self, rule, targetfile):
self.targetfile = checkpoint_target(targetfile)


class InputOpenException(Exception):
def __init__(self, iofile):
self.iofile = iofile
self.rule = None


class CacheMissException(Exception):
pass

Expand Down
5 changes: 3 additions & 2 deletions snakemake/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
is_namedtuple_instance,
)
from snakemake.exceptions import (
InputOpenException,
MissingOutputException,
WildcardError,
WorkflowError,
Expand Down Expand Up @@ -327,8 +328,8 @@ def open(self, mode="r", buffering=-1, encoding=None, errors=None, newline=None)
This can (and should) be used in a `with`-statement.
If the file is a remote storage file, retrieve it first if necessary.
"""
if self.is_storage and not async_run(self.exists_local()):
async_run(self.retrieve_from_storage())
if not os.path.exists(self):
raise InputOpenException(self)
f = open(
self,
mode=mode,
Expand Down
Loading

0 comments on commit e7cb7fb

Please sign in to comment.