Skip to content

Commit

Permalink
Prefix all async ClusterServlet functions with a.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Mar 28, 2024
1 parent 818107b commit 76bbf5e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
44 changes: 22 additions & 22 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ async def __init__(
##############################################
# Cluster config state storage methods
##############################################
async def get_cluster_config(self) -> Dict[str, Any]:
async def aget_cluster_config(self) -> Dict[str, Any]:
return self.cluster_config

async def set_cluster_config(self, cluster_config: Dict[str, Any]):
async def aset_cluster_config(self, cluster_config: Dict[str, Any]):
self.cluster_config = cluster_config

async def set_cluster_config_value(self, key: str, value: Any):
async def aset_cluster_config_value(self, key: str, value: Any):
self.cluster_config[key] = value

##############################################
# Auth cache internal functions
##############################################
async def add_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
async def aadd_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
self._auth_cache.add_user(token, refresh_cache)

async def resource_access_level(
async def aresource_access_level(
self, token: str, resource_uri: str
) -> Union[str, None]:
# If the token in this request matches that of the owner of the cluster,
Expand All @@ -57,13 +57,13 @@ async def resource_access_level(
return ResourceAccess.WRITE
return self._auth_cache.lookup_access_level(token, resource_uri)

async def user_resources(self, token: str) -> dict:
async def auser_resources(self, token: str) -> dict:
return self._auth_cache.get_user_resources(token)

async def get_username(self, token: str) -> str:
async def aget_username(self, token: str) -> str:
return self._auth_cache.get_username(token)

async def has_resource_access(self, token: str, resource_uri=None) -> bool:
async def ahas_resource_access(self, token: str, resource_uri=None) -> bool:
"""Checks whether user has read or write access to a given module saved on the cluster."""
from runhouse.rns.utils.api import ResourceAccess

Expand All @@ -72,7 +72,7 @@ async def has_resource_access(self, token: str, resource_uri=None) -> bool:
return False

cluster_uri = self.cluster_config["name"]
cluster_access = self.resource_access_level(token, cluster_uri)
cluster_access = await self.aresource_access_level(token, cluster_uri)
if cluster_access == ResourceAccess.WRITE:
# if user has write access to cluster will have access to all resources
return True
Expand All @@ -84,52 +84,52 @@ async def has_resource_access(self, token: str, resource_uri=None) -> bool:
# If module does not have a name, must have access to the cluster
return False

resource_access_level = self.resource_access_level(token, resource_uri)
resource_access_level = await self.aresource_access_level(token, resource_uri)
if resource_access_level not in [ResourceAccess.WRITE, ResourceAccess.READ]:
return False

return True

async def clear_auth_cache(self, token: str = None):
async def aclear_auth_cache(self, token: str = None):
self._auth_cache.clear_cache(token)

##############################################
# Key to servlet where it is stored mapping
##############################################
async def mark_env_servlet_name_as_initialized(self, env_servlet_name: str):
async def amark_env_servlet_name_as_initialized(self, env_servlet_name: str):
self._initialized_env_servlet_names.add(env_servlet_name)

async def is_env_servlet_name_initialized(self, env_servlet_name: str) -> bool:
async def ais_env_servlet_name_initialized(self, env_servlet_name: str) -> bool:
return env_servlet_name in self._initialized_env_servlet_names

async def get_all_initialized_env_servlet_names(self) -> Set[str]:
async def aget_all_initialized_env_servlet_names(self) -> Set[str]:
return self._initialized_env_servlet_names

async def get_key_to_env_servlet_name_dict_keys(self) -> List[Any]:
async def aget_key_to_env_servlet_name_dict_keys(self) -> List[Any]:
return list(self._key_to_env_servlet_name.keys())

async def get_key_to_env_servlet_name_dict(self) -> Dict[Any, str]:
async def aget_key_to_env_servlet_name_dict(self) -> Dict[Any, str]:
return self._key_to_env_servlet_name

async def get_env_servlet_name_for_key(self, key: Any) -> str:
async def aget_env_servlet_name_for_key(self, key: Any) -> str:
return self._key_to_env_servlet_name.get(key, None)

async def put_env_servlet_name_for_key(self, key: Any, env_servlet_name: str):
if not await self.is_env_servlet_name_initialized(env_servlet_name):
async def aput_env_servlet_name_for_key(self, key: Any, env_servlet_name: str):
if not await self.ais_env_servlet_name_initialized(env_servlet_name):
raise ValueError(
f"Env servlet name {env_servlet_name} not initialized, and you tried to mark a resource as in it."
)
self._key_to_env_servlet_name[key] = env_servlet_name

async def pop_env_servlet_name_for_key(self, key: Any, *args) -> str:
async def apop_env_servlet_name_for_key(self, key: Any, *args) -> str:
# *args allows us to pass default or not
return self._key_to_env_servlet_name.pop(key, *args)

async def clear_key_to_env_servlet_name_dict(self):
async def aclear_key_to_env_servlet_name_dict(self):
self._key_to_env_servlet_name = {}

##############################################
# Remove Env Servlet
##############################################
async def remove_env_servlet_name(self, env_servlet_name: str):
async def aremove_env_servlet_name(self, env_servlet_name: str):
self._initialized_env_servlet_names.remove(env_servlet_name)
34 changes: 17 additions & 17 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_cluster_servlet(create_if_not_exists: bool = False):
)

# Make sure cluster servlet is actually initialized
ray.get(cluster_servlet.get_cluster_config.remote())
ray.get(cluster_servlet.aget_cluster_config.remote())

return cluster_servlet

Expand Down Expand Up @@ -354,7 +354,7 @@ async def aget_cluster_config(self):
# TODO: Potentially add caching here
if self.cluster_servlet is not None:
return await self.acall_actor_method(
self.cluster_servlet, "get_cluster_config"
self.cluster_servlet, "aget_cluster_config"
)
else:
return {}
Expand All @@ -364,12 +364,12 @@ def get_cluster_config(self):

async def aset_cluster_config(self, config: Dict[str, Any]):
return await self.acall_actor_method(
self.cluster_servlet, "set_cluster_config", config
self.cluster_servlet, "aset_cluster_config", config
)

async def aset_cluster_config_value(self, key: str, value: Any):
return await self.acall_actor_method(
self.cluster_servlet, "set_cluster_config_value", key, value
self.cluster_servlet, "aset_cluster_config_value", key, value
)

def set_cluster_config_value(self, key: str, value: Any):
Expand All @@ -380,7 +380,7 @@ def set_cluster_config_value(self, key: str, value: Any):
##############################################
async def aadd_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
return await self.acall_actor_method(
self.cluster_servlet, "add_user_to_auth_cache", token, refresh_cache
self.cluster_servlet, "aadd_user_to_auth_cache", token, refresh_cache
)

def add_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
Expand All @@ -389,7 +389,7 @@ def add_user_to_auth_cache(self, token: str, refresh_cache: bool = True):
async def aresource_access_level(self, token: str, resource_uri: str):
return await self.acall_actor_method(
self.cluster_servlet,
"resource_access_level",
"aresource_access_level",
token,
resource_uri,
)
Expand All @@ -399,15 +399,15 @@ def resource_access_level(self, token: str, resource_uri: str):

async def auser_resources(self, token: str):
return await self.acall_actor_method(
self.cluster_servlet, "user_resources", token
self.cluster_servlet, "auser_resources", token
)

def user_resources(self, token: str):
return sync_function(self.auser_resources)(token)

async def aget_username(self, token: str):
return await self.acall_actor_method(
self.cluster_servlet, "get_username", token
self.cluster_servlet, "aget_username", token
)

async def ahas_resource_access(self, token: str, resource_uri=None) -> bool:
Expand Down Expand Up @@ -452,7 +452,7 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool:

async def aclear_auth_cache(self, token: str = None):
return await self.acall_actor_method(
self.cluster_servlet, "clear_auth_cache", token
self.cluster_servlet, "aclear_auth_cache", token
)

##############################################
Expand All @@ -461,20 +461,20 @@ async def aclear_auth_cache(self, token: str = None):
async def amark_env_servlet_name_as_initialized(self, env_servlet_name: str):
return await self.acall_actor_method(
self.cluster_servlet,
"mark_env_servlet_name_as_initialized",
"amark_env_servlet_name_as_initialized",
env_servlet_name,
)

async def ais_env_servlet_name_initialized(self, env_servlet_name: str) -> bool:
return await self.acall_actor_method(
self.cluster_servlet, "is_env_servlet_name_initialized", env_servlet_name
self.cluster_servlet, "ais_env_servlet_name_initialized", env_servlet_name
)

async def aget_all_initialized_env_servlet_names(self) -> Set[str]:
return list(
await self.acall_actor_method(
self.cluster_servlet,
"get_all_initialized_env_servlet_names",
"aget_all_initialized_env_servlet_names",
)
)

Expand All @@ -483,28 +483,28 @@ def get_all_initialized_env_servlet_names(self) -> Set[str]:

async def aget_env_servlet_name_for_key(self, key: Any):
return await self.acall_actor_method(
self.cluster_servlet, "get_env_servlet_name_for_key", key
self.cluster_servlet, "aget_env_servlet_name_for_key", key
)

def get_env_servlet_name_for_key(self, key: Any):
return sync_function(self.aget_env_servlet_name_for_key)(key)

async def _aput_env_servlet_name_for_key(self, key: Any, env_servlet_name: str):
return await self.acall_actor_method(
self.cluster_servlet, "put_env_servlet_name_for_key", key, env_servlet_name
self.cluster_servlet, "aput_env_servlet_name_for_key", key, env_servlet_name
)

async def _apop_env_servlet_name_for_key(self, key: Any, *args) -> str:
return await self.acall_actor_method(
self.cluster_servlet, "pop_env_servlet_name_for_key", key, *args
self.cluster_servlet, "apop_env_servlet_name_for_key", key, *args
)

##############################################
# Remove Env Servlet
##############################################
async def aremove_env_servlet_name(self, env_servlet_name: str):
return await self.acall_actor_method(
self.cluster_servlet, "remove_env_servlet_name", env_servlet_name
self.cluster_servlet, "aremove_env_servlet_name", env_servlet_name
)

##############################################
Expand All @@ -528,7 +528,7 @@ def keys_local(self) -> List[Any]:
async def akeys(self) -> List[Any]:
# Return keys across the cluster, not only in this process
return await self.acall_actor_method(
self.cluster_servlet, "get_key_to_env_servlet_name_dict_keys"
self.cluster_servlet, "aget_key_to_env_servlet_name_dict_keys"
)

def keys(self) -> List[Any]:
Expand Down

0 comments on commit 76bbf5e

Please sign in to comment.