diff --git a/src/oumi/cli/launch.py b/src/oumi/cli/launch.py index 588961109..e137b7346 100644 --- a/src/oumi/cli/launch.py +++ b/src/oumi/cli/launch.py @@ -13,11 +13,19 @@ # limitations under the License. import io +import threading import time from collections import defaultdict -from multiprocessing.pool import Pool from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Callable, Optional +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Optional, + TypeVar, + cast, +) import typer from rich.columns import Columns @@ -34,6 +42,8 @@ if TYPE_CHECKING: from oumi.core.launcher import BaseCluster, JobStatus +T = TypeVar("T") + def _get_working_dir(current: Optional[str]) -> Optional[str]: """Prompts the user to select the working directory, if relevant.""" @@ -52,24 +62,39 @@ def _get_working_dir(current: Optional[str]) -> Optional[str]: def _print_and_wait( - message: str, task: Callable[..., bool], asynchronous=True, **kwargs -) -> None: + message: str, task: Callable[..., T], asynchronous: bool = True, **kwargs +) -> T: """Prints a message with a loading spinner until the provided task is done.""" with cli_utils.CONSOLE.status(message): if asynchronous: - with Pool(processes=1) as worker_pool: - task_done = False - while not task_done: - worker_result = worker_pool.apply_async(task, kwds=kwargs) - worker_result.wait() - # Call get() to reraise any exceptions that occurred in the worker. - task_done = worker_result.get() + result_container: dict[str, Any] = {} + exception_container: dict[str, Exception] = {} + + def _worker(): + try: + result_container["value"] = task(**kwargs) + except Exception as e: + exception_container["error"] = e + + worker_thread = threading.Thread(target=_worker) + worker_thread.start() + + while worker_thread.is_alive(): + time.sleep(0.1) + + worker_thread.join() + + # Call get() to reraise any exceptions that occurred in the worker. + if "error" in exception_container: + raise exception_container["error"] + return cast(T, result_container.get("value")) else: # Synchronous tasks should be atomic and not block for a significant amount # of time. If a task is blocking, it should be run asynchronously. while not task(**kwargs): sleep_duration = 0.1 time.sleep(sleep_duration) + return cast(T, None) def _is_job_done(id: str, cloud: str, cluster: str) -> bool: diff --git a/src/oumi/infer.py b/src/oumi/infer.py index 578ed009d..0e8ed0d21 100644 --- a/src/oumi/infer.py +++ b/src/oumi/infer.py @@ -15,6 +15,7 @@ from typing import Optional from oumi.builders.inference_engines import build_inference_engine +from oumi.cli.launch import _print_and_wait from oumi.core.configs import InferenceConfig, InferenceEngineType from oumi.core.inference import BaseInferenceEngine from oumi.core.types.conversation import ( @@ -55,15 +56,22 @@ def infer_interactive( except (EOFError, KeyboardInterrupt): # Triggered by Ctrl+D/Ctrl+C print("\nExiting...") return - model_response = infer( - config=config, - inputs=[ - input_text, - ], - system_prompt=system_prompt, - input_image_bytes=input_image_bytes, - inference_engine=inference_engine, + + def _task_to_run(): + return infer( + config=config, + inputs=[input_text], + inference_engine=inference_engine, + input_image_bytes=input_image_bytes, + system_prompt=system_prompt, + ) + + model_response = _print_and_wait( + "Running inference...", + _task_to_run, + asynchronous=True, ) + for g in model_response: print("------------") print(repr(g))