From cc9179b44dd5563f53771a5ae366004bfe7b6237 Mon Sep 17 00:00:00 2001 From: Isaac Miller <17116851+isaacbmiller@users.noreply.github.com> Date: Fri, 18 Oct 2024 15:06:01 -0700 Subject: [PATCH] Revert "Refactor finetuning implementation to be 2.5 compatible" --- dspy/adapters/chat_adapter.py | 3 - dspy/clients/anyscale.py | 324 -------------------- dspy/clients/finetune.py | 129 -------- dspy/clients/lm.py | 84 +----- dspy/clients/lm_finetune_utils.py | 91 ------ dspy/clients/openai.py | 358 ----------------------- dspy/predict/chain_of_thought.py | 2 +- dspy/primitives/program.py | 29 +- dspy/teleprompt/finetune_teleprompter.py | 146 --------- dspy/teleprompt/random_search.py | 2 - 10 files changed, 22 insertions(+), 1146 deletions(-) delete mode 100644 dspy/clients/anyscale.py delete mode 100644 dspy/clients/finetune.py delete mode 100644 dspy/clients/lm_finetune_utils.py delete mode 100644 dspy/clients/openai.py delete mode 100644 dspy/teleprompt/finetune_teleprompter.py diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index a484186e08..5a20dcff32 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -67,9 +67,6 @@ def parse(self, signature, completion, _parse_values=True): return fields - def format_turn(self, signature, values, role, incomplete=False): - return format_turn(signature, values, role, incomplete) - def format_blob(blob): if "\n" not in blob and "«" not in blob and "»" not in blob: diff --git a/dspy/clients/anyscale.py b/dspy/clients/anyscale.py deleted file mode 100644 index 2c599ec7f9..0000000000 --- a/dspy/clients/anyscale.py +++ /dev/null @@ -1,324 +0,0 @@ -from typing import Any, Dict, List, Optional -import json -import yaml -import os - -from dspy.utils.logging import logger -from dspy.clients.finetune import ( - FinetuneJob, - TrainingMethod, - save_data, -) -from dspy.clients.openai import openai_data_validation - -try: - # AnyScale fine-tuning requires the following additional imports - import anyscale - from anyscale.job import JobConfig -except ImportError: - anyscale = None - - -# List of training methods supported by AnyScale -TRAINING_METHODS_ANYSCALE = [ - TrainingMethod.SFT, -] - -PROVIDER_ANYSCALE = "anyscale" - - -def is_anyscale_model(model: str) -> bool: - """Check if the model is an AnyScale model.""" - # TODO: This needs to be implemented to support fine-tuning - logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py") - return False - - -class FinetuneJobAnyScale(FinetuneJob): - - def __init__(self, *args, **kwargs): - self.job_id = None - self.model_names = None - super().__init__(*args, **kwargs) - - -def finetune_anyscale( - job: FinetuneJobAnyScale, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]]=None, - train_method: TrainingMethod = TrainingMethod.SFT, - ) -> str: - """Start the finetune job.""" - train_kwargs = train_kwargs or {} - assert "model" not in train_kwargs, "Model should not be in the train_kwargs" - train_kwargs_copy = train_kwargs.copy() - train_kwargs_copy["model"] = model - - logger.info("[Finetune] Starting training process...") - if train_method not in TRAINING_METHODS_ANYSCALE: - raise NotImplementedError(f"AnyScale can only support {TRAINING_METHODS_ANYSCALE} for the time being") - - logger.info("[Finetune] Validating the dataset format...") - if not verify_dataset(train_data): - # TODO: Does AnyScale support text completion models? - err = "[Finetune] Error: Unable to verify that the dataset is in the correct format." - logger.error(err) - raise RuntimeError(err) - - logger.info("[Finetune] Converting data to JSONL format...") - train_data_path = save_data(train_data, provider_name=PROVIDER_ANYSCALE) - logger.info("[Finetune] Submitting data to remote storage...") - remote_train_path, _ = submit_data(train_path=train_data_path) - logger.info(f"[Finetune] Data submitted. Remote train path: {remote_train_path}") - - logger.info("[Finetune] Generating configuration files...") - _, compute_config = generate_config_files(train_path=remote_train_path, **train_kwargs_copy) - - logger.info("[Finetune] Starting remote training...") - job_id = start_remote_training(compute_config=compute_config, **train_kwargs_copy) - job.job_id = job_id - logger.info(f"[Finetune] Remote training started. Job ID: {job_id}") - - logger.info("[Finetune] Waiting for training to complete...") - wait_for_training(job.job_id) - logger.info("[Finetune] Training completed.") - - logger.info("[Finetune] Retrieving model information...") - model_info = get_model_info(job.job_id) - logger.info(f"[Finetune] Model info retrieved: {model_info}") - - storage_uri = model_info["storage_uri"] - logger.info(f"[Finetune] Copying LoRA weights from {storage_uri}...") - model_names, lora_dynamic_path = copy_lora_weights(storage_uri, model_info, job.job_id) - logger.info(f"[Finetune] LoRA weights copied. Model names: {model_names}") - - - logger.info("[Finetune] Setting result in future object...") - model_step_pairs = sorted([(model_name, int(model_name.split("-")[-1])) for model_name in model_names], key=lambda x: x[1]) - last_model_checkpoint = model_step_pairs[-1][0] - logger.info("[Finetune] Training process completed successfully.") - - logger.info("[Finetune] Updating model config with the proper dynamic path") - serve_config_path = train_kwargs.pop("serve_config_path", "serve_1B.yaml") - update_model_config(lora_dynamic_path, serve_config_path, job_id) - job.model_names = model_names - - return last_model_checkpoint - -def wait_for_training(job_id): - """Wait for the training to complete.""" - anyscale.job.wait(id=job_id, timeout_s=18000) - - -def update_model_config(lora_dynamic_path: str, serve_config_path: str, job_id: str): - """Update the model config storage location with the job_id.""" - with open(serve_config_path, "r") as f: - serve_config = yaml.safe_load(f) - - model_config_location = serve_config["applications"][0]["args"]["llm_configs"][0] - - with open(model_config_location, "r") as f: - model_config = yaml.safe_load(f) - - dynamic_path_until_job_id = lora_dynamic_path.split(job_id)[0] + job_id - model_config["lora_config"]["dynamic_lora_loading_path"] = dynamic_path_until_job_id - - with open(model_config_location, "w") as f: - yaml.safe_dump(model_config, f) - - -def verify_dataset(dataset: List[dict[str, Any]]) -> bool: - """Verify the training arguments before starting training.""" - dataset_validation = openai_data_validation(dataset) - - if dataset_validation: - logger.error(f"Dataset validation failed: {dataset_validation}") - return False - - return True - - -def submit_data(train_path: str): - """Upload the data to the Workspace cloud storage.""" - storage = os.environ['ANYSCALE_ARTIFACT_STORAGE'] - - datasets = {"train": train_path} - - fine_tuning_file_ids = {} - for name, path in datasets.items(): - num_items = len(read_jsonl(path)) - logger.info(f"Number of items in {name} data: {num_items}") - - remote_path = os.path.join(storage, path.split("/")[-1]) - logger.info(f"Uploading {name} data to S3 at {remote_path}") - if remote_path[:2] == "s3": - os.system(f"aws s3 cp {path} {remote_path}") - elif remote_path[:2] == "gs": - os.system(f"gcloud storage cp {path} {remote_path}") - else: - os.system(f"cp {path} {remote_path}") - logger.info(f"Copied {path} to {remote_path}") - fine_tuning_file_ids[name] = remote_path - - return fine_tuning_file_ids["train"], fine_tuning_file_ids.get("val", None) - - -def generate_config_files(train_path: str, **kwargs): - base_model_yaml_path = kwargs.get("train_config_yaml", None) - assert kwargs["model"] is not None, "Model is required to generate the config files" - - use_lora = kwargs.get("use_lora", False) - example_dir = "" - lora_path = "configs/training/lora" if use_lora else "configs/training/full_param" - - - if not base_model_yaml_path: - def get_yaml_config(model_name): - if "llama" in model_name.lower(): - if "70b" in model_name: - return "llama-3-70b.yaml" - elif "13b" in model_name: - return "llama-3-70b.yaml" - else: - return "llama-3-8b.yaml" - elif "mistral" in model_name.lower(): - if "mixtral" in model_name.lower(): - return "mixtral-8x7b.yaml" - else: - return "mistral-7b.yaml" - else: - raise RuntimeError("No default yaml found for the model") - - default_model_yaml_path = get_yaml_config(kwargs["model"]) - base_model_yaml_path = os.path.join(example_dir, lora_path, default_model_yaml_path) - logger.info(f"Using default yaml template for model: {base_model_yaml_path}") - - model_config_data = yaml.safe_load(open(base_model_yaml_path, "r")) - model_config_data.update(kwargs.get("hyperparameters", {})) - - model_config_data["model_id"] = kwargs["model"] - - custom_modifications = { - "model_id": kwargs["model"], - "train_path": train_path, - "logger": { - "provider": "wandb", - }, - "num_checkpoints_to_keep": 10 - } - if kwargs.get("output_dir", None): - custom_modifications["output_dir"] = kwargs["output_dir"] - - model_config_data.update(custom_modifications) - model_config_data = {k: v for k, v in model_config_data.items() if v is not None} - - def freeze(d): - if isinstance(d, dict): - return tuple(sorted((key, freeze(value)) for key, value in d.items())) - elif isinstance(d, list): - return tuple(freeze(value) for value in sorted(d)) - elif isinstance(d, set): - return tuple(freeze(value) for value in sorted(d)) - return d - - def hash_dict(d): - return hash(freeze(d)) - dict_sorted_hash = hash_dict(model_config_data) - if dict_sorted_hash < 0: - dict_sorted_hash = -dict_sorted_hash - filename = f"model_config_dspy_{dict_sorted_hash}.yaml" - logger.info(f"Model config data: {model_config_data}") - yaml.safe_dump(model_config_data, open(filename, "w")) - - ft_path = os.path.join("utils", "ft.py") - - compute_config_dict = { - "name": "dspy-llmforge-fine-tuning-job", - "entrypoint": f"llmforge anyscale finetune {filename}", - "working_dir": ".", - "image_uri": "localhost:5555/anyscale/llm-forge:0.5.6", - "requirements": [ - "wandb", - ], - "env_vars": { - "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", ""), - "HF_TOKEN": os.environ.get("HF_TOKEN", ""), - "HF_HOME": os.environ.get("HF_HOME", ""), - } - } - compute_config_kwargs = kwargs.get("compute_config", {}) - compute_config_dict.update(compute_config_kwargs) - compute_config = JobConfig(**compute_config_dict) - - job_runner_config_path = kwargs.get("compute_yaml_path", "job_runner_config.yaml") - - return job_runner_config_path, compute_config - - -def start_remote_training(compute_config, **kwargs) -> str: - job_id: str = anyscale.job.submit(compute_config) - return job_id - - -def wait_for_training(job_id): - logger.info("Waiting for training to complete") - anyscale.job.wait(id=job_id, timeout_s=18000) - - -def get_model_info(job_id): - return anyscale.llm.model.get(job_id=job_id).to_dict() - - -def copy_lora_weights(storage_uri, model_info, job_id): - try: - from google.cloud import storage - - storage_client = storage.Client() - - bucket_name = storage_uri.split('/')[2] - source_folder = '/'.join(storage_uri.split('/')[3:-1]) - logger.info(f"Source folder: {source_folder}") - - bucket = storage_client.bucket(bucket_name) - - blobs = bucket.list_blobs(prefix=source_folder) - - subfolders = set() - for blob in blobs: - if '/' in blob.name[len(source_folder):]: - subfolder = blob.name.split('/')[:-1] - subfolders.add('/'.join(subfolder)) - - base_model_id = model_info["base_model_id"] - lora_dynamic_path = f"dspy/lora_weights/{job_id}/{base_model_id}" - - model_names = [] - for subfolder in subfolders: - subfolder_name = subfolder.split('/')[-1] - destination_folder = f"{lora_dynamic_path}:{subfolder_name}" - if subfolder_name.startswith("epoch"): - model_names.append("/".join(destination_folder.split("/")[-2:])) - else: - continue - - subfolder_blobs = bucket.list_blobs(prefix=subfolder) - - for blob in subfolder_blobs: - source_blob = bucket.blob(blob.name) - destination_blob_name = f"{destination_folder}/{blob.name.split('/')[-1]}" - bucket.copy_blob(source_blob, bucket, destination_blob_name) - logger.info(f"Copied {source_blob.name} to {destination_blob_name}") - - logger.info(f"All subfolders copied to: gs://{bucket_name}/{lora_dynamic_path}") - completed_path = f"gs://{bucket_name}/{lora_dynamic_path}" - return model_names, completed_path - - except Exception as e: - logger.error(f"An error occurred: {str(e)}") - raise e - - -def read_jsonl(filename): - with open(filename, "r") as f: - return [json.loads(line) for line in f] diff --git a/dspy/clients/finetune.py b/dspy/clients/finetune.py deleted file mode 100644 index d9d0df8e8e..0000000000 --- a/dspy/clients/finetune.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -from abc import abstractmethod -from concurrent.futures import Future -from enum import Enum -from pathlib import Path -from typing import Any, Dict, List, Optional - -import ujson -from datasets.fingerprint import Hasher - - -def get_finetune_directory() -> str: - """Get the directory to save the fine-tuned models.""" - # TODO: Move to a centralized location with all the other env variables - dspy_cachedir = os.environ.get("DSPY_CACHEDIR") - dspy_cachedir = dspy_cachedir or os.path.join(Path.home(), ".dspy_cache") - finetune_dir = os.path.join(dspy_cachedir, "finetune") - finetune_dir = os.path.abspath(finetune_dir) - return finetune_dir - - -FINETUNE_DIRECTORY = get_finetune_directory() - - -class TrainingMethod(str, Enum): - """Enum class for training methods. - - When comparing enums, Python checks for object IDs, which means that the - enums can't be compared directly. Subclassing the Enum class along with the - str class allows for direct comparison of the enums. - """ - - SFT = "SFT" - Preference = "Preference" - - -class TrainingStatus(str, Enum): - """Enum class for remote training status.""" - - not_started = "not_started" - pending = "pending" - running = "running" - succeeded = "succeeded" - failed = "failed" - cancelled = "cancelled" - - -"""Dictionary mapping training methods to the data keys they require.""" -TRAINING_METHOD_TO_DATA_KEYS = { - TrainingMethod.SFT: ["prompt", "completion"], - TrainingMethod.Preference: ["prompt", "chosen", "rejected"], -} - - -class FinetuneJob(Future): - def __init__( - self, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", - ): - self.model = model - self.train_data = train_data - self.train_kwargs: Dict[str, Any] = train_kwargs or {} - self.train_method = train_method - self.provider = provider - super().__init__() - - def get_kwargs(self): - return dict( - model=self.model, - train_data=self.train_data, - train_kwargs=self.train_kwargs, - train_method=self.train_method, - provider=self.provider, - ) - - def __repr__(self): - return str(self) - - # Subclasses should override the cancel method to cancel the finetune job; - # then call the super's cancel method so that the future can be cancelled. - def cancel(self): - """Cancel the finetune job.""" - super().cancel() - - @abstractmethod - def status(self): - """Get the status of the finetune job.""" - raise NotImplementedError("Method `status` is not implemented.") - - -def validate_finetune_data(data: List[Dict[str, Any]], train_method: TrainingMethod): - """Validate the finetune data based on the training method.""" - # Get the required data keys for the training method - required_keys = TRAINING_METHOD_TO_DATA_KEYS[train_method] - - # Check if the training data has the required keys - for ind, data_dict in enumerate(data): - if not all([key in data_dict for key in required_keys]): - raise ValueError( - f"The datapoint at index {ind} is missing the keys required for {train_method} training. Expected: " - f"{required_keys}, Found: {data_dict.keys()}" - ) - - -def save_data( - data: List[Dict[str, Any]], - provider_name: Optional[str] = None, -) -> str: - """Save the fine-tuning data to a file.""" - # Construct the file name based on the data hash - hash = Hasher.hash(data) - file_name = f"{hash}.jsonl" - file_name = f"{provider_name}_{file_name}" if provider_name else file_name - - # Find the directory to save the fine-tuning data - finetune_parent_dir = get_finetune_directory() - os.makedirs(finetune_parent_dir, exist_ok=True) - - # Save the data to a file - file_path = os.path.join(finetune_parent_dir, file_name) - file_path = os.path.abspath(file_path) - with open(file_path, "w") as f: - for item in data: - f.write(ujson.dumps(item) + "\n") - return file_path diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index f2949372ac..186c622b9f 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,25 +1,15 @@ -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime import functools import os -from pathlib import Path -from typing import Any, Dict, List, Optional -import ujson import uuid - -from dspy.utils.logging import logger -from dspy.clients.finetune import FinetuneJob, TrainingMethod -from dspy.clients.lm_finetune_utils import ( - get_provider_finetune_job_class, - execute_finetune_job, -) +from datetime import datetime +from pathlib import Path import litellm +import ujson from litellm.caching import Cache - -DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") +disk_cache_dir = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") litellm.telemetry = False if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: @@ -27,26 +17,13 @@ class LM: - def __init__( - self, - model, - model_type='chat', - temperature=0.0, - max_tokens=1000, - cache=True, - launch_kwargs=None, - **kwargs - ): - # Remember to update LM.copy() if you modify the constructor! + def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True, **kwargs): self.model = model self.model_type = model_type self.cache = cache - self.launch_kwargs = launch_kwargs or {} self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] - # TODO: Arbitrary model strings could include the substring "o1-". We - # should find a more robust way to check for the "o1-" family models. if "o1-" in model: assert ( max_tokens >= 5000 and temperature == 1.0 @@ -82,56 +59,10 @@ def __call__(self, prompt=None, messages=None, **kwargs): self.history.append(entry) return outputs - + def inspect_history(self, n: int = 1): _inspect_history(self, n) - def launch(self): - """Send a request to the provider to launch the model, if needed.""" - msg = f"`launch()` is called for the auto-launched model {self.model}" - msg += " -- no action is taken!" - logger.info(msg) - - def kill(self): - """Send a request to the provider to kill the model, if needed.""" - msg = f"`kill()` is called for the auto-launched model {self.model}" - msg += " -- no action is taken!" - logger.info(msg) - - def finetune( - self, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]]=None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", - cache_finetune: bool = True, - ) -> FinetuneJob: - """Start model fine-tuning, if supported.""" - from dspy import settings as settings - err = "Fine-tuning is an experimental feature." - err += " Set `dspy.settings.experimental` to `True` to use it." - assert settings.experimental, err - - FinetuneJobClass = get_provider_finetune_job_class(provider=provider) - finetune_job = FinetuneJobClass( - model=self.model, - train_data=train_data, - train_kwargs=train_kwargs, - train_method=train_method, - provider=provider - ) - - executor = ThreadPoolExecutor(max_workers=1) - executor.submit( - execute_finetune_job, - finetune_job, - lm=self, - cache_finetune=cache_finetune - ) - executor.shutdown(wait=False) - - return finetune_job - def copy(self, **kwargs): """Returns a copy of the language model with possibly updated parameters.""" @@ -168,7 +99,6 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}) kwargs = ujson.loads(request) # Extract the provider and model from the model string. - # TODO: Not all the models are in the format of "provider/model" model = kwargs.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] diff --git a/dspy/clients/lm_finetune_utils.py b/dspy/clients/lm_finetune_utils.py deleted file mode 100644 index c9b4863d03..0000000000 --- a/dspy/clients/lm_finetune_utils.py +++ /dev/null @@ -1,91 +0,0 @@ -from typing import Any, Dict, List, Optional, Type, Union - -from dspy.clients.anyscale import FinetuneJobAnyScale, finetune_anyscale -from dspy.clients.finetune import FinetuneJob, TrainingMethod -from dspy.clients.openai import FinetuneJobOpenAI, finetune_openai -from dspy.utils.logging import logger - -_PROVIDER_ANYSCALE = "anyscale" -_PROVIDER_OPENAI = "openai" - - -def get_provider_finetune_job_class(provider: str) -> Type[FinetuneJob]: - """Get the FinetuneJob class for the provider.""" - provider_to_job_class = { - _PROVIDER_ANYSCALE: FinetuneJobAnyScale, - _PROVIDER_OPENAI: FinetuneJobOpenAI, - } - return provider_to_job_class[provider] - - -def get_provider_finetune_function(provider: str) -> callable: - """Return the finetune function for the given model.""" - provider_to_finetune_function = { - _PROVIDER_ANYSCALE: finetune_anyscale, - _PROVIDER_OPENAI: finetune_openai, - } - return provider_to_finetune_function[provider] - - -# Note: Type of LM should be LM. We aren't importing it here to avoid -# circular imports. -def execute_finetune_job(job: FinetuneJob, lm: Any, cache_finetune: bool = True): - """Execute the finetune job in a blocking manner.""" - try: - job_kwargs = job.get_kwargs() - if cache_finetune: - model = cached_finetune(job=job, **job_kwargs) - else: - model = finetune(job=job, **job_kwargs) - lm = lm.copy(model=model) - job.set_result(lm) - except Exception as err: - logger.error(err) - job.set_result(err) - - -# TODO: Add DiskCache, ignore job -def cached_finetune( - job, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", -) -> Union[str, Exception]: - return finetune( - job=job, - model=model, - train_data=train_data, - train_kwargs=train_kwargs, - train_method=train_method, - provider=provider, - ) - - -def finetune( - job, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, - provider: str = "openai", -) -> Union[str, Exception]: - """Fine-tune a new model based on the given model.""" - # Get the fine-tuning provider - try: - # Get the finetune function - provider_finetune_function = get_provider_finetune_function(provider) - - # Fine-tune a new model based on the given model - model = provider_finetune_function( - job=job, - model=model, - train_data=train_data, - train_kwargs=train_kwargs, - train_method=train_method, - ) - except Exception as err: - raise err - - return model diff --git a/dspy/clients/openai.py b/dspy/clients/openai.py deleted file mode 100644 index ec084fe5f5..0000000000 --- a/dspy/clients/openai.py +++ /dev/null @@ -1,358 +0,0 @@ -import re -import time -from collections import defaultdict -from typing import Any, Dict, List, Optional, Union - -import openai - -from dspy.clients.finetune import ( - FinetuneJob, - TrainingMethod, - TrainingStatus, - save_data, - validate_finetune_data, -) -from dspy.utils.logging import logger - -# Provider name -PROVIDER_OPENAI = "openai" - - -def is_openai_model(model: str) -> bool: - """Check if the model is an OpenAI model.""" - # Filter the provider_prefix, if exists - provider_prefix = f"{PROVIDER_OPENAI}/" - if model.startswith(provider_prefix): - model = model[len(provider_prefix) :] - - client = openai.OpenAI() - valid_model_names = [model.id for model in client.models.list().data] - # Check if the model is a base OpenAI model - if model in valid_model_names: - return True - - # Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI models - # have the prefix "ft::", followed by a string specifying - # the fine-tuned model. The following RegEx pattern is used to match the - # base model name. - # TODO: This part can be updated to match the actual fine-tuned model names - # by making a call to the OpenAI API to be more exact, but this might - # require an API key with the right permissions. - match = re.match(r"ft:([^:]+):", model) - if match and match.group(1) in valid_model_names: - return True - - return False - - -class FinetuneJobOpenAI(FinetuneJob): - def __init__(self, *args, **kwargs): - self.provider_file_id = None # TODO: Can we get this using the job_id? - self.provider_job_id = None - super().__init__(*args, **kwargs) - - def cancel(self): - # Cancel the provider job - if _does_job_exist(self.provider_job_id): - status = _get_training_status(self.provider_job_id) - if _is_terminal_training_status(status): - err_msg = "Jobs that are complete cannot be canceled." - err_msg += f" Job with ID {self.provider_job_id} is done." - raise Exception(err_msg) - openai.fine_tuning.jobs.cancel(self.provider_job_id) - self.provider_job_id = None - - # Delete the provider file - # TODO: Should there be a separate clean method? - if self.provider_file_id is not None: - if _does_file_exist(self.provider_file_id): - openai.files.delete(self.provider_file_id) - self.provider_file_id = None - - # Call the super's cancel method after the custom cancellation logic - super().cancel() - - def status(self) -> TrainingStatus: - status = _get_training_status(self.provider_job_id) - return status - - -def finetune_openai( - job: FinetuneJobOpenAI, - model: str, - train_data: List[Dict[str, Any]], - train_kwargs: Optional[Dict[str, Any]] = None, - train_method: TrainingMethod = TrainingMethod.SFT, -) -> str: - train_kwargs = train_kwargs or {} - train_method = TrainingMethod.SFT # Note: This could be an argument; ignoring method - - # Validate train data and method - logger.info("[Finetune] Validating the formatting of the data") - _validate_data(train_data, train_method) - logger.info("[Finetune] Done!") - - # Convert to the OpenAI format - logger.info("[Finetune] Converting the data to the OpenAI format") - # TODO: Should we use the system prompt? - train_data = _convert_data(train_data) - logger.info("[Finetune] Done!") - - # Save to a file - logger.info("[Finetune] Saving the data to a file") - data_path = save_data(train_data, provider_name=PROVIDER_OPENAI) - logger.info("[Finetune] Done!") - - # Upload the data to the cloud - logger.info("[Finetune] Uploading the data to the provider") - provider_file_id = _upload_data(data_path) - job.provider_file_id = provider_file_id - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Start remote training") - # We utilize model and train_kwargs here - provider_job_id = _start_remote_training( - train_file_id=job.provider_file_id, - model=model, - train_kwargs=train_kwargs, - ) - job.provider_job_id = provider_job_id - # job.provider_job_id = "ftjob-ZdEL1mUDk0dwdDuZJQOng8Vv" - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Wait for training to complete") - # TODO: Would it be possible to stream the logs? - _wait_for_job(job) - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Get trained model if the run was a success") - model = _get_trained_model(job) - logger.info("[Finetune] Done!") - - return model - - -_SUPPORTED_TRAINING_METHODS = [ - TrainingMethod.SFT, -] - - -def _get_training_status(job_id: str) -> Union[TrainingStatus, Exception]: - # TODO: Should this type be shared across all fine-tune clients? - provider_status_to_training_status = { - "validating_files": TrainingStatus.pending, - "queued": TrainingStatus.pending, - "running": TrainingStatus.running, - "succeeded": TrainingStatus.succeeded, - "failed": TrainingStatus.failed, - "cancelled": TrainingStatus.cancelled, - } - - # Check if there is an active job - if job_id is None: - logger.info("There is no active job.") - return TrainingStatus.not_started - - err_msg = f"Job with ID {job_id} does not exist." - assert _does_job_exist(job_id), err_msg - - # Retrieve the provider's job and report the status - provider_job = openai.fine_tuning.jobs.retrieve(job_id) - provider_status = provider_job.status - status = provider_status_to_training_status[provider_status] - - return status - - -def _does_job_exist(job_id: str) -> bool: - try: - # TODO: Error handling is vague - openai.fine_tuning.jobs.retrieve(job_id) - return True - except Exception: - return False - - -def _does_file_exist(file_id: str) -> bool: - try: - # TODO: Error handling is vague - openai.files.retrieve(file_id) - return True - except Exception: - return False - - -def _is_terminal_training_status(status: TrainingStatus) -> bool: - return status in [ - TrainingStatus.succeeded, - TrainingStatus.failed, - TrainingStatus.cancelled, - ] - - -def _validate_data(data: Dict[str, str], train_method: TrainingMethod) -> Optional[Exception]: - # Check if this train method is supported - if train_method not in _SUPPORTED_TRAINING_METHODS: - err_msg = f"OpenAI does not support the training method {train_method}." - raise ValueError(err_msg) - - validate_finetune_data(data, train_method) - - -def _convert_data( - data: List[Dict[str, str]], - system_prompt: Optional[str] = None, -) -> Union[List[Dict[str, Any]], Exception]: - # Item-wise conversion function - def _row_converter(d): - messages = [{"role": "user", "content": d["prompt"]}, {"role": "assistant", "content": d["completion"]}] - if system_prompt: - messages.insert(0, {"role": "system", "content": system_prompt}) - messages_dict = {"messages": messages} - return messages_dict - - # Convert the data to the OpenAI format; validate the converted data - converted_data = list(map(_row_converter, data)) - openai_data_validation(converted_data) - return converted_data - - -def _upload_data(data_path: str) -> str: - # Upload the data to the provider - provider_file = openai.files.create( - file=open(data_path, "rb"), - purpose="fine-tune", - ) - return provider_file.id - - -def _start_remote_training(train_file_id: str, model: id, train_kwargs: Optional[Dict[str, Any]] = None) -> str: - train_kwargs = train_kwargs or {} - provider_job = openai.fine_tuning.jobs.create( - model=model, - training_file=train_file_id, - hyperparameters=train_kwargs, - ) - return provider_job.id - - -def _wait_for_job( - job: FinetuneJobOpenAI, - poll_frequency: int = 60, -): - while not _is_terminal_training_status(job.status()): - time.sleep(poll_frequency) - - -def _get_trained_model(job): - status = job.status() - if status != TrainingStatus.succeeded: - err_msg = f"Job status is {status}." - err_msg += f" Must be {TrainingStatus.succeeded} to retrieve the model." - logger.error(err_msg) - raise Exception(err_msg) - - provider_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id) - finetuned_model = provider_job.fine_tuned_model - return finetuned_model - - -# Adapted from https://cookbook.openai.com/examples/chat_finetuning_data_prep -def openai_data_validation(dataset: List[dict[str, Any]]): - format_errors = defaultdict(int) - for ex in dataset: - if not isinstance(ex, dict): - format_errors["data_type"] += 1 - continue - - messages = ex.get("messages", None) - if not messages: - format_errors["missing_messages_list"] += 1 - continue - - for message in messages: - if "role" not in message or "content" not in message: - format_errors["message_missing_key"] += 1 - - if any(k not in ("role", "content", "name", "function_call", "weight") for k in message): - format_errors["message_unrecognized_key"] += 1 - - if message.get("role", None) not in ("system", "user", "assistant", "function"): - format_errors["unrecognized_role"] += 1 - - content = message.get("content", None) - function_call = message.get("function_call", None) - - if (not content and not function_call) or not isinstance(content, str): - format_errors["missing_content"] += 1 - - if not any(message.get("role", None) == "assistant" for message in messages): - format_errors["example_missing_assistant_message"] += 1 - - # Raise an error if there are any format errors - if format_errors: - err_msg = "Found errors in the dataset format using the OpenAI API." - err_msg += " Here are the number of datapoints for each error type:" - for k, v in format_errors.items(): - err_msg += "\n {k}: {v}" - raise ValueError(err_msg) - - -def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]: - n_missing_system = 0 - n_missing_user = 0 - n_messages = [] - convo_lens = [] - assistant_message_lens = [] - - for ex in dataset: - messages = ex["messages"] - if not any(message["role"] == "system" for message in messages): - n_missing_system += 1 - if not any(message["role"] == "user" for message in messages): - n_missing_user += 1 - n_messages.append(len(messages)) - convo_lens.append(num_tokens_from_messages(messages)) - assistant_message_lens.append(num_assistant_tokens_from_messages(messages)) - n_too_long = sum([length > 16385 for length in convo_lens]) - - if n_too_long > 0: - logger.info( - f"There are {n_too_long} examples that may be over the 16,385 token limit, they will be truncated during fine-tuning." - ) - - if n_missing_system > 0: - logger.info(f"There are {n_missing_system} examples that are missing a system message.") - - if n_missing_user > 0: - logger.info(f"There are {n_missing_user} examples that are missing a user message.") - - return convo_lens - - -def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1): - import tiktoken - - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = 0 - for message in messages: - num_tokens += tokens_per_message - for key, value in message.items(): - num_tokens += len(encoding.encode(value)) - if key == "name": - num_tokens += tokens_per_name - num_tokens += 3 - return num_tokens - - -def num_assistant_tokens_from_messages(messages): - import tiktoken - - encoding = tiktoken.get_encoding("cl100k_base") - - num_tokens = 0 - for message in messages: - if message["role"] == "assistant": - num_tokens += len(encoding.encode(message["content"])) - return num_tokens diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index b6b9cb575e..74656697ca 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -29,7 +29,7 @@ def __init__(self, signature, rationale_type=None, activated=True, **config): rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) # Add "rationale" field to the output signature. - if isinstance(dspy.settings.lm, dspy.LM) or dspy.settings.experimental: + if isinstance(dspy.settings.lm, dspy.LM): extended_signature = signature.prepend("reasoning", rationale_type, type_=str) else: extended_signature = signature.prepend("rationale", rationale_type, type_=str) diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index 3cf6220b6c..ac3646275f 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -1,12 +1,18 @@ import magicattr -import dspy from dspy.primitives.assertions import * from dspy.primitives.module import BaseModule class ProgramMeta(type): pass + # def __call__(cls, *args, **kwargs): + # obj = super(ProgramMeta, cls).__call__(*args, **kwargs) + + # if issubclass(cls, Program) and not getattr(obj, "_program_init_called", False): + # obj._base_init() + # obj._program_init_called = True + # return obj class Module(BaseModule, metaclass=ProgramMeta): @@ -26,29 +32,23 @@ def named_predictors(self): def predictors(self): return [param for _, param in self.named_predictors()] - + def set_lm(self, lm): - if not dspy.settings.experimental: - raise ValueError( - "Setting or getting the LM of a program is an experimental feature. Please enable the " - "'dspy.settings.experimental' flag to use these features." - ) + import dspy + assert dspy.settings.experimental, "Setting the lm is an experimental feature." for _, param in self.named_predictors(): param.lm = lm def get_lm(self): - if not dspy.settings.experimental: - raise ValueError( - "Setting or getting the LM of a program is an experimental feature. Please enable the " - "'dspy.settings.experimental' flag to use these features." - ) + import dspy + assert dspy.settings.experimental, "Getting the lm is an experimental feature." all_used_lms = [param.lm for _, param in self.named_predictors()] if len(set(all_used_lms)) == 1: return all_used_lms[0] - + raise ValueError("Multiple LMs are being used in the module.") def __repr__(self): @@ -95,5 +95,4 @@ def activate_assertions(self, handler=backtrack_handler, **handler_args): def set_attribute_by_name(obj, name, value): magicattr.set(obj, name, value) - -Program = Module +Program = Module \ No newline at end of file diff --git a/dspy/teleprompt/finetune_teleprompter.py b/dspy/teleprompt/finetune_teleprompter.py deleted file mode 100644 index c519035d08..0000000000 --- a/dspy/teleprompt/finetune_teleprompter.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import Any, Callable, Dict, List, Optional, Union - -import dspy -from dspy.evaluate.evaluate import Evaluate -from dspy.primitives.example import Example -from dspy.primitives.prediction import Prediction -from dspy.primitives.program import Program -from dspy.utils.logging import logger - - -# TODO: Shared below are useful functions. Similar procedures are implemented -# separately and used by other DSPy teleprompters. These can be moved to shared -# locations. -def prepare_teacher(student: Program, teacher: Program = None) -> Program: - """Prepare the teacher program with respect to the student program. - Args: - student: The student program. - teacher: The teacher program. If `None`, a copy of the student program - is used as the teacher. Defaults to `None`. - """ - # If teacher is None, use a copy of the student program as the teacher - if teacher is None: - logger.info("No teacher provided. Using a copy of the student program as the teacher.") - teacher = student.deepcopy() - else: - teacher = teacher.deepcopy() - - # Ensure that the student and teacher programs have the same structure - logger.info("Ensuring that the student and teacher are are structurally equivalent.") - student._assert_structural_equivalency(teacher) - - # Ensure that the predictors of the programs point to different objects - logger.info("Ensuring that the student and teacher programs do not share predictors.") - student._assert_no_shared_predictor(teacher) - - # Ensure that the LM consistency property is satisfied - logger.info("Ensuring that the teacher program satisfies the LM consistency property.") - teacher._assert_lm_consistency() - - # If the global LM is being used, set it to the LMs of the copied teacher - # program predictors to to avoid handling the same edge cases later - if dspy.settings.lm: - teacher._set_all_predictor_lms(dspy.settings.lm) - - return teacher - - -def convert_to_module_level_message_data( - data: List[Dict], - keep_data_keys: bool = False, - exclude_demos: bool = False, - try_to_record_lm_kwargs: bool = False, - program: Program = None, -) -> List[Dict]: - """Wrapper around the function - `build_messages_from_trace`, calling it on the "trace" field - of each dictionary in the input data list and combiningin the results into - a list of prompt-completion data dictionaries.""" - - prompt_completion_data = [] - for data_dict in data: - trace = data_dict["trace"] - trace_prompt_comletion_data = build_messages_from_trace( - trace=trace, exclude_demos=exclude_demos, try_to_record_lm_kwargs=try_to_record_lm_kwargs, program=program - ) - for prompt_completion_dict in trace_prompt_comletion_data: - if keep_data_keys: - prompt_completion_dict = {**data_dict, **prompt_completion_dict} - prompt_completion_data.append(prompt_completion_dict) - return prompt_completion_data - - -def build_messages_from_trace( - trace: List[Dict], - exclude_demos: bool = False, - try_to_record_lm_kwargs: bool = False, - program: Program = None, -) -> Dict[str, List[Dict[str, Any]]]: - messages = [] - # If the program is provided, build the predictor index to name mapping - if program: - pred_ind_to_name = {ind: name for ind, (name, _) in enumerate(program.named_predictors())} - - # Build the prompt-completion data - - adapter = dspy.settings.adapter or dspy.ChatAdapter() - data = [] - - # TODO: Make sure that this works for multi-stage pipelines - for pred_ind, (pred, inputs, outputs) in enumerate(trace): - # Get the demos from the predictor if exclude_demos is False - demos = [] if exclude_demos else pred.demos - messages = adapter.format(pred.signature, demos, inputs) - messages.append( - adapter.format_turn(signature=pred.signature, values=outputs, role="assistant", incomplete=False) - ) - data.append(messages) - - return data - - -def bootstrap_data( - program: Program, - dataset: List[Example], - metric: Optional[Callable[[Example, Prediction, Optional[List]], Union[bool, int, float]]] = None, - num_threads=1, - max_errors: int = 0, -) -> List[Dict[str, Any]]: - """Bootstrap example, prediction, trace, example_ind, score data for the program using the dataset.""" - data = [] - - # Use Evaluate to call the program have the responses cached - cname = program.__class__.__name__ - info = f"Bootstrapping data on {len(dataset)} examples with the program {cname}, with {num_threads} threads" - logger.info(info) - evaluator = Evaluate( - devset=dataset, - num_threads=num_threads, - display_progress=True, - max_errors=max_errors, - provide_traceback=True, - ) - evaluator(program, metric=metric) - - data = [] - for example_ind, example in enumerate(dataset): - data_dict = bootstrap_one_example(example=example, example_ind=example_ind, program=program, metric=metric) - if data_dict is not None: - data.append(data_dict) - - return data - - -def bootstrap_one_example( - example: Any, example_ind: int, program: Program, metric: Optional[Callable] = None -) -> Dict[str, Any]: - with dspy.context(trace=[]): - prediction = program(**example.inputs()) - trace = dspy.settings.trace - score = metric(example, prediction, trace) if metric else None - - data_dict = {"example": example, "prediction": prediction, "trace": trace, "example_ind": example_ind} - if metric: - data_dict["score"] = score - - return data_dict diff --git a/dspy/teleprompt/random_search.py b/dspy/teleprompt/random_search.py index 4fa8ba23fd..57bd38150c 100644 --- a/dspy/teleprompt/random_search.py +++ b/dspy/teleprompt/random_search.py @@ -85,7 +85,6 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None max_labeled_demos=self.max_labeled_demos, teacher_settings=self.teacher_settings, max_rounds=self.max_rounds, - max_errors=self.max_errors, ) program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy) @@ -102,7 +101,6 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None max_labeled_demos=self.max_labeled_demos, teacher_settings=self.teacher_settings, max_rounds=self.max_rounds, - max_errors=self.max_errors, ) program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)