Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
from dataclasses import asdict
from typing import Any, AsyncIterable, Dict, Optional

Expand Down Expand Up @@ -193,9 +194,40 @@ async def create_text_generation_inference_bundle(
command = []
if checkpoint_path is not None:
if checkpoint_path.startswith("s3://"):
command = ["bash", "launch_s3_model.sh", checkpoint_path, str(num_shards)]
base_path = checkpoint_path.split("/")[-1]
final_weights_folder = "model_files"
subcommands = []

s5cmd = "s5cmd"
subcommands.append(
f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}"
)

if base_path.endswith(".tar"):
# If the checkpoint file is a tar file, extract it into final_weights_folder
subcommands.extend(
[
f"{s5cmd} cp {checkpoint_path} .",
f"mkdir -p {final_weights_folder}",
f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}",
]
)
else:
subcommands.append(
f"{s5cmd} --numworkers 512 cp --concurrency 10 {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
)

subcommands.append(
f"text-generation-launcher --hostname :: --model-id ./{final_weights_folder} --num-shard {num_shards} --port 5005"
)

if quantize:
command = command + [f"'--quantize {str(quantize)}'"]
subcommands[-1] = subcommands[-1] + f" --quantize {quantize}"
command = [
"/bin/bash",
"-c",
";".join(subcommands),
]
else:
raise ObjectHasInvalidValueException(
f"Not able to load checkpoint path {checkpoint_path}."
Expand Down Expand Up @@ -599,7 +631,8 @@ async def execute(

inference_request = EndpointPredictV1Request(args=args)
predict_result = await inference_gateway.predict(
topic=model_endpoint.record.destination, predict_request=inference_request
topic=model_endpoint.record.destination,
predict_request=inference_request,
)

if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None:
Expand Down Expand Up @@ -632,7 +665,8 @@ async def execute(
}
inference_request = EndpointPredictV1Request(args=tgi_args)
predict_result = await inference_gateway.predict(
topic=model_endpoint.record.destination, predict_request=inference_request
topic=model_endpoint.record.destination,
predict_request=inference_request,
)

if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
Expand Down