Skip to content
Merged
42 changes: 38 additions & 4 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
DownloadCheckpointType,
FinetuneEventType,
FinetuneTrainingLimits,
FullTrainingType,
LoRATrainingType,
)
from together.utils import (
finetune_price_to_dollars,
Expand All @@ -29,13 +31,21 @@

_CONFIRMATION_MESSAGE = (
"You are about to create a fine-tuning job. "
"The cost of your job will be determined by the model size, the number of tokens "
"The estimated price of this job is {price}. "
"The actual cost of your job will be determined by the model size, the number of tokens "
"in the training file, the number of tokens in the validation file, the number of epochs, and "
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
"the number of evaluations. Visit https://www.together.ai/pricing to learn more about fine-tuning pricing.\n"
"{warning}"
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
"Do you want to proceed?"
)

_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of this job is significantly greater than your current credit limit and balance combined. "
"It will likely get cancelled due to insufficient funds. "
"Consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
)


class DownloadCheckpointTypeChoice(click.Choice):
def __init__(self) -> None:
Expand Down Expand Up @@ -357,12 +367,36 @@ def create(
"You have specified a number of evaluation loops but no validation file."
)

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
finetune_price_estimation_result = client.fine_tuning.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)

price = click.style(
f"${finetune_price_estimation_result.estimated_total_price:.2f}",
bold=True,
)

if not finetune_price_estimation_result.allowed_to_proceed:
warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True)
else:
warning = ""

confirmation_message = _CONFIRMATION_MESSAGE.format(
price=price,
warning=warning,
)

if confirm or click.confirm(confirmation_message, default=True, show_default=True):
response = client.fine_tuning.create(
**training_args,
verbose=True,
)

report_string = f"Successfully submitted a fine-tuning job {response.id}"
if response.created_at is not None:
created_time = datetime.strptime(
Expand Down
205 changes: 204 additions & 1 deletion src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
FinetuneLRScheduler,
FinetuneRequest,
FinetuneResponse,
FinetunePriceEstimationRequest,
FinetunePriceEstimationResponse,
FinetuneTrainingLimits,
FullTrainingType,
LinearLRScheduler,
Expand All @@ -31,7 +33,7 @@
TrainingMethodSFT,
TrainingType,
)
from together.types.finetune import DownloadCheckpointType
from together.types.finetune import DownloadCheckpointType, TrainingMethod
from together.utils import log_warn_once, normalize_key


Expand All @@ -42,6 +44,12 @@
TrainingMethodSFT().method,
TrainingMethodDPO().method,
}
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of the fine-tuning job is {} which is significantly "
"greater than your current credit limit and balance combined. "
"It will likely get cancelled due to insufficient funds. "
"Proceed at your own risk."
)


def create_finetune_request(
Expand Down Expand Up @@ -473,12 +481,34 @@ def create(
hf_api_token=hf_api_token,
hf_output_repo_name=hf_output_repo_name,
)
if from_checkpoint is None and from_hf_model is None:
price_estimation_result = self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model_name,
n_epochs=finetune_request.n_epochs,
n_evals=finetune_request.n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)
price_limit_passed = price_estimation_result.allowed_to_proceed
else:
# unsupported case
price_limit_passed = True

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_limit_passed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price
)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

