diff --git a/src/together/lib/cli/api/fine_tuning.py b/src/together/lib/cli/api/fine_tuning.py index d1215359..c7190d75 100644 --- a/src/together/lib/cli/api/fine_tuning.py +++ b/src/together/lib/cli/api/fine_tuning.py @@ -24,13 +24,21 @@ _CONFIRMATION_MESSAGE = ( "You are about to create a fine-tuning job. " - "The cost of your job will be determined by the model size, the number of tokens " + "The estimated price of this job is {price}. " + "The actual cost of your job will be determined by the model size, the number of tokens " "in the training file, the number of tokens in the validation file, the number of epochs, and " - "the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n" + "the number of evaluations. Visit https://www.together.ai/pricing to learn more about pricing.\n" + "{warning}" "You can pass `-y` or `--confirm` to your command to skip this message.\n\n" "Do you want to proceed?" ) +_WARNING_MESSAGE_INSUFFICIENT_FUNDS = ( + "The estimated price of this job is significantly greater than your current credit limit and balance. " + "It will likely fail due to insufficient funds. " + "Please consider increasing your credit limit at https://api.together.xyz/settings/profile\n" +) + _FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$" @@ -323,7 +331,32 @@ def create( elif n_evals > 0 and not validation_file: raise click.BadParameter("You have specified a number of evaluation loops but no validation file.") - if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True): + finetune_price_estimation_result = client.fine_tuning.estimate_price( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type="lora" if lora else "full", + training_method=training_method, + ) + + price = click.style( + f"${finetune_price_estimation_result.estimated_total_price:.2f}", + bold=True, + ) + + if not finetune_price_estimation_result.allowed_to_proceed: + warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True) + else: + warning = "" + + confirmation_message = _CONFIRMATION_MESSAGE.format( + price=price, + warning=warning, + ) + + if confirm or click.confirm(confirmation_message, default=True, show_default=True): response = client.fine_tuning.create( **training_args, verbose=True, diff --git a/src/together/lib/resources/fine_tuning.py b/src/together/lib/resources/fine_tuning.py index f4779191..002f9bfc 100644 --- a/src/together/lib/resources/fine_tuning.py +++ b/src/together/lib/resources/fine_tuning.py @@ -21,6 +21,7 @@ CosineLRSchedulerArgs, LinearLRSchedulerArgs, FinetuneTrainingLimits, + FinetunePriceEstimationRequest, ) AVAILABLE_TRAINING_METHODS = { @@ -236,6 +237,53 @@ def create_finetune_request( return finetune_request +def create_finetune_price_estimation_request( + training_file: str, + validation_file: str | None = None, + model: str | None = None, + n_epochs: int = 1, + n_evals: int | None = 0, + training_type: str | None = "lora", + training_method: str | None = "sft", +) -> FinetunePriceEstimationRequest: + """ + Create a fine-tune price estimation request + """ + + training_method_cls: TrainingMethodSFT | TrainingMethodDPO + if training_method == "sft": + training_method_cls = TrainingMethodSFT(train_on_inputs="auto") + elif training_method == "dpo": + training_method_cls = TrainingMethodDPO( + dpo_beta=None, + dpo_normalize_logratios_by_length=False, + dpo_reference_free=False, + rpo_alpha=None, + simpo_gamma=None, + ) + else: + raise ValueError(f"Invalid training method: {training_method}. Must be 'sft' or 'dpo'") + + training_type_cls: FullTrainingType | LoRATrainingType + if training_type == "full": + training_type_cls = FullTrainingType(type="Full") + elif training_type == "lora": + # lora parameters do not matter for price estimation + training_type_cls = LoRATrainingType(type="Lora", lora_r=10, lora_alpha=10) + else: + raise ValueError(f"Invalid training type: {training_type}. Must be 'full' or 'lora'") + + return FinetunePriceEstimationRequest( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type=training_type_cls, + training_method=training_method_cls, + ) + + def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits: """ Requests training limits for a specific model diff --git a/src/together/lib/types/fine_tuning.py b/src/together/lib/types/fine_tuning.py index 55327e5a..4173009e 100644 --- a/src/together/lib/types/fine_tuning.py +++ b/src/together/lib/types/fine_tuning.py @@ -395,3 +395,36 @@ class FinetuneRequest(BaseModel): # hf related fields hf_api_token: Union[str, None] = None hf_output_repo_name: Union[str, None] = None + + +class FinetunePriceEstimationRequest(BaseModel): + """ + Fine-tune price estimation request type + """ + + # training file ID + training_file: str + # validation file id + validation_file: Union[str, None] = None + # base model string + model: Union[str, None] = None + # number of epochs to train for + n_epochs: int + # number of evaluation loops to run + n_evals: Union[int, None] = None + # training type + training_type: Union[TrainingType, None] = None + # training method + training_method: Union[TrainingMethodSFT, TrainingMethodDPO] = Field(default_factory=TrainingMethodSFT) + + +class FinetunePriceEstimationResponse(BaseModel): + """ + Fine-tune price estimation request type + """ + + allowed_to_proceed: bool + estimated_train_token_count: int + estimated_eval_token_count: int + user_limit: float + estimated_total_price: float diff --git a/src/together/resources/fine_tuning.py b/src/together/resources/fine_tuning.py index 47362d31..7b1df510 100644 --- a/src/together/resources/fine_tuning.py +++ b/src/together/resources/fine_tuning.py @@ -27,15 +27,31 @@ async_to_custom_streamed_response_wrapper, ) from .._base_client import make_request_options -from ..lib.types.fine_tuning import FinetuneResponse as FinetuneResponseLib, FinetuneTrainingLimits +from ..lib.types.fine_tuning import ( + FinetuneResponse as FinetuneResponseLib, + FinetuneTrainingLimits, + FinetunePriceEstimationResponse, +) from ..types.finetune_response import FinetuneResponse -from ..lib.resources.fine_tuning import get_model_limits, async_get_model_limits, create_finetune_request +from ..lib.resources.fine_tuning import ( + get_model_limits, + async_get_model_limits, + create_finetune_request, + create_finetune_price_estimation_request, +) from ..types.fine_tuning_list_response import FineTuningListResponse from ..types.fine_tuning_cancel_response import FineTuningCancelResponse from ..types.fine_tuning_delete_response import FineTuningDeleteResponse from ..types.fine_tuning_list_events_response import FineTuningListEventsResponse from ..types.fine_tuning_list_checkpoints_response import FineTuningListCheckpointsResponse +_WARNING_MESSAGE_INSUFFICIENT_FUNDS = ( + "The estimated price of the fine-tuning job is {} which is significantly " + "greater than your current credit limit and balance. " + "It will likely fail due to insufficient funds. " + "Please proceed at your own risk." +) + __all__ = ["FineTuningResource", "AsyncFineTuningResource"] @@ -218,11 +234,27 @@ def create( hf_output_repo_name=hf_output_repo_name, ) + price_estimation_result = self.estimate_price( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type="lora" if lora else "full", + training_method=training_method, + ) + if verbose: rprint( "Submitting a fine-tuning job with the following parameters:", finetune_request, ) + if not price_estimation_result.allowed_to_proceed: + rprint( + "[red]" + + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(price_estimation_result.estimated_total_price) + + "[/red]", + ) parameter_payload = finetune_request.model_dump(exclude_none=True) return self._client.post( @@ -231,6 +263,37 @@ def create( cast_to=FinetuneResponseLib, ) + def estimate_price( + self, + *, + training_file: str, + validation_file: str | None = None, + model: str | None = None, + n_epochs: int = 1, + n_evals: int | None = 0, + training_type: str | None = "lora", + training_method: str | None = "sft", + ) -> FinetunePriceEstimationResponse: + """ + Estimate the price of a fine-tuning job + """ + + finetune_price_estimation_request = create_finetune_price_estimation_request( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type=training_type, + training_method=training_method, + ) + parameter_payload = finetune_price_estimation_request.model_dump(exclude_none=True) + return self._client.post( + "/fine-tunes/estimate-price", + body=parameter_payload, + cast_to=FinetunePriceEstimationResponse, + ) + def retrieve( self, id: str, @@ -659,11 +722,27 @@ async def create( hf_output_repo_name=hf_output_repo_name, ) + price_estimation_result = await self.estimate_price( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type="lora" if lora else "full", + training_method=training_method, + ) + if verbose: rprint( "Submitting a fine-tuning job with the following parameters:", finetune_request, ) + if not price_estimation_result.allowed_to_proceed: + rprint( + "[red]" + + _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(price_estimation_result.estimated_total_price) + + "[/red]", + ) parameter_payload = finetune_request.model_dump(exclude_none=True) return await self._client.post( @@ -672,6 +751,37 @@ async def create( cast_to=FinetuneResponse, ) + async def estimate_price( + self, + *, + training_file: str, + validation_file: str | None = None, + model: str | None = None, + n_epochs: int = 1, + n_evals: int | None = 0, + training_type: str | None = "lora", + training_method: str | None = "sft", + ) -> FinetunePriceEstimationResponse: + """ + Estimate the price of a fine-tuning job + """ + + finetune_price_estimation_request = create_finetune_price_estimation_request( + training_file=training_file, + validation_file=validation_file, + model=model, + n_epochs=n_epochs, + n_evals=n_evals, + training_type=training_type, + training_method=training_method, + ) + parameter_payload = finetune_price_estimation_request.model_dump(exclude_none=True) + return await self._client.post( + "/fine-tunes/estimate-price", + body=parameter_payload, + cast_to=FinetunePriceEstimationResponse, + ) + async def retrieve( self, id: str,