Skip to content

Commit

Permalink
Use cachedir and rsync to speed-up remote read (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdegeus committed Jun 8, 2023
1 parent fe178a9 commit ee33e2a
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 9 deletions.
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ dependencies:
- click
- furo
- numpy >=1.15.0
- platformdirs
- pre-commit
- prettytable
- python >=3.10
Expand Down
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
[project]
authors = [{name = "Tom de Geus", email = "tom@geus.me"}]
classifiers = ["License :: OSI Approved :: MIT License"]
dependencies = ["click", "numpy>=1.15.0", "prettytable", "pyyaml", "tqdm"]
dependencies = [
"click",
"numpy>=1.15.0",
"platformdirs",
"prettytable",
"pyyaml",
"tqdm"
]
description = "Simple dataset management"
dynamic = ["version"]
name = "shelephant"
Expand Down
5 changes: 3 additions & 2 deletions shelephant/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ class MyFmt(
)
parser.add_argument("-i", "--info", action="store_true", help="Add information (sha256, size).")
parser.add_argument("-f", "--force", action="store_true", help="Force overwrite output.")
parser.add_argument("--verbose", action="store_true", help="Print commands (only SSH remote).")
parser.add_argument("--version", action="version", version=version)
parser.add_argument("path", type=pathlib.Path, help="Path to remote directory.")
return parser
Expand Down Expand Up @@ -566,9 +567,9 @@ def shelephant_hostinfo(args: list[str]):
if args.dump:
loc.dump = args.dump

loc.read()
loc.read(verbose=args.verbose)
if args.info:
loc.getinfo()
loc.getinfo(verbose=args.verbose)
loc.to_yaml(args.output, force=args.force)


Expand Down
26 changes: 20 additions & 6 deletions shelephant/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,23 @@
import numpy as np
import prettytable
import tqdm
from platformdirs import user_cache_dir

from . import cli
from . import compute_hash
from . import rsync
from . import scp
from . import search
from . import ssh
from . import yaml
from ._version import version
from .external import exec_cmd

if shutil.which("rsync") is not None:
_copyfunc = rsync.copy
else:
_copyfunc = scp.copy


def _force_absolute_path(root: pathlib.Path, path: pathlib.Path) -> pathlib.Path:
"""
Expand Down Expand Up @@ -484,18 +491,19 @@ def read(self, verbose: bool = False):
)

# search on SSH remote host for files (the sha256/size/mtime of 'new' files is set to None)
with ssh.tempdir(self.ssh) as remote, search.tempdir():
cache_dir = ssh._shelephant_cachdir(self.ssh, self.python)
with ssh._cachedir(self.ssh, cache_dir) as remote, search.tempdir():
shutil.copy(pathlib.Path(__file__).parent / "search.py", "script.py")
with open("settings.json", "w") as f:
json.dump(self.search, f)

host = f'{self.ssh:s}:"{str(remote):s}"'
scp.copy(".", host, ["script.py", "settings.json"], progress=False, verbose=verbose)
_copyfunc(".", host, ["script.py", "settings.json"], progress=False, verbose=verbose)
exec_cmd(
f'ssh {self.ssh:s} "cd {str(remote)} && {self.python} script.py {str(self.root)}"',
verbose=verbose,
)
scp.copy(host, ".", ["files.txt"], progress=False, verbose=verbose)
_copyfunc(host, ".", ["files.txt"], progress=False, verbose=verbose)
return self._prune(sorted(pathlib.Path("files.txt").read_text().splitlines()))

def has_info(self) -> bool:
Expand Down Expand Up @@ -538,21 +546,22 @@ def _get_info(self, paths: list[pathlib.Path], sha256: bool, progress: bool, ver
np.array(csum, dtype="U64"),
)

with ssh.tempdir(self.ssh) as remote, search.tempdir():
cache_dir = ssh._shelephant_cachdir(self.ssh, self.python)
with ssh._cachedir(self.ssh, cache_dir) as remote, search.tempdir():
files = [str(self.root / i) for i in paths]
pathlib.Path("files.txt").write_text("\n".join(files))
pathlib.Path("sha256.txt").write_text("")
shutil.copy(pathlib.Path(__file__).parent / "compute_hash.py", "script.py")

extra = ["sha256.txt"] if sha256 else []
hostpath = f'{self.ssh:s}:"{str(remote):s}"'
scp.copy(
_copyfunc(
".", hostpath, extra + ["script.py", "files.txt"], progress=False, verbose=verbose
)
exec_cmd(
f'ssh {self.ssh:s} "cd {str(remote)} && {self.python} script.py"', verbose=verbose
)
scp.copy(
_copyfunc(
hostpath, ".", extra + ["size.txt", "mtime.txt"], progress=False, verbose=verbose
)
size = np.array(
Expand Down Expand Up @@ -1663,6 +1672,7 @@ class MyFmt(
parser = argparse.ArgumentParser(formatter_class=MyFmt, description=desc)

parser.add_argument("--version", action="version", version=version)
parser.add_argument("--cachedir", action="store_true", help="Print cachedir.")
parser.add_argument(
"--basedir", action="store_true", help="Print basedir (that contain '.shelephant')."
)
Expand All @@ -1679,6 +1689,10 @@ def info(args: list[str]):
parser = _info_parser()
args = parser.parse_args(args)

if args.cachedir:
print(user_cache_dir("shelephant", "tdegeus"))
return

if args.basedir:
print(_search_upwards_dir(".shelephant").parent)
return
Expand Down
46 changes: 46 additions & 0 deletions shelephant/ssh.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,56 @@
import pathlib
import re
import subprocess
from contextlib import contextmanager

from .external import exec_cmd


def _shelephant_cachdir(hostname: str, python: str = "python3") -> str:
"""
Return the path to the shelephant cache directory or a tempdir on a remote host.
:param hostname: Hostname.
:param python: Python executable (on remote).
:return: Path to the shelephant cache directory or a tempdir on a remote host.
"""

script = [
"from platformdirs import user_cache_dir",
"from pathlib import Path",
"d = user_cache_dir('shelephant', 'tdegeus')",
"Path(d).mkdir(exist_ok=True)",
"print(d)",
]
cmd = f"{python:s} -c \\\"{';'.join(script):s}\\\" || mktemp -d"
cmd = f'ssh {hostname:s} "{cmd:s}"'
ret = subprocess.check_output(cmd, stderr=subprocess.DEVNULL, shell=True).decode("utf-8")
return ret.strip().splitlines()[0]


@contextmanager
def _cachedir(hostname: str, cache_dir: str):
"""
Do nothing if the cache directory is a shelephant cache directory.
Otherwise, remove the cache directory on exit.
with _cachedir(hostname, cache_dir) as cachdir_tempdir:
print(cachdir_tempdir)
"""

try:
if re.match(r".*shelephant.*", str(cache_dir)):
rm = None
yield pathlib.Path(cache_dir)
else:
rm = cache_dir
yield pathlib.Path(cache_dir.strip())
finally:
if rm is not None:
cmd = f"ssh {hostname:s} rm -rf {cache_dir:s}"
exec_cmd(cmd, verbose=False)


def has_keys_set(hostname: str) -> bool:
"""
Check if the ssh keys are set for a given host.
Expand Down

0 comments on commit ee33e2a

Please sign in to comment.