Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions src/together/lib/cli/api/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,21 @@

_CONFIRMATION_MESSAGE = (
"You are about to create a fine-tuning job. "
"The cost of your job will be determined by the model size, the number of tokens "
"The estimated price of this job is {price}. "
"The actual cost of your job will be determined by the model size, the number of tokens "
"in the training file, the number of tokens in the validation file, the number of epochs, and "
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
"the number of evaluations. Visit https://www.together.ai/pricing to learn more about pricing.\n"
"{warning}"
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
"Do you want to proceed?"
)

_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of this job is significantly greater than your current credit limit and balance. "
"It will likely fail due to insufficient funds. "
"Please consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
)

_FT_JOB_WITH_STEP_REGEX = r"^ft-[\dabcdef-]+:\d+$"


Expand Down Expand Up @@ -323,7 +331,32 @@ def create(
elif n_evals > 0 and not validation_file:
raise click.BadParameter("You have specified a number of evaluation loops but no validation file.")

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
finetune_price_estimation_result = client.fine_tuning.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)

price = click.style(
f"${finetune_price_estimation_result.estimated_total_price:.2f}",
bold=True,
)

if not finetune_price_estimation_result.allowed_to_proceed:
warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True)
else:
warning = ""

confirmation_message = _CONFIRMATION_MESSAGE.format(
price=price,
warning=warning,
)

if confirm or click.confirm(confirmation_message, default=True, show_default=True):
response = client.fine_tuning.create(
**training_args,
verbose=True,
Expand Down
48 changes: 48 additions & 0 deletions src/together/lib/resources/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CosineLRSchedulerArgs,
LinearLRSchedulerArgs,
FinetuneTrainingLimits,
FinetunePriceEstimationRequest,
)

