Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ Given a few tens or hundreds of representative _inputs_ of your task and a _metr

```python linenums="1"
import dspy
dspy.configure(lm=dspy.LM('gpt-4o-mini-2024-07-18'))
dspy.configure(lm=dspy.LM('openai/gpt-4o-mini-2024-07-18'))

# Define the DSPy module for classification. It will use the hint at training time, if available.
signature = dspy.Signature("text -> label").with_updated_fields('label', type_=Literal[tuple(CLASSES)])
Expand All @@ -394,7 +394,7 @@ Given a few tens or hundreds of representative _inputs_ of your task and a _metr
optimizer = dspy.BootstrapFinetune(metric=(lambda x, y, trace=None: x.label == y.label), num_threads=24)
optimized = optimizer.compile(classify, trainset=trainset)

optimized_classifier(text="What does a pending cash withdrawal mean?")
optimized(text="What does a pending cash withdrawal mean?")
```

**Possible Output (from the last line):**
Expand Down
16 changes: 8 additions & 8 deletions dspy/clients/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,21 +172,21 @@ def finetune(
train_data_format: Optional[Union[TrainDataFormat, str]] = "chat",
train_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
if isinstance(data_format, str):
if data_format == "chat":
data_format = TrainDataFormat.CHAT
elif data_format == "completion":
data_format = TrainDataFormat.COMPLETION
if isinstance(train_data_format, str):
if train_data_format == "chat":
train_data_format = TrainDataFormat.CHAT
elif train_data_format == "completion":
train_data_format = TrainDataFormat.COMPLETION
else:
raise ValueError(
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {data_format}."
f"String `train_data_format` must be one of 'chat' or 'completion', but received: {train_data_format}."
)

if "train_data_path" not in train_kwargs:
raise ValueError("The `train_data_path` must be provided to finetune on Databricks.")
# Add the file name to the directory path.
train_kwargs["train_data_path"] = DatabricksProvider.upload_data(
train_data, train_kwargs["train_data_path"], data_format
train_data, train_kwargs["train_data_path"], train_data_format
)

try:
Expand Down Expand Up @@ -236,7 +236,7 @@ def finetune(
model_to_deploy = train_kwargs.get("register_to")
job.endpoint_name = model_to_deploy.replace(".", "_")
DatabricksProvider.deploy_finetuned_model(
model_to_deploy, data_format, databricks_host, databricks_token, deploy_timeout
model_to_deploy, train_data_format, databricks_host, databricks_token, deploy_timeout
)
job.launch_completed = True
# The finetuned model name should be in the format: "databricks/<endpoint_name>".
Expand Down
8 changes: 6 additions & 2 deletions dspy/clients/lm_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from typing import Any, Dict, List, Optional
from dspy.clients.provider import TrainingJob, Provider
from dspy.clients.utils_finetune import TrainDataFormat, save_data
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dspy.clients.lm import LM

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,7 +122,7 @@ def get_logs() -> str:


@staticmethod
def kill(lm: 'LM', launch_kwargs: Optional[Dict[str, Any]] = None):
def kill(lm: "LM", launch_kwargs: Optional[Dict[str, Any]] = None):
from sglang.utils import terminate_process
if not hasattr(lm, "process"):
logger.info("No running server to kill.")
Expand Down Expand Up @@ -227,7 +231,7 @@ def train_sft_locally(model_name, train_data, train_kwargs):

hf_dataset = Dataset.from_list(train_data)
def tokenize_function(example):
return encode_sft_example(example, tokenizer, max_seq_length)
return encode_sft_example(example, tokenizer, train_kwargs["max_seq_length"])
tokenized_dataset = hf_dataset.map(tokenize_function, batched=False)
tokenized_dataset.set_format(type="torch")
tokenized_dataset = tokenized_dataset.filter(lambda example: (example["labels"] != -100).any())
Expand Down
23 changes: 13 additions & 10 deletions dspy/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,13 @@ def __init__(self):

@staticmethod
def is_provider_model(model: str) -> bool:
# Filter the provider_prefix, if exists
provider_prefix = "openai/"
if model.startswith(provider_prefix):
model = model[len(provider_prefix):]
model = OpenAIProvider._remove_provider_prefix(model)

# Check if the model is a base OpenAI model
# TODO(enhance) The following list can be replaced with
# openai.models.list(), but doing so might require a key. Is there a
# way to get the list of models without a key?
valid_model_names = _OPENAI_MODELS
if model in valid_model_names:
if model in _OPENAI_MODELS:
return True

# Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI
Expand All @@ -113,10 +109,15 @@ def is_provider_model(model: str) -> bool:
# model names by making a call to the OpenAI API to be more exact, but
# this might require an API key with the right permissions.
match = re.match(r"ft:([^:]+):", model)
if match and match.group(1) in valid_model_names:
if match and match.group(1) in _OPENAI_MODELS:
return True

return False

@staticmethod
def _remove_provider_prefix(model: str) -> str:
provider_prefix = "openai/"
return model.replace(provider_prefix, "")

@staticmethod
def finetune(
Expand All @@ -126,6 +127,8 @@ def finetune(
train_data_format: Optional[TrainDataFormat],
train_kwargs: Optional[Dict[str, Any]] = None,
) -> str:
model = OpenAIProvider._remove_provider_prefix(model)

print("[OpenAI Provider] Validating the data format")
OpenAIProvider.validate_data_format(train_data_format)

Expand All @@ -138,7 +141,7 @@ def finetune(
job.provider_file_id = provider_file_id

print("[OpenAI Provider] Starting remote training")
provider_job_id = OpenAIProvider.start_remote_training(
provider_job_id = OpenAIProvider._start_remote_training(
train_file_id=job.provider_file_id,
model=model,
train_kwargs=train_kwargs,
Expand Down Expand Up @@ -231,9 +234,9 @@ def upload_data(data_path: str) -> str:
return provider_file.id

@staticmethod
def start_remote_training(
def _start_remote_training(
train_file_id: str,
model: id,
model: str,
train_kwargs: Optional[Dict[str, Any]] = None
) -> str:
train_kwargs = train_kwargs or {}
Expand Down
7 changes: 5 additions & 2 deletions dspy/clients/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from typing import Any, Dict, List, Optional, Union

from dspy.clients.utils_finetune import TrainDataFormat
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dspy.clients.lm import LM

class TrainingJob(Future):
def __init__(
Expand Down Expand Up @@ -44,14 +47,14 @@ def is_provider_model(model: str) -> bool:
return False

@staticmethod
def launch(lm: 'LM', launch_kwargs: Optional[Dict[str, Any]] = None):
def launch(lm: "LM", launch_kwargs: Optional[Dict[str, Any]] = None):
# Note that "launch" and "kill" methods might be called even if there
# is a launched LM or no launched LM to kill. These methods should be
# resillient to such cases.
pass

@staticmethod
def kill(lm: 'LM', launch_kwargs: Optional[Dict[str, Any]] = None):
def kill(lm: "LM", launch_kwargs: Optional[Dict[str, Any]] = None):
# We assume that LM.launch_kwargs dictionary will contain the necessary
# information for a provider to launch and/or kill an LM. This is the
# reeason why the argument here is named launch_kwargs and not
Expand Down
2 changes: 1 addition & 1 deletion dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def finetune_lms(finetune_dict) -> Dict[Any, LM]:

key_to_job = {}
for key, finetune_kwargs in finetune_dict.items():
lm = finetune_kwargs.pop("lm")
lm: LM = finetune_kwargs.pop("lm")
# TODO: The following line is a hack. We should re-think how to free
# up resources for fine-tuning. This might mean introducing a new
# provider method (e.g. prepare_for_finetune) that can be called
Expand Down
Loading