Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 7 additions & 15 deletions src/together/lib/cli/api/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from together.lib.cli.api.utils import INT_WITH_MAX, BOOL_WITH_AUTO
from together.lib.resources.files import DownloadManager
from together.lib.utils.serializer import datetime_serializer
from together.lib.resources.fine_tuning import get_model_limits, create_fine_tuning_request
from together.lib.resources.fine_tuning import get_model_limits

_CONFIRMATION_MESSAGE = (
"You are about to create a fine-tuning job. "
Expand Down Expand Up @@ -287,13 +287,7 @@ def create(
if from_checkpoint is not None:
model_name = from_checkpoint.split(":")[0]

if model_name is None:
raise click.BadParameter("You must specify a model or a checkpoint")

model_limits = get_model_limits(
client,
model=model_name,
)
model_limits = get_model_limits(client, str(model_name))

if lora:
if model_limits.lora_training is None:
Expand All @@ -304,9 +298,9 @@ def create(
}

for arg in default_values:
arg_source = ctx.get_parameter_source("arg")
arg_source = ctx.get_parameter_source("arg") # type: ignore[attr-defined]
if arg_source == ParameterSource.DEFAULT:
training_args[str(arg)] = default_values[str(arg_source)]
training_args[arg] = default_values[str(arg_source)]

if ctx.get_parameter_source("lora_alpha") == ParameterSource.DEFAULT: # type: ignore[attr-defined]
training_args["lora_alpha"] = training_args["lora_r"] * 2
Expand All @@ -330,18 +324,16 @@ def create(
raise click.BadParameter("You have specified a number of evaluation loops but no validation file.")

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
params = create_fine_tuning_request(
model_limits=model_limits,
response = client.fine_tuning.create(
**training_args,
verbose=True,
)
rprint("Submitting a fine-tuning job with the following parameters:", params)
response = client.fine_tuning.create(**params)

report_string = f"Successfully submitted a fine-tuning job {response.id}"
# created_at reports UTC time, we use .astimezone() to convert to local time
formatted_time = response.created_at.astimezone().strftime("%m/%d/%Y, %H:%M:%S")
report_string += f" at {formatted_time}"
click.echo(report_string)
rprint(report_string)
else:
click.echo("No confirmation received, stopping job launch")

Expand Down
167 changes: 80 additions & 87 deletions src/together/lib/resources/fine_tuning.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
from __future__ import annotations

from typing import Any, Dict, Literal
from typing_extensions import Union, TypeAlias
from typing import TYPE_CHECKING, Literal

from rich import print as rprint

from together import Together
from together.types import (
LrSchedulerParam,
FullTrainingTypeParam,
LoRaTrainingTypeParam,
FineTuningCreateParams,
TrainingMethodDpoParam,
TrainingMethodSftParam,
CosineLrSchedulerArgsParam,
LinearLrSchedulerArgsParam,
)
from together.lib.utils import log_warn_once
from together.lib.types.fine_tuning import FinetuneTrainingLimits

TrainingMethod: TypeAlias = Union[TrainingMethodSftParam, TrainingMethodDpoParam]

TrainingType: TypeAlias = Union[FullTrainingTypeParam, LoRaTrainingTypeParam]
if TYPE_CHECKING:
from together import Together, AsyncTogether
from together.lib.types.fine_tuning import (
TrainingType,
FinetuneRequest,
FullTrainingType,
LoRATrainingType,
CosineLRScheduler,
LinearLRScheduler,
TrainingMethodDPO,
TrainingMethodSFT,
FinetuneLRScheduler,
CosineLRSchedulerArgs,
LinearLRSchedulerArgs,
FinetuneTrainingLimits,
)

AVAILABLE_TRAINING_METHODS = {
"sft",
"dpo",
}


def create_fine_tuning_request(
def create_finetune_request(
model_limits: FinetuneTrainingLimits,
training_file: str,
model: str | None = None,
Expand All @@ -40,11 +40,11 @@ def create_fine_tuning_request(
batch_size: int | Literal["max"] = "max",
learning_rate: float | None = 0.00001,
lr_scheduler_type: Literal["linear", "cosine"] = "cosine",
min_lr_ratio: float = 0.0,
min_lr_ratio: float | None = 0.0,
scheduler_num_cycles: float = 0.5,
warmup_ratio: float | None = None,
max_grad_norm: float = 1.0,
weight_decay: float = 0.0,
weight_decay: float | None = 0.0,
lora: bool = False,
lora_r: int | None = None,
lora_dropout: float | None = 0,
Expand All @@ -66,7 +66,7 @@ def create_fine_tuning_request(
hf_model_revision: str | None = None,
hf_api_token: str | None = None,
hf_output_repo_name: str | None = None,
) -> FineTuningCreateParams:
) -> FinetuneRequest:
if model is not None and from_checkpoint is not None:
raise ValueError("You must specify either a model or a checkpoint to start a job from, not both")

Expand All @@ -87,7 +87,7 @@ def create_fine_tuning_request(
if warmup_ratio is None:
warmup_ratio = 0.0

training_type: TrainingType = FullTrainingTypeParam(type="Full")
training_type: TrainingType = FullTrainingType()
if lora:
if model_limits.lora_training is None:
raise ValueError(f"LoRA adapters are not supported for the selected model ({model_or_checkpoint}).")
Expand All @@ -98,15 +98,12 @@ def create_fine_tuning_request(

lora_r = lora_r if lora_r is not None else model_limits.lora_training.max_rank
lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
training_type = LoRaTrainingTypeParam(
type="Lora",
training_type = LoRATrainingType(
lora_r=lora_r,
lora_alpha=int(lora_alpha),
lora_dropout=lora_dropout or 0.0,
lora_trainable_modules=lora_trainable_modules or "all-linear",
)
if lora_dropout is not None:
training_type["lora_dropout"] = lora_dropout
if lora_trainable_modules is not None:
training_type["lora_trainable_modules"] = lora_trainable_modules

max_batch_size = model_limits.lora_training.max_batch_size
min_batch_size = model_limits.lora_training.min_batch_size
Expand Down Expand Up @@ -139,13 +136,13 @@ def create_fine_tuning_request(
if warmup_ratio > 1 or warmup_ratio < 0:
raise ValueError(f"Warmup ratio should be between 0 and 1 (got {warmup_ratio})")

if min_lr_ratio > 1 or min_lr_ratio < 0:
if min_lr_ratio is not None and (min_lr_ratio > 1 or min_lr_ratio < 0):
raise ValueError(f"Min learning rate ratio should be between 0 and 1 (got {min_lr_ratio})")

if max_grad_norm < 0:
raise ValueError(f"Max gradient norm should be non-negative (got {max_grad_norm})")

if weight_decay < 0:
if weight_decay is not None and (weight_decay < 0):
raise ValueError(f"Weight decay should be non-negative (got {weight_decay})")

if training_method not in AVAILABLE_TRAINING_METHODS:
Expand All @@ -154,10 +151,6 @@ def create_fine_tuning_request(
if train_on_inputs is not None and training_method != "sft":
raise ValueError("train_on_inputs is only supported for SFT training")

if train_on_inputs is None and training_method == "sft":
log_warn_once("train_on_inputs is not set for SFT training, it will be set to 'auto'")
train_on_inputs = "auto"

if dpo_beta is not None and training_method != "dpo":
raise ValueError("dpo_beta is only supported for DPO training")
if dpo_normalize_logratios_by_length and training_method != "dpo":
Expand All @@ -174,24 +167,25 @@ def create_fine_tuning_request(
if not simpo_gamma >= 0.0:
raise ValueError(f"simpo_gamma should be non-negative (got {simpo_gamma})")

lr_scheduler: LrSchedulerParam
lr_scheduler: FinetuneLRScheduler
if lr_scheduler_type == "cosine":
if scheduler_num_cycles <= 0.0:
raise ValueError(f"Number of cycles should be greater than 0 (got {scheduler_num_cycles})")

lr_scheduler = LrSchedulerParam(
lr_scheduler_type="cosine",
lr_scheduler_args=CosineLrSchedulerArgsParam(min_lr_ratio=min_lr_ratio, num_cycles=scheduler_num_cycles),
lr_scheduler = CosineLRScheduler(
lr_scheduler_args=CosineLRSchedulerArgs(min_lr_ratio=min_lr_ratio, num_cycles=scheduler_num_cycles),
)
else:
lr_scheduler = LrSchedulerParam(
lr_scheduler_type="linear",
lr_scheduler_args=LinearLrSchedulerArgsParam(min_lr_ratio=min_lr_ratio),
lr_scheduler = LinearLRScheduler(
lr_scheduler_args=LinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio),
)

training_method_cls: TrainingMethod | None = None
training_method_cls: TrainingMethodSFT | TrainingMethodDPO
if training_method == "sft":
training_method_cls = TrainingMethodSftParam(method="sft", train_on_inputs=train_on_inputs or "auto")
if train_on_inputs is None:
log_warn_once("train_on_inputs is not set for SFT training, it will be set to 'auto'")
train_on_inputs = "auto"
training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs)
elif training_method == "dpo":
if simpo_gamma is not None and simpo_gamma > 0:
dpo_reference_free = True
Expand All @@ -204,59 +198,40 @@ def create_fine_tuning_request(
else:
dpo_reference_free = False

training_method_cls = TrainingMethodDpoParam(
method="dpo",
training_method_cls = TrainingMethodDPO(
dpo_beta=dpo_beta,
dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length,
dpo_reference_free=dpo_reference_free,
rpo_alpha=rpo_alpha,
simpo_gamma=simpo_gamma,
)
if dpo_beta is not None:
training_method_cls["dpo_beta"] = dpo_beta
if rpo_alpha is not None:
training_method_cls["rpo_alpha"] = rpo_alpha
if simpo_gamma is not None:
training_method_cls["simpo_gamma"] = simpo_gamma

finetune_request = FineTuningCreateParams(
model=model or "",

finetune_request = FinetuneRequest(
model=model,
training_file=training_file,
validation_file=validation_file,
n_epochs=n_epochs,
n_evals=n_evals,
n_checkpoints=n_checkpoints,
batch_size=batch_size,
learning_rate=learning_rate or 0.00001,
lr_scheduler=lr_scheduler,
warmup_ratio=warmup_ratio,
max_grad_norm=max_grad_norm,
weight_decay=weight_decay,
weight_decay=weight_decay or 0.0,
training_type=training_type,
suffix=suffix,
wandb_key=wandb_api_key,
wandb_base_url=wandb_base_url,
wandb_project_name=wandb_project_name,
wandb_name=wandb_name,
training_method=training_method_cls, # pyright: ignore[reportPossiblyUnboundVariable]
from_checkpoint=from_checkpoint,
from_hf_model=from_hf_model,
hf_model_revision=hf_model_revision,
hf_api_token=hf_api_token,
hf_output_repo_name=hf_output_repo_name,
)
if validation_file is not None:
finetune_request["validation_file"] = validation_file
if n_evals is not None:
finetune_request["n_evals"] = n_evals
if n_checkpoints is not None:
finetune_request["n_checkpoints"] = n_checkpoints
if learning_rate is not None:
finetune_request["learning_rate"] = learning_rate
if suffix is not None:
finetune_request["suffix"] = suffix
if wandb_api_key is not None:
finetune_request["wandb_api_key"] = wandb_api_key
if wandb_base_url is not None:
finetune_request["wandb_base_url"] = wandb_base_url
if wandb_project_name is not None:
finetune_request["wandb_project_name"] = wandb_project_name
if wandb_name is not None:
finetune_request["wandb_name"] = wandb_name
if training_method_cls is not None:
finetune_request["training_method"] = training_method_cls
if from_checkpoint is not None:
finetune_request["from_checkpoint"] = from_checkpoint
if from_hf_model is not None:
finetune_request["from_hf_model"] = from_hf_model
if hf_model_revision is not None:
finetune_request["hf_model_revision"] = hf_model_revision
if hf_api_token is not None:
finetune_request["hf_api_token"] = hf_api_token
if hf_output_repo_name is not None:
finetune_request["hf_output_repo_name"] = hf_output_repo_name

return finetune_request

Expand All @@ -283,5 +258,23 @@ def get_model_limits(client: Together, model: str) -> FinetuneTrainingLimits:
return response


def not_none_kwargs(**kwargs: Any) -> Dict[str, Any]:
return {k: v for k, v in kwargs.items() if v is not None}
async def async_get_model_limits(client: AsyncTogether, model: str) -> FinetuneTrainingLimits:
"""
Requests training limits for a specific model

Args:
model_name (str): Name of the model to get limits for

Returns:
FinetuneTrainingLimits: Object containing training limits for the model
"""

response = await client.get(
"/fine-tunes/models/limits",
cast_to=FinetuneTrainingLimits,
options={
"params": {"model_name": model},
},
)

return response
Loading