AVAILABLE_TRAINING_METHODS = {
Expand Down Expand Up @@ -236,6 +237,53 @@ def create_finetune_request(
return finetune_request


def create_finetune_price_estimation_request(
training_file: str,
validation_file: str | None = None,
model: str | None = None,
n_epochs: int = 1,
n_evals: int | None = 0,
training_type: str | None = "lora",
training_method: str | None = "sft",
) -> FinetunePriceEstimationRequest:
"""
Create a fine-tune price estimation request
"""

training_method_cls: TrainingMethodSFT | TrainingMethodDPO
if training_method == "sft":
training_method_cls = TrainingMethodSFT(train_on_inputs="auto")
elif training_method == "dpo":
training_method_cls = TrainingMethodDPO(
dpo_beta=None,
dpo_normalize_logratios_by_length=False,
dpo_reference_free=False,
rpo_alpha=None,
simpo_gamma=None,
)
else:
raise ValueError(f"Invalid training method: {training_method}. Must be 'sft' or 'dpo'")

training_type_cls: FullTrainingType | LoRATrainingType
if training_type == "full":
training_type_cls = FullTrainingType(type="Full")
elif training_type == "lora":
# lora parameters do not matter for price estimation
training_type_cls = LoRATrainingType(type="Lora", lora_r=10, lora_alpha=10)
else:
raise ValueError(f"Invalid training type: {training_type}. Must be 'full' or 'lora'")

return FinetunePriceEstimationRequest(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type_cls,
training_method=training_method_cls,
)


def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits:
"""
Requests training limits for a specific model
Expand Down
33 changes: 33 additions & 0 deletions src/together/lib/types/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,36 @@ class FinetuneRequest(BaseModel):
# hf related fields
hf_api_token: Union[str, None] = None
hf_output_repo_name: Union[str, None] = None


class FinetunePriceEstimationRequest(BaseModel):
"""
Fine-tune price estimation request type
"""

# training file ID
training_file: str
# validation file id
validation_file: Union[str, None] = None
# base model string
model: Union[str, None] = None
# number of epochs to train for
n_epochs: int
# number of evaluation loops to run
n_evals: Union[int, None] = None
# training type
training_type: Union[TrainingType, None] = None
# training method
training_method: Union[TrainingMethodSFT, TrainingMethodDPO] = Field(default_factory=TrainingMethodSFT)


class FinetunePriceEstimationResponse(BaseModel):
"""
Fine-tune price estimation request type
"""

allowed_to_proceed: bool
estimated_train_token_count: int
estimated_eval_token_count: int
user_limit: float
estimated_total_price: float
114 changes: 112 additions & 2 deletions src/together/resources/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,31 @@
async_to_custom_streamed_response_wrapper,
)
from .._base_client import make_request_options
from ..lib.types.fine_tuning import FinetuneResponse as FinetuneResponseLib, FinetuneTrainingLimits
from ..lib.types.fine_tuning import (
FinetuneResponse as FinetuneResponseLib,
FinetuneTrainingLimits,
FinetunePriceEstimationResponse,
)
from ..types.finetune_response import FinetuneResponse
from ..lib.resources.fine_tuning import get_model_limits, async_get_model_limits, create_finetune_request
from ..lib.resources.fine_tuning import (
get_model_limits,
async_get_model_limits,
create_finetune_request,
create_finetune_price_estimation_request,
)
from ..types.fine_tuning_list_response import FineTuningListResponse
from ..types.fine_tuning_cancel_response import FineTuningCancelResponse
from ..types.fine_tuning_delete_response import FineTuningDeleteResponse
from ..types.fine_tuning_list_events_response import FineTuningListEventsResponse
from ..types.fine_tuning_list_checkpoints_response import FineTuningListCheckpointsResponse

_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of the fine-tuning job is {} which is significantly "
"greater than your current credit limit and balance. "
"It will likely fail due to insufficient funds. "
"Please proceed at your own risk."
)

__all__ = ["FineTuningResource", "AsyncFineTuningResource"]


Expand Down Expand Up @@ -218,11 +234,27 @@ def create(
hf_output_repo_name=hf_output_repo_name,
)

price_estimation_result = self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_estimation_result.allowed_to_proceed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(price_estimation_result.estimated_total_price)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

return self._client.post(
Expand All @@ -231,6 +263,37 @@ def create(
cast_to=FinetuneResponseLib,
)

def estimate_price(
self,
*,
training_file: str,
validation_file: str | None = None,
model: str | None = None,
n_epochs: int = 1,
n_evals: int | None = 0,
training_type: str | None = "lora",
training_method: str | None = "sft",
) -> FinetunePriceEstimationResponse:
"""
Estimate the price of a fine-tuning job
"""

finetune_price_estimation_request = create_finetune_price_estimation_request(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type,
training_method=training_method,
)
parameter_payload = finetune_price_estimation_request.model_dump(exclude_none=True)
return self._client.post(
"/fine-tunes/estimate-price",
body=parameter_payload,
cast_to=FinetunePriceEstimationResponse,
)

def retrieve(
self,
id: str,
Expand Down Expand Up @@ -659,11 +722,27 @@ async def create(
hf_output_repo_name=hf_output_repo_name,
)

price_estimation_result = await self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_estimation_result.allowed_to_proceed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(price_estimation_result.estimated_total_price)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

return await self._client.post(
Expand All @@ -672,6 +751,37 @@ async def create(
cast_to=FinetuneResponse,
)

async def estimate_price(
self,
*,
training_file: str,
validation_file: str | None = None,
model: str | None = None,
n_epochs: int = 1,
n_evals: int | None = 0,
training_type: str | None = "lora",
training_method: str | None = "sft",
) -> FinetunePriceEstimationResponse:
"""
Estimate the price of a fine-tuning job
"""

finetune_price_estimation_request = create_finetune_price_estimation_request(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type,
training_method=training_method,
)
parameter_payload = finetune_price_estimation_request.model_dump(exclude_none=True)
return await self._client.post(
"/fine-tunes/estimate-price",
body=parameter_payload,
cast_to=FinetunePriceEstimationResponse,
)

async def retrieve(
self,
id: str,
Expand Down