diff --git a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py index 8de7ef72..08b25354 100644 --- a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ b/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -1,4 +1,5 @@ import json +import os from dataclasses import asdict from typing import Any, AsyncIterable, Dict, Optional @@ -193,9 +194,40 @@ async def create_text_generation_inference_bundle( command = [] if checkpoint_path is not None: if checkpoint_path.startswith("s3://"): - command = ["bash", "launch_s3_model.sh", checkpoint_path, str(num_shards)] + base_path = checkpoint_path.split("/")[-1] + final_weights_folder = "model_files" + subcommands = [] + + s5cmd = "s5cmd" + subcommands.append( + f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}" + ) + + if base_path.endswith(".tar"): + # If the checkpoint file is a tar file, extract it into final_weights_folder + subcommands.extend( + [ + f"{s5cmd} cp {checkpoint_path} .", + f"mkdir -p {final_weights_folder}", + f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}", + ] + ) + else: + subcommands.append( + f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + + subcommands.append( + f"text-generation-launcher --hostname :: --model-id ./{final_weights_folder} --num-shard {num_shards} --port 5005" + ) + if quantize: - command = command + [f"'--quantize {str(quantize)}'"] + subcommands[-1] = subcommands[-1] + f" --quantize {quantize}" + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] else: raise ObjectHasInvalidValueException( f"Not able to load checkpoint path {checkpoint_path}." @@ -599,7 +631,8 @@ async def execute( inference_request = EndpointPredictV1Request(args=args) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: @@ -632,7 +665,8 @@ async def execute( } inference_request = EndpointPredictV1Request(args=tgi_args) predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request + topic=model_endpoint.record.destination, + predict_request=inference_request, ) if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: