Skip to content

Commit

Permalink
Install preferred version only if package is not already installed.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Aug 8, 2024
1 parent 00a675b commit 7ea23ec
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
36 changes: 32 additions & 4 deletions runhouse/resources/packages/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class Package(Resource):

def __init__(
self,
name: str = None,
install_method: str = None,
install_target: Union[str, "Folder"] = None,
install_args: str = None,
name: Optional[str] = None,
install_method: Optional[str] = None,
install_target: Optional[Union[str, "Folder"]] = None,
install_args: Optional[str] = None,
preferred_version: Optional[str] = None,
dryrun: bool = False,
**kwargs, # We have this here to ignore extra arguments when calling from from_config
):
Expand All @@ -65,6 +66,7 @@ def __init__(
self.install_method = install_method
self.install_target = install_target
self.install_args = install_args
self.preferred_version = preferred_version

def config(self, condensed=True):
# If the package is just a simple Package.from_string string, no
Expand All @@ -77,6 +79,7 @@ def config(self, condensed=True):
self.install_target, condensed
)
config["install_args"] = self.install_args
config["preferred_version"] = self.preferred_version
return config

def __str__(self):
Expand Down Expand Up @@ -233,6 +236,25 @@ def _install(self, env: Union[str, "Env"] = None, cluster: "Cluster" = None):
return

if self.install_method == "pip":

# If this is a generic pip package, with no version pinned, we want to check if there is a version
# already installed. If there is, then we ignore preferred version and leave the existing version.
# The user can always force a version install by doing `numpy==2.0.0` for example. Else, we install
# the preferred version, that matches their local.
if (
is_python_package_string(self.install_target)
and self.preferred_version is not None
):
# Check if this is installed
retcode = run_setup_command(
f"python -c \"import importlib.util; exit(0) if importlib.util.find_spec('{self.install_target}') else exit(1)\"",
cluster=cluster,
)[0]
if retcode != 0:
self.install_target = (
f"{self.install_target}=={self.preferred_version}"
)

install_cmd = self._pip_install_cmd(env=env, cluster=cluster)
logger.info(f"Running via install_method pip: {install_cmd}")
retcode = run_setup_command(install_cmd, cluster=cluster)[0]
Expand Down Expand Up @@ -501,6 +523,7 @@ def from_string(specifier: str, dryrun=False):
# If we are just defaulting to pip, attempt to install the same version of the package
# that is already installed locally
# Check if the target is only letters, nothing else. This means its a string like 'numpy'.
preferred_version = None
if install_method == "pip" and is_python_package_string(target):
locally_installed_version = find_locally_installed_version(target)
if locally_installed_version:
Expand All @@ -513,6 +536,10 @@ def from_string(specifier: str, dryrun=False):
path=local_install_path, system=Folder.DEFAULT_FS, dryrun=True
)

else:
# We want to preferrably install this version of the package server-side
preferred_version = locally_installed_version

# "Local" install method is a special case where we just copy a local folder and add to path
if install_method == "local":
return Package(
Expand All @@ -524,6 +551,7 @@ def from_string(specifier: str, dryrun=False):
install_target=target,
install_args=args,
install_method=install_method,
preferred_version=preferred_version,
dryrun=dryrun,
)
elif install_method == "rh":
Expand Down
15 changes: 7 additions & 8 deletions tests/test_resources/test_data/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from runhouse.utils import run_with_logs


def get_plotly_version():
import plotly
def get_bs4_version():
import bs4

return plotly.__version__
return bs4.__version__


class TestPackage(tests.test_resources.test_resource.TestResource):
Expand Down Expand Up @@ -152,13 +152,12 @@ def test_local_reqs_on_cluster(self, cluster, local_package):
assert remote_package.install_target.system == cluster

@pytest.mark.level("local")
@pytest.mark.skip("Feature deprecated for now")
def test_local_package_version_gets_installed(self, cluster):
run_with_logs("pip install plotly==5.9.0")
env = rh.env(name="temp_env", reqs=["plotly"])
run_with_logs("pip install beautifulsoup4==4.11.1")
env = rh.env(name="temp_env", reqs=["beautifulsoup4"])

remote_fn = rh.function(get_plotly_version, env=env).to(cluster)
assert remote_fn() == "5.9.0"
remote_fn = rh.function(get_bs4_version, env=env).to(cluster)
assert remote_fn() == "4.11.1"

# --------- basic torch index-url testing ---------
@pytest.mark.level("unit")
Expand Down

0 comments on commit 7ea23ec

Please sign in to comment.