From 3c0554ca43855d8eaeb8d712def545da07ee4bce Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Wed, 6 Nov 2024 16:22:20 -0800 Subject: [PATCH 1/2] Databricks finetuning small fix some fixes --- dspy/clients/databricks.py | 330 +++++++++++++++++++++++++++++++ dspy/clients/lm.py | 50 ++--- dspy/clients/openai.py | 2 +- dspy/clients/provider.py | 22 +-- tests/clients/test_databricks.py | 88 +++++++++ 5 files changed, 451 insertions(+), 41 deletions(-) create mode 100644 dspy/clients/databricks.py create mode 100644 tests/clients/test_databricks.py diff --git a/dspy/clients/databricks.py b/dspy/clients/databricks.py new file mode 100644 index 0000000000..ff21a12e55 --- /dev/null +++ b/dspy/clients/databricks.py @@ -0,0 +1,330 @@ +import logging +import os +import re +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import requests +import ujson + +from dspy.clients.provider import Provider, TrainingJob +from dspy.clients.utils_finetune import DataFormat, get_finetune_directory + +if TYPE_CHECKING: + from databricks.sdk import WorkspaceClient + +logger = logging.getLogger(__name__) + + +class TrainingJobDatabricks(TrainingJob): + def __init__(self, finetuning_run=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.finetuning_run = finetuning_run + self.launch_started = False + self.launch_completed = False + self.endpoint_name = None + + def status(self): + if not self.finetuning_run: + return None + try: + from databricks.model_training import foundation_model as fm + except ImportError: + raise ImportError( + "To use Databricks finetuning, please install the databricks_genai package via " + "`pip install databricks_genai`." + ) + run = fm.get(self.finetuning_run) + return run.status + + +class DatabricksProvider(Provider): + finetunable = True + TrainingJob = TrainingJobDatabricks + + @staticmethod + def is_provider_model(model: str) -> bool: + # We don't automatically infer Databricks models because Databricks is not a proprietary model provider. + return False + + @staticmethod + def kill(model: str, launch_kwargs: Optional[Dict[str, Any]] = None): + pass + + @staticmethod + def deploy_finetuned_model( + model: str, + data_format: Optional[DataFormat] = None, + databricks_host: Optional[str] = None, + databricks_token: Optional[str] = None, + ): + workspace_client = _get_workspace_client() + model_version = next(workspace_client.model_versions.list(model)).version + + # Allow users to override the host and token. This is useful on Databricks hosted runtime. + databricks_host = databricks_host or workspace_client.config.host + databricks_token = databricks_token or workspace_client.config.token + + headers = {"Context-Type": "text/json", "Authorization": f"Bearer {databricks_token}"} + + optimizable_info = requests.get( + url=f"{databricks_host}/api/2.0/serving-endpoints/get-model-optimization-info/{model}/{model_version}", + headers=headers, + ).json() + + if "optimizable" not in optimizable_info or not optimizable_info["optimizable"]: + raise ValueError(f"Model is not eligible for provisioned throughput: {optimizable_info}") + + chunk_size = optimizable_info["throughput_chunk_size"] + + # Minimum desired provisioned throughput + min_provisioned_throughput = 0 + + # Maximum desired provisioned throughput + max_provisioned_throughput = chunk_size + + # Databricks serving endpoint names cannot contain ".". + model_name = model.replace(".", "_") + + get_endpoint_response = requests.get( + url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers + ) + if get_endpoint_response.status_code == 200: + # The serving endpoint already exists, we will update it instead of creating a new one. + data = { + "served_entities": [ + { + "entity_name": model_name, + "entity_version": model_version, + "min_provisioned_throughput": min_provisioned_throughput, + "max_provisioned_throughput": max_provisioned_throughput, + } + ] + } + + response = requests.put( + url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}/config", + json=data, + headers=headers, + ) + else: + # Send the POST request to create the serving endpoint + data = { + "name": model_name, + "config": { + "served_entities": [ + { + "entity_name": model, + "entity_version": model_version, + "min_provisioned_throughput": min_provisioned_throughput, + "max_provisioned_throughput": max_provisioned_throughput, + } + ] + }, + } + + response = requests.post(url=f"{databricks_host}/api/2.0/serving-endpoints", json=data, headers=headers) + + logger.info(response) + if response.status_code != 200: + raise ValueError(f"Failed to create serving endpoint: {response.json()}.") + + from openai import OpenAI + + client = OpenAI( + api_key=databricks_token, + base_url=f"{databricks_host}/serving-endpoints", + ) + # Wait for the deployment to be ready. + while True: + try: + if data_format == DataFormat.chat: + client.chat.completions.create( + messages=[{"role": "user", "content": "hi"}], model=model_name, max_tokens=1 + ) + elif data_format == DataFormat.completion: + client.completions.create(prompt="hi", model=model_name, max_tokens=1) + return + except Exception: + time.sleep(60) + + @staticmethod + def finetune( + job: TrainingJobDatabricks, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + data_format: Optional[Union[DataFormat, str]] = None, + ) -> str: + if isinstance(data_format, str): + if data_format == "chat": + data_format = DataFormat.chat + elif data_format == "completion": + data_format = DataFormat.completion + else: + raise ValueError( + f"String `data_format` must be one of 'chat' or 'completion', but received: {data_format}." + ) + + if "train_data_path" not in train_kwargs: + raise ValueError("The `train_data_path` must be provided to finetune on Databricks.") + # Add the file name to the directory path. + train_kwargs["train_data_path"] = DatabricksProvider.upload_data( + train_data, train_kwargs["train_data_path"], data_format + ) + + try: + from databricks.model_training import foundation_model as fm + except ImportError: + raise ImportError( + "To use Databricks finetuning, please install the databricks_genai package via " + "`pip install databricks_genai`." + ) + + if "register_to" not in train_kwargs: + raise ValueError("The `register_to` must be provided to finetune on Databricks.") + + # Allow users to override the host and token. This is useful on Databricks hosted runtime. + databricks_host = train_kwargs.pop("databricks_host", None) + databricks_token = train_kwargs.pop("databricks_token", None) + + skip_deploy = train_kwargs.pop("skip_deploy", False) + + finetuning_run = fm.create( + model=model, + **train_kwargs, + ) + + job.run = finetuning_run + + # Wait for the finetuning run to be ready. + while True: + job.run = fm.get(job.run) + if job.run.status.display_name == "Completed": + break + elif job.run.status.display_name == "Failed": + raise ValueError( + f"Finetuning run failed with status: {job.run.status.display_name}. Please check the Databricks " + f"workspace for more details. Finetuning job's metadata: {job.run}。" + ) + else: + time.sleep(60) + if skip_deploy: + return None + + job.launch_started = True + model_to_deploy = train_kwargs.get("register_to") + job.endpoint_name = model_to_deploy.replace(".", "_") + DatabricksProvider.deploy_finetuned_model(model_to_deploy, data_format, databricks_host, databricks_token) + job.launch_completed = True + # The finetuned model name should be in the format: "databricks/". + return f"databricks/{job.endpoint_name}" + + @staticmethod + def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: DataFormat): + file_path = _save_data_to_local_file(train_data, data_format) + + w = _get_workspace_client() + _create_directory_in_databricks_unity_catalog(w, databricks_unity_catalog_path) + + with open(file_path, "rb") as f: + target_path = os.path.join(databricks_unity_catalog_path, os.path.basename(file_path)) + w.files.upload(target_path, f, overwrite=True) + return target_path + + +def _get_workspace_client() -> "WorkspaceClient": + try: + from databricks.sdk import WorkspaceClient + except ImportError: + raise ImportError( + "To use Databricks finetuning, please install the databricks-sdk package via " + "`pip install databricks-sdk`." + ) + return WorkspaceClient() + + +def _create_directory_in_databricks_unity_catalog(w: "WorkspaceClient", databricks_unity_catalog_path: str): + pattern = r"^/Volumes/(?P[^/]+)/(?P[^/]+)/(?P[^/]+)(/[^/]+)+$" + match = re.match(pattern, databricks_unity_catalog_path) + if not match: + raise ValueError( + f"Databricks Unity Catalog path must be in the format '/Volumes////...', but " + f"received: {databricks_unity_catalog_path}." + ) + + catalog = match.group("catalog") + schema = match.group("schema") + volume = match.group("volume") + + try: + volume_path = f"{catalog}.{schema}.{volume}" + w.volumes.read(volume_path) + except Exception: + raise ValueError( + f"Databricks Unity Catalog volume does not exist: {volume_path}, please create it on the Databricks " + "workspace." + ) + + try: + w.files.get_directory_metadata(databricks_unity_catalog_path) + logger.info(f"Directory {databricks_unity_catalog_path} already exists, skip creating it.") + except Exception: + # Create the directory if it doesn't exist, we don't raise an error because this is a common case. + logger.info(f"Creating directory {databricks_unity_catalog_path} in Databricks Unity Catalog...") + w.files.create_directory(databricks_unity_catalog_path) + logger.info(f"Successfully created directory {databricks_unity_catalog_path} in Databricks Unity Catalog!") + + +def _save_data_to_local_file(train_data: List[Dict[str, Any]], data_format: DataFormat): + import uuid + + file_name = f"finetuning_{uuid.uuid4()}.jsonl" + + finetune_dir = get_finetune_directory() + file_path = os.path.join(finetune_dir, file_name) + file_path = os.path.abspath(file_path) + with open(file_path, "w") as f: + for item in train_data: + if data_format == DataFormat.chat: + _validate_chat_data(item) + elif data_format == DataFormat.completion: + _validate_completion_data(item) + + f.write(ujson.dumps(item) + "\n") + return file_path + + +def _validate_chat_data(data: Dict[str, Any]): + if "messages" not in data: + raise ValueError( + "Each finetuning data must be a dict with a 'messages' key when `task=CHAT_COMPLETION`, but " + f"received: {data}" + ) + + if not isinstance(data["messages"], list): + raise ValueError( + "The value of the 'messages' key in each finetuning data must be a list of dicts with keys 'role' and " + f"'content' when `task=CHAT_COMPLETION`, but received: {data['messages']}" + ) + + for message in data["messages"]: + if "role" not in message: + raise ValueError(f"Each message in the 'messages' list must contain a 'role' key, but received: {message}.") + if "content" not in message: + raise ValueError( + f"Each message in the 'messages' list must contain a 'content' key, but received: {message}." + ) + + +def _validate_completion_data(data: Dict[str, Any]): + if "prompt" not in data: + raise ValueError( + "Each finetuning data must be a dict with a 'prompt' key when `task=INSTRUCTION_FINETUNE`, but " + f"received: {data}" + ) + if "response" not in data and "completion" not in data: + raise ValueError( + "Each finetuning data must be a dict with a 'response' or 'completion' key when " + f"`task=INSTRUCTION_FINETUNE`, but received: {data}" + ) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index b73d272ff8..1b08fcb17e 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,26 +1,22 @@ import functools -from .base_lm import BaseLM import logging import os +import threading import uuid from datetime import datetime -import threading from typing import Any, Dict, List, Literal, Optional -import dspy + import litellm import ujson +import dspy from dspy.adapters.base import Adapter -from dspy.clients.provider import Provider, TrainingJob from dspy.clients.openai import OpenAIProvider -from dspy.clients.utils_finetune import ( - DataFormat, - validate_data_format, - infer_data_format -) - +from dspy.clients.provider import Provider, TrainingJob +from dspy.clients.utils_finetune import DataFormat, infer_data_format, validate_data_format from dspy.utils.callback import BaseCallback, with_callbacks +from .base_lm import BaseLM logger = logging.getLogger(__name__) @@ -63,7 +59,6 @@ def __init__( self.model = model self.model_type = model_type self.cache = cache - self.launch_kwargs = launch_kwargs or {} self.provider = provider or self.infer_provider() self.callbacks = callbacks or [] self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) @@ -117,18 +112,18 @@ def __call__(self, prompt=None, messages=None, **kwargs): return outputs - def launch(self): - self.provider.launch(self.model, self.launch_kwargs) + def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None): + self.provider.launch(self.model, **launch_kwargs) - def kill(self): - self.provider.kill(self.model, self.launch_kwargs) + def kill(self, kill_kwargs: Optional[Dict[str, Any]] = None): + self.provider.kill(self.model, **kill_kwargs) def finetune( - self, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]]=None, - data_format: Optional[DataFormat] = None, - ) -> TrainingJob: + self, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + data_format: Optional[DataFormat] = None, + ) -> TrainingJob: from dspy import settings as settings err = "Fine-tuning is an experimental feature." @@ -153,11 +148,7 @@ def thread_function_wrapper(): thread = threading.Thread(target=thread_function_wrapper) job = self.provider.TrainingJob( - thread=thread, - model=self.model, - train_data=train_data, - train_kwargs=train_kwargs, - data_format=data_format + thread=thread, model=self.model, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format ) thread.start() @@ -172,23 +163,24 @@ def _run_finetune_job(self, job: TrainingJob): model=job.model, train_data=job.train_data, train_kwargs=job.train_kwargs, - data_format=job.data_format + data_format=job.data_format, ) lm = self.copy(model=model) job.set_result(lm) except Exception as err: logger.error(err) job.set_result(err) - + def infer_provider(self) -> Provider: if OpenAIProvider.is_provider_model(self.model): return OpenAIProvider() # TODO(PR): Keeping this function here will require us to import all # providers in this file. Is this okay? return Provider() - + def infer_adapter(self) -> Adapter: import dspy + if dspy.settings.adapter: return dspy.settings.adapter @@ -197,7 +189,7 @@ def infer_adapter(self) -> Adapter: } model_type = self.model_type return model_type_to_adapter[model_type] - + def copy(self, **kwargs): """Returns a copy of the language model with possibly updated parameters.""" diff --git a/dspy/clients/openai.py b/dspy/clients/openai.py index 77a540a127..4e16bc25c3 100644 --- a/dspy/clients/openai.py +++ b/dspy/clients/openai.py @@ -83,7 +83,7 @@ def status(self) -> TrainingStatus: class OpenAIProvider(Provider): - + def __init__(self): super().__init__() self.finetunable = True diff --git a/dspy/clients/provider.py b/dspy/clients/provider.py index 9eb02be1fb..0f9f4e5172 100644 --- a/dspy/clients/provider.py +++ b/dspy/clients/provider.py @@ -1,7 +1,7 @@ -from concurrent.futures import Future from abc import abstractmethod +from concurrent.futures import Future from threading import Thread -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from dspy.clients.utils_finetune import DataFormat @@ -9,9 +9,9 @@ class TrainingJob(Future): def __init__( self, - thread: Thread, - model: str, - train_data: List[Dict[str, Any]], + thread: Optional[Thread] = None, + model: Optional[str] = None, + train_data: Optional[List[Dict[str, Any]]] = None, train_kwargs: Optional[Dict[str, Any]] = None, data_format: Optional[DataFormat] = None, ): @@ -33,7 +33,6 @@ def status(self): class Provider: - def __init__(self): self.finetunable = False self.TrainingJob = TrainingJob @@ -45,23 +44,24 @@ def is_provider_model(model: str) -> bool: return False @staticmethod - def launch(model: str, launch_kwargs: Optional[Dict[str, Any]]=None): + def launch(model: str, launch_kwargs: Optional[Dict[str, Any]] = None): msg = f"`launch()` is called for the auto-launched model `{model}`" msg += " -- no action is taken!" print(msg) - + @staticmethod - def kill(model: str, launch_kwargs: Optional[Dict[str, Any]]=None): + def kill(model: str, launch_kwargs: Optional[Dict[str, Any]] = None): msg = f"`kill()` is called for the auto-launched model `{model}`" msg += " -- no action is taken!" print(msg) - + @staticmethod def finetune( job: TrainingJob, model: str, train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]] = None, - data_format: Optional[DataFormat] = None, + data_format: Optional[Union[DataFormat, str]] = None, + **kwargs, ) -> str: raise NotImplementedError diff --git a/tests/clients/test_databricks.py b/tests/clients/test_databricks.py new file mode 100644 index 0000000000..6f08191ce7 --- /dev/null +++ b/tests/clients/test_databricks.py @@ -0,0 +1,88 @@ +from dspy.clients.databricks import ( + DatabricksProvider, + _create_directory_in_databricks_unity_catalog, + TrainingJobDatabricks, +) + +import pytest +import dspy + +try: + from databricks.sdk import WorkspaceClient + + WorkspaceClient() +except (ImportError, Exception): + # Skip the test if the Databricks SDK is not configured or credentials are not available. + pytestmark = pytest.mark.skip(reason="Databricks SDK not configured or credentials not available") + + +def test_create_directory_in_databricks_unity_catalog(): + from databricks.sdk import WorkspaceClient + + w = WorkspaceClient() + + with pytest.raises( + ValueError, + match=( + "Databricks Unity Catalog path must be in the format '/Volumes////...', " + "but received: /badstring/whatever" + ), + ): + _create_directory_in_databricks_unity_catalog(w, "/badstring/whatever") + + _create_directory_in_databricks_unity_catalog(w, "/Volumes/main/chenmoney/testing/dspy_testing") + # Check that the directory was created successfully, otherwise `get_directory_metadata` will raise an exception. + w.files.get_directory_metadata("/Volumes/main/chenmoney/testing/dspy_testing") + + +def test_create_finetuning_job(): + fake_training_data = [ + { + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing great, thank you!"}, + ] + }, + { + "messages": [ + {"role": "user", "content": "What is the capital of France?"}, + {"role": "assistant", "content": "Paris!"}, + ] + }, + { + "messages": [ + {"role": "user", "content": "What is the capital of Germany?"}, + {"role": "assistant", "content": "Berlin!"}, + ] + }, + ] + dspy.settings.experimental = True + + job = TrainingJobDatabricks() + + finetuned_model = DatabricksProvider.finetune( + job=job, + model="meta-llama/Llama-3.2-1B", + train_data=fake_training_data, + data_format="chat", + train_kwargs={ + "train_data_path": "/Volumes/main/chenmoney/testing/dspy_testing", + "register_to": "main.chenmoney.finetuned_model", + "task_type": "CHAT_COMPLETION", + "skip_deploy": True, + }, + ) + assert job.finetuning_run.status.display_name is not None + + +def test_deploy_finetuned_model(): + dspy.settings.experimental = True + model_to_deploy = "main.chenmoney.finetuned_model" + + DatabricksProvider.deploy_finetuned_model( + model=model_to_deploy, + data_format="chat", + ) + + lm = dspy.LM(model="databricks/main_chenmoney_finetuned_model") + lm("what is 2 + 2?") From 882e9ec9722138522a3352cc7490047b3df8712f Mon Sep 17 00:00:00 2001 From: chenmoneygithub Date: Thu, 7 Nov 2024 16:34:35 -0800 Subject: [PATCH 2/2] add examples --- dspy/clients/databricks.py | 59 +++++++++++++----- dspy/clients/lm.py | 9 ++- dspy/clients/openai.py | 2 +- dspy/clients/provider.py | 3 +- dspy/teleprompt/bootstrap_finetune.py | 5 +- examples/finetune/databricks_finetuning.py | 70 ++++++++++++++++++++++ tests/clients/test_databricks.py | 6 ++ 7 files changed, 133 insertions(+), 21 deletions(-) create mode 100644 examples/finetune/databricks_finetuning.py diff --git a/dspy/clients/databricks.py b/dspy/clients/databricks.py index ff21a12e55..a56e91ac13 100644 --- a/dspy/clients/databricks.py +++ b/dspy/clients/databricks.py @@ -47,16 +47,13 @@ def is_provider_model(model: str) -> bool: # We don't automatically infer Databricks models because Databricks is not a proprietary model provider. return False - @staticmethod - def kill(model: str, launch_kwargs: Optional[Dict[str, Any]] = None): - pass - @staticmethod def deploy_finetuned_model( model: str, data_format: Optional[DataFormat] = None, databricks_host: Optional[str] = None, databricks_token: Optional[str] = None, + deploy_timeout: int = 900, ): workspace_client = _get_workspace_client() model_version = next(workspace_client.model_versions.list(model)).version @@ -89,12 +86,17 @@ def deploy_finetuned_model( get_endpoint_response = requests.get( url=f"{databricks_host}/api/2.0/serving-endpoints/{model_name}", json={"name": model_name}, headers=headers ) + if get_endpoint_response.status_code == 200: + logger.info( + f"Serving endpoint {model_name} already exists, updating it instead of creating a new one." + ) # The serving endpoint already exists, we will update it instead of creating a new one. data = { "served_entities": [ { - "entity_name": model_name, + "name": model_name, + "entity_name": model, "entity_version": model_version, "min_provisioned_throughput": min_provisioned_throughput, "max_provisioned_throughput": max_provisioned_throughput, @@ -108,12 +110,16 @@ def deploy_finetuned_model( headers=headers, ) else: + logger.info( + f"Creating serving endpoint {model_name} on Databricks model serving!" + ) # Send the POST request to create the serving endpoint data = { "name": model_name, "config": { "served_entities": [ { + "name": model_name, "entity_name": model, "entity_version": model_version, "min_provisioned_throughput": min_provisioned_throughput, @@ -125,10 +131,17 @@ def deploy_finetuned_model( response = requests.post(url=f"{databricks_host}/api/2.0/serving-endpoints", json=data, headers=headers) - logger.info(response) - if response.status_code != 200: + if response.status_code == 200: + logger.info( + f"Successfully started creating/updating serving endpoint {model_name} on Databricks model serving!" + ) + else: raise ValueError(f"Failed to create serving endpoint: {response.json()}.") + logger.info( + f"Waiting for serving endpoint {model_name} to be ready, this might take a few minutes... You can check " + f"the status of the endpoint at {databricks_host}/ml/endpoints/{model_name}" + ) from openai import OpenAI client = OpenAI( @@ -136,7 +149,8 @@ def deploy_finetuned_model( base_url=f"{databricks_host}/serving-endpoints", ) # Wait for the deployment to be ready. - while True: + num_retries = deploy_timeout // 60 + for _ in range(num_retries): try: if data_format == DataFormat.chat: client.chat.completions.create( @@ -144,10 +158,16 @@ def deploy_finetuned_model( ) elif data_format == DataFormat.completion: client.completions.create(prompt="hi", model=model_name, max_tokens=1) + logger.info(f"Databricks model serving endpoint {model_name} is ready!") return except Exception: time.sleep(60) + raise ValueError( + f"Failed to create serving endpoint {model_name} on Databricks model serving platform within " + f"{deploy_timeout} seconds." + ) + @staticmethod def finetune( job: TrainingJobDatabricks, @@ -189,7 +209,9 @@ def finetune( databricks_token = train_kwargs.pop("databricks_token", None) skip_deploy = train_kwargs.pop("skip_deploy", False) + deploy_timeout = train_kwargs.pop("deploy_timeout", 900) + logger.info("Starting finetuning on Databricks... this might take a few minutes to finish.") finetuning_run = fm.create( model=model, **train_kwargs, @@ -201,36 +223,45 @@ def finetune( while True: job.run = fm.get(job.run) if job.run.status.display_name == "Completed": + logger.info("Finetuning run completed successfully!") break elif job.run.status.display_name == "Failed": raise ValueError( f"Finetuning run failed with status: {job.run.status.display_name}. Please check the Databricks " - f"workspace for more details. Finetuning job's metadata: {job.run}。" + f"workspace for more details. Finetuning job's metadata: {job.run}." ) else: time.sleep(60) + if skip_deploy: return None job.launch_started = True model_to_deploy = train_kwargs.get("register_to") job.endpoint_name = model_to_deploy.replace(".", "_") - DatabricksProvider.deploy_finetuned_model(model_to_deploy, data_format, databricks_host, databricks_token) + DatabricksProvider.deploy_finetuned_model( + model_to_deploy, data_format, databricks_host, databricks_token, deploy_timeout + ) job.launch_completed = True # The finetuned model name should be in the format: "databricks/". return f"databricks/{job.endpoint_name}" @staticmethod def upload_data(train_data: List[Dict[str, Any]], databricks_unity_catalog_path: str, data_format: DataFormat): + logger.info("Uploading finetuning data to Databricks Unity Catalog...") file_path = _save_data_to_local_file(train_data, data_format) w = _get_workspace_client() _create_directory_in_databricks_unity_catalog(w, databricks_unity_catalog_path) - with open(file_path, "rb") as f: - target_path = os.path.join(databricks_unity_catalog_path, os.path.basename(file_path)) - w.files.upload(target_path, f, overwrite=True) - return target_path + try: + with open(file_path, "rb") as f: + target_path = os.path.join(databricks_unity_catalog_path, os.path.basename(file_path)) + w.files.upload(target_path, f, overwrite=True) + logger.info("Successfully uploaded finetuning data to Databricks Unity Catalog!") + return target_path + except Exception as e: + raise ValueError(f"Failed to upload finetuning data to Databricks Unity Catalog: {e}") def _get_workspace_client() -> "WorkspaceClient": diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 1b08fcb17e..ac891841c3 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -33,10 +33,10 @@ def __init__( temperature: float = 0.0, max_tokens: int = 1000, cache: bool = True, - launch_kwargs: Optional[Dict[str, Any]] = None, callbacks: Optional[List[BaseCallback]] = None, num_retries: int = 3, provider=None, + finetuning_model: Optional[str] = None, **kwargs, ): """ @@ -54,6 +54,9 @@ def __init__( num_retries: The number of times to retry a request if it fails transiently due to network error, rate limiting, etc. Requests are retried with exponential backoff. + provider: The provider to use. If not specified, the provider will be inferred from the model. + finetuning_model: The model to finetune. In some providers, the models available for finetuning is different + from the models available for inference. """ # Remember to update LM.copy() if you modify the constructor! self.model = model @@ -65,6 +68,7 @@ def __init__( self.history = [] self.callbacks = callbacks or [] self.num_retries = num_retries + self.finetuning_model = finetuning_model #turned off by default to avoid LiteLLM logging during every LM call litellm.suppress_debug_info = dspy.settings.suppress_debug_info @@ -147,8 +151,9 @@ def thread_function_wrapper(): return self._run_finetune_job(job) thread = threading.Thread(target=thread_function_wrapper) + model_to_finetune = self.finetuning_model or self.model job = self.provider.TrainingJob( - thread=thread, model=self.model, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format + thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format ) thread.start() diff --git a/dspy/clients/openai.py b/dspy/clients/openai.py index 4e16bc25c3..77a540a127 100644 --- a/dspy/clients/openai.py +++ b/dspy/clients/openai.py @@ -83,7 +83,7 @@ def status(self) -> TrainingStatus: class OpenAIProvider(Provider): - + def __init__(self): super().__init__() self.finetunable = True diff --git a/dspy/clients/provider.py b/dspy/clients/provider.py index 0f9f4e5172..cc4e7147b6 100644 --- a/dspy/clients/provider.py +++ b/dspy/clients/provider.py @@ -50,7 +50,7 @@ def launch(model: str, launch_kwargs: Optional[Dict[str, Any]] = None): print(msg) @staticmethod - def kill(model: str, launch_kwargs: Optional[Dict[str, Any]] = None): + def kill(model: str, kill_kwargs: Optional[Dict[str, Any]] = None): msg = f"`kill()` is called for the auto-launched model `{model}`" msg += " -- no action is taken!" print(msg) @@ -62,6 +62,5 @@ def finetune( train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]] = None, data_format: Optional[Union[DataFormat, str]] = None, - **kwargs, ) -> str: raise NotImplementedError diff --git a/dspy/teleprompt/bootstrap_finetune.py b/dspy/teleprompt/bootstrap_finetune.py index 418c8f0a4b..32f0ca58f2 100644 --- a/dspy/teleprompt/bootstrap_finetune.py +++ b/dspy/teleprompt/bootstrap_finetune.py @@ -6,8 +6,8 @@ from dspy.adapters.base import Adapter from dspy.clients.utils_finetune import infer_data_format from dspy.evaluate.evaluate import Evaluate -from dspy.primitives.example import Example from dspy.predict.predict import Predict +from dspy.primitives.example import Example from dspy.primitives.program import Program from dspy.teleprompt.teleprompt import Teleprompter @@ -235,7 +235,8 @@ def set_missing_predictor_lms(program: Program) -> Program: def prepare_student(student: Program) -> Program: print("Ensuring that the student is not compiled") - assert not student._compiled, "The student program should not be compiled" + if getattr(student, "_compiled", False): + raise ValueError("The student program should not be compiled.") # TODO: Should we use reset_copy here? How would it affect the student # program's predictor LMs, if they are set? diff --git a/examples/finetune/databricks_finetuning.py b/examples/finetune/databricks_finetuning.py new file mode 100644 index 0000000000..372c0121ec --- /dev/null +++ b/examples/finetune/databricks_finetuning.py @@ -0,0 +1,70 @@ +from typing import Literal + +from datasets import load_dataset + +import dspy +from dspy.clients.databricks import DatabricksProvider + +# Define the range as a tuple of valid integers +CLASSES = tuple(range(77)) + +ds = load_dataset("PolyAI/banking77") +trainset_hf = ds["train"][:100] +trainset = [] + +for text, label in zip(trainset_hf["text"], trainset_hf["label"]): + # Each example should have two fields, `inputs` and `answer`, with `inputs` as the input field, + # and `answer` as the output field. + trainset.append(dspy.Example(text=text, answer=label).with_inputs("text")) + +gold = {text: label for text, label in zip(trainset_hf["text"], trainset_hf["label"])} + +lm = dspy.LM( + model="databricks/databricks-meta-llama-3-1-70b-instruct", + provider=DatabricksProvider, + finetuning_model="meta-llama/Llama-3.2-3B", +) + +dspy.settings.configure(lm=lm) +dspy.settings.experimental = True + + +def accuracy(example, pred, trace=None): + return int(example.answer == int(pred.answer)) + + +class Classify(dspy.Signature): + """As a part of a banking issue traiging system, classify the intent of a natural language query.""" + + text = dspy.InputField() + answer: Literal[CLASSES] = dspy.OutputField() + + +class Program(dspy.Module): + def __init__(self, oracle=False): + self.oracle = oracle + self.classify = dspy.ChainOfThoughtWithHint(Classify) + + def forward(self, text): + if self.oracle and text in gold: + hint = f"the right label is {gold[text]}" + else: + hint = None + return self.classify(text=text, hint=hint) + + +model = Program(oracle=True) +print("Try the original model: ", model("I am still waiting on my card?")) + +train_kwargs = { + "train_data_path": "/Volumes/main/chenmoney/testing/dspy_testing/classification", + "register_to": "main.chenmoney.finetuned_model_classification", + "task_type": "CHAT_COMPLETION", +} + +optimized = dspy.BootstrapFinetune(metric=accuracy, num_threads=10, train_kwargs=train_kwargs).compile( + student=model, trainset=trainset +) +optimized.oracle = False + +print("Try the optimized model: ", optimized("I am still waiting on my card?")) diff --git a/tests/clients/test_databricks.py b/tests/clients/test_databricks.py index 6f08191ce7..eb209e9f68 100644 --- a/tests/clients/test_databricks.py +++ b/tests/clients/test_databricks.py @@ -1,3 +1,9 @@ +"""Test the Databricks finetuning and deployment. + +This test requires valid Databricks credentials, so it is skipped on github actions. Right now it is only used for +manual testing. +""" + from dspy.clients.databricks import ( DatabricksProvider, _create_directory_in_databricks_unity_catalog,