Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix: improve robustness when retrieving remote source files, fixed us…
…age of local git repos as wrapper prefixes (in collaboration with @cokelaer and @Smeds) (#1495)

* fix: fixed usage of local git repos as wrapper prefixes

* implement retry mechanism for non-local source file access

* logging fixes

* avoid retry import
  • Loading branch information
johanneskoester committed Mar 18, 2022
1 parent 5cf275a commit e16531d
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 52 deletions.
22 changes: 15 additions & 7 deletions snakemake/deployment/conda.py
Expand Up @@ -6,7 +6,12 @@
import os
from pathlib import Path
import re
from snakemake.sourcecache import LocalGitFile, LocalSourceFile, infer_source_file
from snakemake.sourcecache import (
LocalGitFile,
LocalSourceFile,
SourceFile,
infer_source_file,
)
import subprocess
import tempfile
from urllib.request import urlopen
Expand Down Expand Up @@ -498,9 +503,7 @@ def create(self, dryrun=False):

logger.debug(out)
logger.info(
"Environment for {} created (location: {})".format(
os.path.relpath(env_file), os.path.relpath(env_path)
)
f"Environment for {self.file.get_path_or_uri()} created (location: {os.path.relpath(env_path)})"
)
except subprocess.CalledProcessError as e:
# remove potential partially installed environment
Expand Down Expand Up @@ -706,8 +709,10 @@ def __eq__(self, other):


class CondaEnvFileSpec(CondaEnvSpec):
def __init__(self, filepath: str, rule=None):
if isinstance(filepath, _IOFile):
def __init__(self, filepath, rule=None):
if isinstance(filepath, SourceFile):
self.file = IOFile(str(filepath.get_path_or_uri()), rule=rule)
elif isinstance(filepath, _IOFile):
self.file = filepath
else:
self.file = IOFile(filepath, rule=rule)
Expand Down Expand Up @@ -777,5 +782,8 @@ def __eq__(self, other):
return self.name == other.name


def is_conda_env_file(spec: str):
def is_conda_env_file(spec):
if isinstance(spec, SourceFile):
spec = spec.get_filename()

return spec.endswith(".yaml") or spec.endswith(".yml")
2 changes: 1 addition & 1 deletion snakemake/logging.py
Expand Up @@ -349,7 +349,7 @@ def set_level(self, level):
def logfile_hint(self):
if self.mode == Mode.default:
logfile = self.get_logfile()
self.info("Complete log: {}".format(logfile))
self.info("Complete log: {}".format(os.path.relpath(logfile)))

def location(self, msg):
callerframerecord = inspect.stack()[1]
Expand Down
4 changes: 3 additions & 1 deletion snakemake/shell.py
Expand Up @@ -205,7 +205,9 @@ def __new__(
)
logger.info("Activating singularity image {}".format(container_img))
if conda_env:
logger.info("Activating conda environment: {}".format(conda_env))
logger.info(
"Activating conda environment: {}".format(os.path.relpath(conda_env))
)

tmpdir_resource = resources.get("tmpdir", None)
# environment variable lists for linear algebra libraries taken from:
Expand Down
73 changes: 61 additions & 12 deletions snakemake/sourcecache.py
Expand Up @@ -5,6 +5,7 @@

import hashlib
from pathlib import Path
import posixpath
import re
import os
import shutil
Expand All @@ -14,7 +15,6 @@
from abc import ABC, abstractmethod
from datetime import datetime


from snakemake.common import (
ON_WINDOWS,
is_local_file,
Expand Down Expand Up @@ -66,6 +66,11 @@ def mtime(self):
"""If possible, return mtime of the file. Otherwise, return None."""
return None

@property
@abstractmethod
def is_local(self):
...

def __hash__(self):
return self.get_path_or_uri().__hash__()

Expand Down Expand Up @@ -94,6 +99,10 @@ def get_filename(self):
def is_persistently_cacheable(self):
return False

@property
def is_local(self):
return False


class LocalSourceFile(SourceFile):
def __init__(self, path):
Expand Down Expand Up @@ -123,6 +132,10 @@ def mtime(self):
def __fspath__(self):
return self.path

@property
def is_local(self):
return True


class LocalGitFile(SourceFile):
def __init__(
Expand All @@ -136,7 +149,7 @@ def __init__(
self.path = path

def get_path_or_uri(self):
return "git+{}/{}@{}".format(self.repo_path, self.path, self.ref)
return "git+file://{}/{}@{}".format(self.repo_path, self.path, self.ref)

def join(self, path):
return LocalGitFile(
Expand All @@ -147,16 +160,29 @@ def join(self, path):
commit=self.commit,
)

def get_basedir(self):
return self.__class__(
repo_path=self.repo_path,
path=os.path.dirname(self.path),
tag=self.tag,
commit=self.commit,
ref=self.ref,
)

def is_persistently_cacheable(self):
return False

def get_filename(self):
return os.path.basename(self.path)
return posixpath.basename(self.path)

@property
def ref(self):
return self.tag or self.commit or self._ref

@property
def is_local(self):
return True


class HostingProviderFile(SourceFile):
"""Marker for denoting github source files from releases."""
Expand Down Expand Up @@ -229,6 +255,10 @@ def join(self, path):
branch=self.branch,
)

@property
def is_local(self):
return False


class GithubFile(HostingProviderFile):
def get_path_or_uri(self):
Expand Down Expand Up @@ -276,7 +306,12 @@ def infer_source_file(path_or_uri, basedir: SourceFile = None):
return basedir.join(path_or_uri)
return LocalSourceFile(path_or_uri)
if path_or_uri.startswith("git+file:"):
root_path, file_path, ref = split_git_path(path_or_uri)
try:
root_path, file_path, ref = split_git_path(path_or_uri)
except Exception as e:
raise WorkflowError(
f"Failed to read source {path_or_uri} from git repo.", e
)
return LocalGitFile(root_path, file_path, ref=ref)
# something else
return GenericSourceFile(path_or_uri)
Expand Down Expand Up @@ -311,7 +346,7 @@ def runtime_cache_path(self):

def open(self, source_file, mode="r"):
cache_entry = self._cache(source_file)
return self._open(cache_entry, mode)
return self._open(LocalSourceFile(cache_entry), mode)

def exists(self, source_file):
try:
Expand Down Expand Up @@ -343,7 +378,7 @@ def _cache(self, source_file):

def _do_cache(self, source_file, cache_entry):
# open from origin
with self._open(source_file.get_path_or_uri(), "rb") as source:
with self._open(source_file, "rb") as source:
tmp_source = tempfile.NamedTemporaryFile(
prefix=str(cache_entry),
delete=False, # no need to delete since we move it below
Expand All @@ -362,20 +397,34 @@ def _do_cache(self, source_file, cache_entry):
# as mtime.
os.utime(cache_entry, times=(mtime, mtime))

def _open(self, path_or_uri, mode):
def _open_local_or_remote(self, source_file, mode):
from retry import retry_call

if source_file.is_local:
return self._open(source_file, mode)
else:
return retry_call(
self._open,
[source_file, mode],
tries=3,
delay=3,
backoff=2,
logger=logger,
)

def _open(self, source_file, mode):
from smart_open import open

if isinstance(path_or_uri, LocalGitFile):
if isinstance(source_file, LocalGitFile):
import git

return io.BytesIO(
git.Repo(path_or_uri.repo_path)
.git.show("{}:{}".format(path_or_uri.ref, path_or_uri.path))
git.Repo(source_file.repo_path)
.git.show("{}:{}".format(source_file.ref, source_file.path))
.encode()
)

if isinstance(path_or_uri, SourceFile):
path_or_uri = path_or_uri.get_path_or_uri()
path_or_uri = source_file.get_path_or_uri()

try:
return open(path_or_uri, mode)
Expand Down
55 changes: 25 additions & 30 deletions snakemake/wrapper.py
Expand Up @@ -12,17 +12,20 @@

from snakemake.exceptions import WorkflowError
from snakemake.script import script
from snakemake.sourcecache import SourceCache, infer_source_file
from snakemake.sourcecache import LocalGitFile, SourceCache, infer_source_file


PREFIX = "https://github.com/snakemake/snakemake-wrappers/raw/"

EXTENSIONS = [".py", ".R", ".Rmd", ".jl"]

def is_script(path):

def is_script(source_file):
filename = source_file.get_filename()
return (
path.endswith("wrapper.py")
or path.endswith("wrapper.R")
or path.endswith("wrapper.jl")
filename.endswith("wrapper.py")
or filename.endswith("wrapper.R")
or filename.endswith("wrapper.jl")
)


Expand All @@ -34,7 +37,7 @@ def get_path(path, prefix=None):
parts = path.split("/")
path = "/" + "/".join(parts[1:]) + "@" + parts[0]
path = prefix + path
return path
return infer_source_file(path)


def is_url(path):
Expand All @@ -45,24 +48,13 @@ def is_url(path):
)


def is_local(path):
return path.startswith("file:")


def is_git_path(path):
return path.startswith("git+file:")


def find_extension(
path, sourcecache: SourceCache, extensions=[".py", ".R", ".Rmd", ".jl"]
):
for ext in extensions:
if path.endswith("wrapper{}".format(ext)):
return path
def find_extension(source_file, sourcecache: SourceCache):
for ext in EXTENSIONS:
if source_file.get_filename().endswith("wrapper{}".format(ext)):
return source_file

path = infer_source_file(path)
for ext in extensions:
script = path.join("wrapper{}".format(ext))
for ext in EXTENSIONS:
script = source_file.join("wrapper{}".format(ext))

if sourcecache.exists(script):
return script
Expand All @@ -77,11 +69,8 @@ def get_conda_env(path, prefix=None):
path = get_path(path, prefix=prefix)
if is_script(path):
# URLs and posixpaths share the same separator. Hence use posixpath here.
path = posixpath.dirname(path)
if is_git_path(path):
path, version = path.split("@")
return os.path.join(path, "environment.yaml") + "@" + version
return path + "/environment.yaml"
path = path.get_basedir()
return path.join("environment.yaml")


def wrapper(
Expand Down Expand Up @@ -112,11 +101,17 @@ def wrapper(
Load a wrapper from https://github.com/snakemake/snakemake-wrappers under
the given path + wrapper.(py|R|Rmd) and execute it.
"""
path = get_script(
assert path is not None
script_source = get_script(
path, SourceCache(runtime_cache_path=runtime_sourcecache_path), prefix=prefix
)
if script_source is None:
raise WorkflowError(
f"Unable to locate wrapper script for wrapper {path}. "
"This can be a network issue or a mistake in the wrapper URL."
)
script(
path,
script_source.get_path_or_uri(),
"",
input,
output,
Expand Down
20 changes: 19 additions & 1 deletion tests/tests.py
Expand Up @@ -526,11 +526,29 @@ def test_conda_cmd_exe():
run(dpath("test_conda_cmd_exe"), use_conda=True)


@skip_on_windows # Conda support is partly broken on Win
@skip_on_windows # wrappers are for linux and macos only
def test_wrapper():
run(dpath("test_wrapper"), use_conda=True)


@skip_on_windows # wrappers are for linux and macos only
def test_wrapper_local_git_prefix():
import git

with tempfile.TemporaryDirectory() as tmpdir:
print("Cloning wrapper repo...")
repo = git.Repo.clone_from(
"https://github.com/snakemake/snakemake-wrappers", tmpdir
)
print("Cloning complete.")

run(
dpath("test_wrapper"),
use_conda=True,
wrapper_prefix=f"git+file://{tmpdir}",
)


def test_get_log_none():
run(dpath("test_get_log_none"))

Expand Down

0 comments on commit e16531d

Please sign in to comment.