Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/tetra_rp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def remote(
dependencies: Optional[List[str]] = None,
system_dependencies: Optional[List[str]] = None,
accelerate_downloads: bool = True,
hf_models_to_cache: Optional[List[str]] = None,
**extra,
):
"""
Expand All @@ -33,8 +32,6 @@ def remote(
environment before executing the function. Defaults to None.
accelerate_downloads (bool, optional): Enable download acceleration for dependencies and models.
Defaults to True.
hf_models_to_cache (List[str], optional): List of HuggingFace model IDs to pre-cache using
download acceleration. Defaults to None.
extra (dict, optional): Additional parameters for the execution of the resource. Defaults to an empty dict.

Returns:
Expand All @@ -47,7 +44,6 @@ def remote(
resource_config=my_resource_config,
dependencies=["numpy", "pandas"],
accelerate_downloads=True,
hf_models_to_cache=["gpt2", "bert-base-uncased"]
)
async def my_function(data):
# Function logic here
Expand All @@ -64,7 +60,6 @@ def decorator(func_or_class):
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
extra,
)
else:
Expand All @@ -82,7 +77,6 @@ async def wrapper(*args, **kwargs):
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
)
Expand Down
3 changes: 0 additions & 3 deletions src/tetra_rp/execute_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def create_remote_class(
dependencies: Optional[List[str]],
system_dependencies: Optional[List[str]],
accelerate_downloads: bool,
hf_models_to_cache: Optional[List[str]],
extra: dict,
):
"""
Expand All @@ -222,7 +221,6 @@ def __init__(self, *args, **kwargs):
self._dependencies = dependencies or []
self._system_dependencies = system_dependencies or []
self._accelerate_downloads = accelerate_downloads
self._hf_models_to_cache = hf_models_to_cache
self._extra = extra
self._constructor_args = args
self._constructor_kwargs = kwargs
Expand Down Expand Up @@ -307,7 +305,6 @@ async def method_proxy(*args, **kwargs):
dependencies=self._dependencies,
system_dependencies=self._system_dependencies,
accelerate_downloads=self._accelerate_downloads,
hf_models_to_cache=self._hf_models_to_cache,
instance_id=self._instance_id,
create_new_instance=not hasattr(
self, "_stub"
Expand Down
1 change: 0 additions & 1 deletion src/tetra_rp/protos/remote_execution.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ message FunctionRequest {

// Download acceleration fields
optional bool accelerate_downloads = 19; // Enable download acceleration for dependencies and models (default: true)
repeated string hf_models_to_cache = 20; // List of HuggingFace model IDs to pre-cache using acceleration
}

// The response message containing the execution result or error
Expand Down
4 changes: 0 additions & 4 deletions src/tetra_rp/protos/remote_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ class FunctionRequest(BaseModel):
default=True,
description="Enable download acceleration for dependencies and models",
)
hf_models_to_cache: Optional[List[str]] = Field(
default=None,
description="List of HuggingFace model IDs to pre-cache using acceleration",
)

@model_validator(mode="after")
def validate_execution_requirements(self) -> "FunctionRequest":
Expand Down
2 changes: 0 additions & 2 deletions src/tetra_rp/stubs/live_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def prepare_request(
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
):
Expand All @@ -79,7 +78,6 @@ def prepare_request(
"dependencies": dependencies,
"system_dependencies": system_dependencies,
"accelerate_downloads": accelerate_downloads,
"hf_models_to_cache": hf_models_to_cache,
}

# Thread-safe cache access
Expand Down
4 changes: 0 additions & 4 deletions src/tetra_rp/stubs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ async def stubbed_resource(
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
) -> dict:
Expand All @@ -43,7 +42,6 @@ async def stubbed_resource(
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
)
Expand Down Expand Up @@ -78,7 +76,6 @@ async def stubbed_resource(
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
) -> dict:
Expand All @@ -103,7 +100,6 @@ async def stubbed_resource(
dependencies,
system_dependencies,
accelerate_downloads,
hf_models_to_cache,
*args,
**kwargs,
) -> dict:
Expand Down
18 changes: 7 additions & 11 deletions tests/integration/test_class_execution_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def get_state(self):
}

RemoteCounter = create_remote_class(
StatefulCounter, self.mock_resource_config, [], [], True, None, {}
StatefulCounter, self.mock_resource_config, [], [], True, {}
)

counter = RemoteCounter(5)
Expand Down Expand Up @@ -276,7 +276,7 @@ def get_completed_count(self):
return self.tasks_completed

RemoteWorker = create_remote_class(
AsyncWorker, self.mock_resource_config, [], [], True, None, {}
AsyncWorker, self.mock_resource_config, [], [], True, {}
)

worker = RemoteWorker()
Expand Down Expand Up @@ -376,7 +376,6 @@ def process_with_config(self, input_data):
["scikit-learn", "pandas"],
[], # system_dependencies
True, # accelerate_downloads
None, # hf_models_to_cache
{}, # extra
)

Expand Down Expand Up @@ -478,7 +477,7 @@ def get_service_info(self):
api_keys = ["key1", "key2", "key3"]

RemoteDataService = create_remote_class(
DataService, self.mock_resource_config, ["psycopg2"], [], True, None, {}
DataService, self.mock_resource_config, ["psycopg2"], [], True, {}
)

service = RemoteDataService(db_conn, cache_conf, api_keys=api_keys)
Expand Down Expand Up @@ -549,7 +548,7 @@ def safe_method(self):
return "This always works"

