Skip to content

Commit

Permalink
Change pointers to store absolute path and work with .to.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Jul 12, 2024
1 parent 2dde9db commit 41ab1af
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 52 deletions.
12 changes: 8 additions & 4 deletions runhouse/resources/functions/aws_lambda_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,15 @@ def aws_lambda_fn(

# extract function pointers, path to code and arg names from callable function.
handler_function_name = fn.__name__
fn_pointers, req_to_add = Function._extract_pointers(
fn, reqs=[] if env is None else env.reqs
fn_pointers = Function._extract_pointers(fn)
(
local_path_containing_function,
should_add,
) = Function._get_local_path_containing_module(
fn_pointers[0], reqs=[] if env is None else env.reqs
)
if req_to_add and env is not None:
env.reqs = [req_to_add] + env.reqs
if should_add and env is not None:
env.reqs = [str(local_path_containing_function)] + env.reqs
paths_to_code = [extract_module_path(fn)]
if name is None:
name = fn.__name__
Expand Down
10 changes: 7 additions & 3 deletions runhouse/resources/functions/function_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ def function(

fn_pointers = None
if callable(fn):
fn_pointers, req_to_add = Function._extract_pointers(fn, reqs=env.reqs)
if req_to_add:
env.reqs = [req_to_add] + env.reqs
fn_pointers = Function._extract_pointers(fn)
(
local_path_containing_module,
should_add,
) = Function._get_local_path_containing_module(fn_pointers[0], env.reqs)
if should_add:
env.reqs = [str(local_path_containing_module)] + env.reqs
if fn_pointers[1] == "notebook":
fn_pointers = Function._handle_nb_fn(
fn,
Expand Down
127 changes: 82 additions & 45 deletions runhouse/resources/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def __init__(
self._env = self._system.default_env if self._system else Env()
# If we're creating pointers, we're also local to the class definition and package, so it should be
# set as the workdir (we can do this in a fancier way later)
pointers, req_to_add = Module._extract_pointers(
self.__class__, reqs=self._env.reqs
)
if req_to_add:
self._env.reqs = [req_to_add] + self._env.reqs
pointers = Module._extract_pointers(self.__class__)
(
local_path_containing_module,
should_add,
) = Module._get_local_path_containing_module(pointers[0], self._env.reqs)
if should_add:
self._env.reqs = [str(local_path_containing_module)] + self._env.reqs
self._pointers = pointers
self._endpoint = endpoint
self._signature = signature
Expand Down Expand Up @@ -445,9 +447,28 @@ def to(
)

env = self.env if not env else env

env = _get_env_from(env)

# We need to change the pointers to the remote import path if we're sending this module to a remote cluster,
# and we need to add the local path to the module to the requirements if it's not already there.
remote_import_path = None
if env and (self._pointers or getattr(self, "fn_pointers", None)):
pointers = self._pointers if self._pointers else self.fn_pointers

# Update the envs reqs with the local path to the module if it's not already there
(
local_path_containing_module,
should_add,
) = Module._get_local_path_containing_module(pointers[0], env.reqs)
if should_add:
env.reqs = [str(local_path_containing_module)] + env.reqs

# Figure out what the import path would be on the remote system
remote_import_path = str(
local_path_containing_module.name
/ Path(pointers[0]).relative_to(local_path_containing_module)
)

if system:
system.check_server()
if isinstance(env, Env):
Expand All @@ -471,6 +492,21 @@ def to(
)
new_module.dryrun = True

# Set remote import path
if remote_import_path:
if new_module._pointers:
new_module._pointers = (
remote_import_path,
new_module._pointers[1],
new_module._pointers[2],
)
else:
new_module.fn_pointers = (
remote_import_path,
new_module.fn_pointers[1],
new_module.fn_pointers[2],
)

if isinstance(system, Cluster):
new_name = name or self.name or self.default_name()
if self.rns_address:
Expand Down Expand Up @@ -971,12 +1007,35 @@ def _is_running_in_notebook(module_path: Union[str, None]) -> bool:
return False

@staticmethod
def _find_req_containing_module_root_path(
def _extract_pointers(raw_cls_or_fn: Union[Type, Callable]):
"""Get the path to the module, module name, and function name to be able to import it on the server"""
if not (isinstance(raw_cls_or_fn, type) or isinstance(raw_cls_or_fn, Callable)):
raise TypeError(
f"Expected Type or Callable but received {type(raw_cls_or_fn)}"
)

root_path, module_name, cls_or_fn_name = get_module_import_info(raw_cls_or_fn)

return (
root_path,
module_name,
cls_or_fn_name,
)

@staticmethod
def _get_local_path_containing_module(
root_path: str, reqs: List[str]
) -> Optional[str]:
"""Find the req containing the module root path"""
) -> Tuple[Path, bool]:
"""
Find the directory containing the module root path in the reqs list.
If it is not found, find the working directory that contains the module root path and return that,
along with a flag indicating whether the module should be added to the reqs list.
"""

# First, check if the module is already included in one of the directories in reqs
local_path = None
local_path_containing_module = None
for req in reqs:
if isinstance(req, str):
req = Package.from_string(req)
Expand All @@ -986,52 +1045,27 @@ def _find_req_containing_module_root_path(
and not isinstance(req.install_target, str)
and req.install_target.is_local()
):
local_path = Path(req.install_target.local_path)
local_path_containing_module = Path(req.install_target.local_path)

if local_path:
if local_path_containing_module:
try:
# Module path relative to package
Path(root_path).relative_to(local_path)
Path(root_path).relative_to(local_path_containing_module)
break
except ValueError: # Not a subdirectory
local_path = None
local_path_containing_module = None
pass

return local_path

@staticmethod
def _extract_pointers(raw_cls_or_fn: Union[Type, Callable], reqs: List[str]):
"""Get the path to the module, module name, and function name to be able to import it on the server"""
if not (isinstance(raw_cls_or_fn, type) or isinstance(raw_cls_or_fn, Callable)):
raise TypeError(
f"Expected Type or Callable but received {type(raw_cls_or_fn)}"
)

root_path, module_name, cls_or_fn_name = get_module_import_info(raw_cls_or_fn)

local_path_containing_module = Module._find_req_containing_module_root_path(
root_path, reqs
)
# Only add this to the env's reqs if the module is not in one of the directories in reqs
add = local_path_containing_module is None

if not local_path_containing_module:
# If the module is not in one of the directories in reqs, we just use the full path,
# then we'll create a new "req" containing the Module
# TODO: should this "req" be a Package? I'll just start with a string for now
local_path_containing_module = Path(locate_working_dir(root_path))
req_to_add = str(local_path_containing_module)
else:
req_to_add = None

remote_import_path = str(
local_path_containing_module.name
/ Path(root_path).relative_to(local_path_containing_module)
)

return (
remote_import_path,
module_name,
cls_or_fn_name,
), req_to_add
return local_path_containing_module, add

def openapi_spec(self, spec_name: Optional[str] = None):
"""Generate an OpenAPI spec for the module.
Expand Down Expand Up @@ -1331,9 +1365,12 @@ class (e.g. ``to``, ``fetch``, etc.). Properties and private methods are not int
if not env:
env = Env()

cls_pointers, working_dir_to_add = Module._extract_pointers(cls, env.reqs)
if working_dir_to_add is not None:
env.reqs = [str(working_dir_to_add)] + env.reqs
cls_pointers = Module._extract_pointers(cls)
local_path_containing_module, should_add = Module._get_local_path_containing_module(
cls_pointers[0], env.reqs
)
if should_add:
env.reqs = [str(local_path_containing_module)] + env.reqs

name = name or (
cls_pointers[2] if cls_pointers else _generate_default_name(prefix="module")
Expand Down

0 comments on commit 41ab1af

Please sign in to comment.