response, _, _ = requestor.request(
Expand All @@ -493,6 +523,81 @@ def create(

return FinetuneResponse(**response.data)

def estimate_price(
self,
*,
training_file: str,
model: str,
validation_file: str | None = None,
n_epochs: int | None = 1,
n_evals: int | None = 0,
training_type: str = "lora",
training_method: str = "sft",
) -> FinetunePriceEstimationResponse:
"""
Estimates the price of a fine-tuning job

Args:
training_file (str): File-ID of a file uploaded to the Together API
model (str): Name of the base model to run fine-tune job on
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
training_type (str, optional): Training type. Defaults to "lora".
training_method (str, optional): Training method. Defaults to "sft".

Returns:
FinetunePriceEstimationResponse: Object containing the price estimation result.
"""
training_type_cls: TrainingType
training_method_cls: TrainingMethod

if training_method == "sft":
training_method_cls = TrainingMethodSFT(method="sft")
elif training_method == "dpo":
training_method_cls = TrainingMethodDPO(method="dpo")
else:
raise ValueError(f"Unknown training method: {training_method}")

if training_type.lower() == "lora":
# parameters of lora are unused in price estimation
# but we need to set them to valid values
training_type_cls = LoRATrainingType(
type="Lora",
lora_r=16,
lora_alpha=16,
lora_dropout=0.0,
lora_trainable_modules="all-linear",
)
elif training_type.lower() == "full":
training_type_cls = FullTrainingType(type="Full")
else:
raise ValueError(f"Unknown training type: {training_type}")

request = FinetunePriceEstimationRequest(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type_cls,
training_method=training_method_cls,
)
parameter_payload = request.model_dump(exclude_none=True)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = requestor.request(
options=TogetherRequest(
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
),
stream=False,
)
assert isinstance(response, TogetherResponse)

return FinetunePriceEstimationResponse(**response.data)

def list(self) -> FinetuneList:
"""
Lists fine-tune job history
Expand Down Expand Up @@ -941,11 +1046,34 @@ async def create(
hf_output_repo_name=hf_output_repo_name,
)

if from_checkpoint is None and from_hf_model is None:
price_estimation_result = await self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model_name,
n_epochs=finetune_request.n_epochs,
n_evals=finetune_request.n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)
price_limit_passed = price_estimation_result.allowed_to_proceed
else:
# unsupported case
price_limit_passed = True

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_limit_passed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price
)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

response, _, _ = await requestor.arequest(
Expand All @@ -961,6 +1089,81 @@ async def create(

return FinetuneResponse(**response.data)

async def estimate_price(
self,
*,
training_file: str,
model: str,
validation_file: str | None = None,
n_epochs: int | None = 1,
n_evals: int | None = 0,
training_type: str = "lora",
training_method: str = "sft",
) -> FinetunePriceEstimationResponse:
"""
Estimates the price of a fine-tuning job

Args:
training_file (str): File-ID of a file uploaded to the Together API
model (str): Name of the base model to run fine-tune job on
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
training_type (str, optional): Training type. Defaults to "lora".
training_method (str, optional): Training method. Defaults to "sft".

Returns:
FinetunePriceEstimationResponse: Object containing the price estimation result.
"""
training_type_cls: TrainingType
training_method_cls: TrainingMethod

if training_method == "sft":
training_method_cls = TrainingMethodSFT(method="sft")
elif training_method == "dpo":
training_method_cls = TrainingMethodDPO(method="dpo")
else:
raise ValueError(f"Unknown training method: {training_method}")

if training_type.lower() == "lora":
# parameters of lora are unused in price estimation
# but we need to set them to valid values
training_type_cls = LoRATrainingType(
type="Lora",
lora_r=16,
lora_alpha=16,
lora_dropout=0.0,
lora_trainable_modules="all-linear",
)
elif training_type.lower() == "full":
training_type_cls = FullTrainingType(type="Full")
else:
raise ValueError(f"Unknown training type: {training_type}")

request = FinetunePriceEstimationRequest(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type_cls,
training_method=training_method_cls,
)
parameter_payload = request.model_dump(exclude_none=True)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
),
stream=False,
)
assert isinstance(response, TogetherResponse)

return FinetunePriceEstimationResponse(**response.data)

async def list(self) -> FinetuneList:
"""
Async method to list fine-tune job history
Expand Down
4 changes: 4 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
FinetuneListEvents,
FinetuneRequest,
FinetuneResponse,
FinetunePriceEstimationRequest,
FinetunePriceEstimationResponse,
FinetuneDeleteResponse,
FinetuneTrainingLimits,
FullTrainingType,
Expand Down Expand Up @@ -103,6 +105,8 @@
"FinetuneDeleteResponse",
"FinetuneDownloadResult",
"FinetuneLRScheduler",
"FinetunePriceEstimationRequest",
"FinetunePriceEstimationResponse",
"LinearLRScheduler",
"LinearLRSchedulerArgs",
"CosineLRScheduler",
Expand Down
26 changes: 26 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,32 @@ def validate_training_type(cls, v: TrainingType) -> TrainingType:
raise ValueError("Unknown training type")


class FinetunePriceEstimationRequest(BaseModel):
"""
Fine-tune price estimation request type
"""

training_file: str
validation_file: str | None = None
model: str
n_epochs: int
n_evals: int
training_type: TrainingType
training_method: TrainingMethod


class FinetunePriceEstimationResponse(BaseModel):
"""
Fine-tune price estimation response type
"""

estimated_total_price: float
user_limit: float
estimated_train_token_count: int
estimated_eval_token_count: int
allowed_to_proceed: bool


class FinetuneList(BaseModel):
# object type
object: Literal["list"] | None = None
Expand Down
Loading