RemoteErrorProneClass = create_remote_class(
ErrorProneClass, self.mock_resource_config, [], [], True, None, {}
ErrorProneClass, self.mock_resource_config, [], [], True, {}
)

error_instance = RemoteErrorProneClass(should_fail=True)
Expand Down Expand Up @@ -585,7 +584,7 @@ def simple_method(self):
return "hello"

RemoteSimpleClass = create_remote_class(
SimpleClass, self.mock_resource_config, [], [], True, None, {}
SimpleClass, self.mock_resource_config, [], [], True, {}
)

instance = RemoteSimpleClass()
Expand Down Expand Up @@ -621,7 +620,7 @@ def process_file(self):

with tempfile.NamedTemporaryFile() as temp_file:
RemoteUnserializableClass = create_remote_class(
UnserializableClass, self.mock_resource_config, [], [], True, None, {}
UnserializableClass, self.mock_resource_config, [], [], True, {}
)

# This should not fail during initialization (lazy serialization)
Expand Down Expand Up @@ -669,7 +668,6 @@ def slow_method(self, duration):
[],
[],
True,
None,
{"timeout": 5}, # 5 second timeout
)

Expand Down Expand Up @@ -705,7 +703,6 @@ def test_invalid_class_type_error(self):
[],
[],
True,
None,
{},
)

Expand All @@ -715,7 +712,7 @@ def not_a_class():

with pytest.raises(TypeError, match="Expected a class"):
create_remote_class(
not_a_class, self.mock_resource_config, [], [], True, None, {}
not_a_class, self.mock_resource_config, [], [], True, {}
)

# Note: Testing class without __name__ is not practically possible
Expand All @@ -738,7 +735,6 @@ def use_dependency(self):
["nonexistent-package==999.999.999"], # Invalid package
[],
True,
None,
{},
)

Expand Down
22 changes: 11 additions & 11 deletions tests/unit/test_class_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self, value):
self.value = value

RemoteCacheTestClass = create_remote_class(
CacheTestClass, self.mock_resource_config, [], [], True, None, {}
CacheTestClass, self.mock_resource_config, [], [], True, {}
)

# First instance - should be cache miss
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(self, x, y=None):
self.y = y

RemoteMultiArgClass = create_remote_class(
MultiArgClass, self.mock_resource_config, [], [], True, None, {}
MultiArgClass, self.mock_resource_config, [], [], True, {}
)

# Different args should create different cache entries
Expand All @@ -198,7 +198,7 @@ def __init__(self, file_handle, name="default"):
self.name = name

RemoteFileHandlerClass = create_remote_class(
FileHandlerClass, self.mock_resource_config, [], [], True, None, {}
FileHandlerClass, self.mock_resource_config, [], [], True, {}
)

with tempfile.NamedTemporaryFile() as temp_file:
Expand All @@ -224,7 +224,7 @@ def __init__(self, value):
self.value = value

RemoteOptimizationTestClass = create_remote_class(
OptimizationTestClass, self.mock_resource_config, [], [], True, None, {}
OptimizationTestClass, self.mock_resource_config, [], [], True, {}
)

with patch("tetra_rp.execute_class.extract_class_code_simple") as mock_extract:
Expand All @@ -250,7 +250,7 @@ def get_value(self):
return self.value

RemoteConsistencyTestClass = create_remote_class(
ConsistencyTestClass, self.mock_resource_config, [], [], True, None, {}
ConsistencyTestClass, self.mock_resource_config, [], [], True, {}
)

instance1 = RemoteConsistencyTestClass(1)
Expand All @@ -273,7 +273,7 @@ def __init__(self, file_handle):
self.file_handle = file_handle

RemoteUUIDFallbackClass = create_remote_class(
UUIDFallbackClass, self.mock_resource_config, [], [], True, None, {}
UUIDFallbackClass, self.mock_resource_config, [], [], True, {}
)

with (
Expand All @@ -299,7 +299,7 @@ def __init__(self, value):
self.value = value

RemoteMemoryTestClass = create_remote_class(
MemoryTestClass, self.mock_resource_config, [], [], True, None, {}
MemoryTestClass, self.mock_resource_config, [], [], True, {}
)

# Create many instances with same args - should only create one cache entry
Expand All @@ -323,10 +323,10 @@ def __init__(self, value):
self.value = value

RemoteClassTypeA = create_remote_class(
ClassTypeA, self.mock_resource_config, [], [], True, None, {}
ClassTypeA, self.mock_resource_config, [], [], True, {}
)
RemoteClassTypeB = create_remote_class(
ClassTypeB, self.mock_resource_config, [], [], True, None, {}
ClassTypeB, self.mock_resource_config, [], [], True, {}
)

instanceA = RemoteClassTypeA(42)
Expand Down Expand Up @@ -358,7 +358,7 @@ def __init__(self, value, config=None):
)

RemoteStructureTestClass = create_remote_class(
StructureTestClass, resource_config, [], [], True, None, {}
StructureTestClass, resource_config, [], [], True, {}
)

instance = RemoteStructureTestClass(42, config={"key": "value"})
Expand Down Expand Up @@ -401,7 +401,7 @@ def __init__(self, data):
)

RemoteSerializationTestClass = create_remote_class(
SerializationTestClass, resource_config, [], [], True, None, {}
SerializationTestClass, resource_config, [], [], True, {}
)

test_data = {"test": [1, 2, 3]}
Expand Down
Loading
Loading