diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 5a20dcff32..a484186e08 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -67,6 +67,9 @@ def parse(self, signature, completion, _parse_values=True): return fields + def format_turn(self, signature, values, role, incomplete=False): + return format_turn(signature, values, role, incomplete) + def format_blob(blob): if "\n" not in blob and "«" not in blob and "»" not in blob: diff --git a/dspy/clients/anyscale.py b/dspy/clients/anyscale.py new file mode 100644 index 0000000000..2c599ec7f9 --- /dev/null +++ b/dspy/clients/anyscale.py @@ -0,0 +1,324 @@ +from typing import Any, Dict, List, Optional +import json +import yaml +import os + +from dspy.utils.logging import logger +from dspy.clients.finetune import ( + FinetuneJob, + TrainingMethod, + save_data, +) +from dspy.clients.openai import openai_data_validation + +try: + # AnyScale fine-tuning requires the following additional imports + import anyscale + from anyscale.job import JobConfig +except ImportError: + anyscale = None + + +# List of training methods supported by AnyScale +TRAINING_METHODS_ANYSCALE = [ + TrainingMethod.SFT, +] + +PROVIDER_ANYSCALE = "anyscale" + + +def is_anyscale_model(model: str) -> bool: + """Check if the model is an AnyScale model.""" + # TODO: This needs to be implemented to support fine-tuning + logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py") + return False + + +class FinetuneJobAnyScale(FinetuneJob): + + def __init__(self, *args, **kwargs): + self.job_id = None + self.model_names = None + super().__init__(*args, **kwargs) + + +def finetune_anyscale( + job: FinetuneJobAnyScale, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]]=None, + train_method: TrainingMethod = TrainingMethod.SFT, + ) -> str: + """Start the finetune job.""" + train_kwargs = train_kwargs or {} + assert "model" not in train_kwargs, "Model should not be in the train_kwargs" + train_kwargs_copy = train_kwargs.copy() + train_kwargs_copy["model"] = model + + logger.info("[Finetune] Starting training process...") + if train_method not in TRAINING_METHODS_ANYSCALE: + raise NotImplementedError(f"AnyScale can only support {TRAINING_METHODS_ANYSCALE} for the time being") + + logger.info("[Finetune] Validating the dataset format...") + if not verify_dataset(train_data): + # TODO: Does AnyScale support text completion models? + err = "[Finetune] Error: Unable to verify that the dataset is in the correct format." + logger.error(err) + raise RuntimeError(err) + + logger.info("[Finetune] Converting data to JSONL format...") + train_data_path = save_data(train_data, provider_name=PROVIDER_ANYSCALE) + logger.info("[Finetune] Submitting data to remote storage...") + remote_train_path, _ = submit_data(train_path=train_data_path) + logger.info(f"[Finetune] Data submitted. Remote train path: {remote_train_path}") + + logger.info("[Finetune] Generating configuration files...") + _, compute_config = generate_config_files(train_path=remote_train_path, **train_kwargs_copy) + + logger.info("[Finetune] Starting remote training...") + job_id = start_remote_training(compute_config=compute_config, **train_kwargs_copy) + job.job_id = job_id + logger.info(f"[Finetune] Remote training started. Job ID: {job_id}") + + logger.info("[Finetune] Waiting for training to complete...") + wait_for_training(job.job_id) + logger.info("[Finetune] Training completed.") + + logger.info("[Finetune] Retrieving model information...") + model_info = get_model_info(job.job_id) + logger.info(f"[Finetune] Model info retrieved: {model_info}") + + storage_uri = model_info["storage_uri"] + logger.info(f"[Finetune] Copying LoRA weights from {storage_uri}...") + model_names, lora_dynamic_path = copy_lora_weights(storage_uri, model_info, job.job_id) + logger.info(f"[Finetune] LoRA weights copied. Model names: {model_names}") + + + logger.info("[Finetune] Setting result in future object...") + model_step_pairs = sorted([(model_name, int(model_name.split("-")[-1])) for model_name in model_names], key=lambda x: x[1]) + last_model_checkpoint = model_step_pairs[-1][0] + logger.info("[Finetune] Training process completed successfully.") + + logger.info("[Finetune] Updating model config with the proper dynamic path") + serve_config_path = train_kwargs.pop("serve_config_path", "serve_1B.yaml") + update_model_config(lora_dynamic_path, serve_config_path, job_id) + job.model_names = model_names + + return last_model_checkpoint + +def wait_for_training(job_id): + """Wait for the training to complete.""" + anyscale.job.wait(id=job_id, timeout_s=18000) + + +def update_model_config(lora_dynamic_path: str, serve_config_path: str, job_id: str): + """Update the model config storage location with the job_id.""" + with open(serve_config_path, "r") as f: + serve_config = yaml.safe_load(f) + + model_config_location = serve_config["applications"][0]["args"]["llm_configs"][0] + + with open(model_config_location, "r") as f: + model_config = yaml.safe_load(f) + + dynamic_path_until_job_id = lora_dynamic_path.split(job_id)[0] + job_id + model_config["lora_config"]["dynamic_lora_loading_path"] = dynamic_path_until_job_id + + with open(model_config_location, "w") as f: + yaml.safe_dump(model_config, f) + + +def verify_dataset(dataset: List[dict[str, Any]]) -> bool: + """Verify the training arguments before starting training.""" + dataset_validation = openai_data_validation(dataset) + + if dataset_validation: + logger.error(f"Dataset validation failed: {dataset_validation}") + return False + + return True + + +def submit_data(train_path: str): + """Upload the data to the Workspace cloud storage.""" + storage = os.environ['ANYSCALE_ARTIFACT_STORAGE'] + + datasets = {"train": train_path} + + fine_tuning_file_ids = {} + for name, path in datasets.items(): + num_items = len(read_jsonl(path)) + logger.info(f"Number of items in {name} data: {num_items}") + + remote_path = os.path.join(storage, path.split("/")[-1]) + logger.info(f"Uploading {name} data to S3 at {remote_path}") + if remote_path[:2] == "s3": + os.system(f"aws s3 cp {path} {remote_path}") + elif remote_path[:2] == "gs": + os.system(f"gcloud storage cp {path} {remote_path}") + else: + os.system(f"cp {path} {remote_path}") + logger.info(f"Copied {path} to {remote_path}") + fine_tuning_file_ids[name] = remote_path + + return fine_tuning_file_ids["train"], fine_tuning_file_ids.get("val", None) + + +def generate_config_files(train_path: str, **kwargs): + base_model_yaml_path = kwargs.get("train_config_yaml", None) + assert kwargs["model"] is not None, "Model is required to generate the config files" + + use_lora = kwargs.get("use_lora", False) + example_dir = "" + lora_path = "configs/training/lora" if use_lora else "configs/training/full_param" + + + if not base_model_yaml_path: + def get_yaml_config(model_name): + if "llama" in model_name.lower(): + if "70b" in model_name: + return "llama-3-70b.yaml" + elif "13b" in model_name: + return "llama-3-70b.yaml" + else: + return "llama-3-8b.yaml" + elif "mistral" in model_name.lower(): + if "mixtral" in model_name.lower(): + return "mixtral-8x7b.yaml" + else: + return "mistral-7b.yaml" + else: + raise RuntimeError("No default yaml found for the model") + + default_model_yaml_path = get_yaml_config(kwargs["model"]) + base_model_yaml_path = os.path.join(example_dir, lora_path, default_model_yaml_path) + logger.info(f"Using default yaml template for model: {base_model_yaml_path}") + + model_config_data = yaml.safe_load(open(base_model_yaml_path, "r")) + model_config_data.update(kwargs.get("hyperparameters", {})) + + model_config_data["model_id"] = kwargs["model"] + + custom_modifications = { + "model_id": kwargs["model"], + "train_path": train_path, + "logger": { + "provider": "wandb", + }, + "num_checkpoints_to_keep": 10 + } + if kwargs.get("output_dir", None): + custom_modifications["output_dir"] = kwargs["output_dir"] + + model_config_data.update(custom_modifications) + model_config_data = {k: v for k, v in model_config_data.items() if v is not None} + + def freeze(d): + if isinstance(d, dict): + return tuple(sorted((key, freeze(value)) for key, value in d.items())) + elif isinstance(d, list): + return tuple(freeze(value) for value in sorted(d)) + elif isinstance(d, set): + return tuple(freeze(value) for value in sorted(d)) + return d + + def hash_dict(d): + return hash(freeze(d)) + dict_sorted_hash = hash_dict(model_config_data) + if dict_sorted_hash < 0: + dict_sorted_hash = -dict_sorted_hash + filename = f"model_config_dspy_{dict_sorted_hash}.yaml" + logger.info(f"Model config data: {model_config_data}") + yaml.safe_dump(model_config_data, open(filename, "w")) + + ft_path = os.path.join("utils", "ft.py") + + compute_config_dict = { + "name": "dspy-llmforge-fine-tuning-job", + "entrypoint": f"llmforge anyscale finetune {filename}", + "working_dir": ".", + "image_uri": "localhost:5555/anyscale/llm-forge:0.5.6", + "requirements": [ + "wandb", + ], + "env_vars": { + "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", ""), + "HF_TOKEN": os.environ.get("HF_TOKEN", ""), + "HF_HOME": os.environ.get("HF_HOME", ""), + } + } + compute_config_kwargs = kwargs.get("compute_config", {}) + compute_config_dict.update(compute_config_kwargs) + compute_config = JobConfig(**compute_config_dict) + + job_runner_config_path = kwargs.get("compute_yaml_path", "job_runner_config.yaml") + + return job_runner_config_path, compute_config + + +def start_remote_training(compute_config, **kwargs) -> str: + job_id: str = anyscale.job.submit(compute_config) + return job_id + + +def wait_for_training(job_id): + logger.info("Waiting for training to complete") + anyscale.job.wait(id=job_id, timeout_s=18000) + + +def get_model_info(job_id): + return anyscale.llm.model.get(job_id=job_id).to_dict() + + +def copy_lora_weights(storage_uri, model_info, job_id): + try: + from google.cloud import storage + + storage_client = storage.Client() + + bucket_name = storage_uri.split('/')[2] + source_folder = '/'.join(storage_uri.split('/')[3:-1]) + logger.info(f"Source folder: {source_folder}") + + bucket = storage_client.bucket(bucket_name) + + blobs = bucket.list_blobs(prefix=source_folder) + + subfolders = set() + for blob in blobs: + if '/' in blob.name[len(source_folder):]: + subfolder = blob.name.split('/')[:-1] + subfolders.add('/'.join(subfolder)) + + base_model_id = model_info["base_model_id"] + lora_dynamic_path = f"dspy/lora_weights/{job_id}/{base_model_id}" + + model_names = [] + for subfolder in subfolders: + subfolder_name = subfolder.split('/')[-1] + destination_folder = f"{lora_dynamic_path}:{subfolder_name}" + if subfolder_name.startswith("epoch"): + model_names.append("/".join(destination_folder.split("/")[-2:])) + else: + continue + + subfolder_blobs = bucket.list_blobs(prefix=subfolder) + + for blob in subfolder_blobs: + source_blob = bucket.blob(blob.name) + destination_blob_name = f"{destination_folder}/{blob.name.split('/')[-1]}" + bucket.copy_blob(source_blob, bucket, destination_blob_name) + logger.info(f"Copied {source_blob.name} to {destination_blob_name}") + + logger.info(f"All subfolders copied to: gs://{bucket_name}/{lora_dynamic_path}") + completed_path = f"gs://{bucket_name}/{lora_dynamic_path}" + return model_names, completed_path + + except Exception as e: + logger.error(f"An error occurred: {str(e)}") + raise e + + +def read_jsonl(filename): + with open(filename, "r") as f: + return [json.loads(line) for line in f] diff --git a/dspy/clients/finetune.py b/dspy/clients/finetune.py new file mode 100644 index 0000000000..d9d0df8e8e --- /dev/null +++ b/dspy/clients/finetune.py @@ -0,0 +1,129 @@ +import os +from abc import abstractmethod +from concurrent.futures import Future +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional + +import ujson +from datasets.fingerprint import Hasher + + +def get_finetune_directory() -> str: + """Get the directory to save the fine-tuned models.""" + # TODO: Move to a centralized location with all the other env variables + dspy_cachedir = os.environ.get("DSPY_CACHEDIR") + dspy_cachedir = dspy_cachedir or os.path.join(Path.home(), ".dspy_cache") + finetune_dir = os.path.join(dspy_cachedir, "finetune") + finetune_dir = os.path.abspath(finetune_dir) + return finetune_dir + + +FINETUNE_DIRECTORY = get_finetune_directory() + + +class TrainingMethod(str, Enum): + """Enum class for training methods. + + When comparing enums, Python checks for object IDs, which means that the + enums can't be compared directly. Subclassing the Enum class along with the + str class allows for direct comparison of the enums. + """ + + SFT = "SFT" + Preference = "Preference" + + +class TrainingStatus(str, Enum): + """Enum class for remote training status.""" + + not_started = "not_started" + pending = "pending" + running = "running" + succeeded = "succeeded" + failed = "failed" + cancelled = "cancelled" + + +"""Dictionary mapping training methods to the data keys they require.""" +TRAINING_METHOD_TO_DATA_KEYS = { + TrainingMethod.SFT: ["prompt", "completion"], + TrainingMethod.Preference: ["prompt", "chosen", "rejected"], +} + + +class FinetuneJob(Future): + def __init__( + self, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + train_method: TrainingMethod = TrainingMethod.SFT, + provider: str = "openai", + ): + self.model = model + self.train_data = train_data + self.train_kwargs: Dict[str, Any] = train_kwargs or {} + self.train_method = train_method + self.provider = provider + super().__init__() + + def get_kwargs(self): + return dict( + model=self.model, + train_data=self.train_data, + train_kwargs=self.train_kwargs, + train_method=self.train_method, + provider=self.provider, + ) + + def __repr__(self): + return str(self) + + # Subclasses should override the cancel method to cancel the finetune job; + # then call the super's cancel method so that the future can be cancelled. + def cancel(self): + """Cancel the finetune job.""" + super().cancel() + + @abstractmethod + def status(self): + """Get the status of the finetune job.""" + raise NotImplementedError("Method `status` is not implemented.") + + +def validate_finetune_data(data: List[Dict[str, Any]], train_method: TrainingMethod): + """Validate the finetune data based on the training method.""" + # Get the required data keys for the training method + required_keys = TRAINING_METHOD_TO_DATA_KEYS[train_method] + + # Check if the training data has the required keys + for ind, data_dict in enumerate(data): + if not all([key in data_dict for key in required_keys]): + raise ValueError( + f"The datapoint at index {ind} is missing the keys required for {train_method} training. Expected: " + f"{required_keys}, Found: {data_dict.keys()}" + ) + + +def save_data( + data: List[Dict[str, Any]], + provider_name: Optional[str] = None, +) -> str: + """Save the fine-tuning data to a file.""" + # Construct the file name based on the data hash + hash = Hasher.hash(data) + file_name = f"{hash}.jsonl" + file_name = f"{provider_name}_{file_name}" if provider_name else file_name + + # Find the directory to save the fine-tuning data + finetune_parent_dir = get_finetune_directory() + os.makedirs(finetune_parent_dir, exist_ok=True) + + # Save the data to a file + file_path = os.path.join(finetune_parent_dir, file_name) + file_path = os.path.abspath(file_path) + with open(file_path, "w") as f: + for item in data: + f.write(ujson.dumps(item) + "\n") + return file_path diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 186c622b9f..f2949372ac 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -1,15 +1,25 @@ +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime import functools import os -import uuid -from datetime import datetime from pathlib import Path +from typing import Any, Dict, List, Optional +import ujson +import uuid + +from dspy.utils.logging import logger +from dspy.clients.finetune import FinetuneJob, TrainingMethod +from dspy.clients.lm_finetune_utils import ( + get_provider_finetune_job_class, + execute_finetune_job, +) import litellm -import ujson from litellm.caching import Cache -disk_cache_dir = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") -litellm.cache = Cache(disk_cache_dir=disk_cache_dir, type="disk") + +DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") +litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") litellm.telemetry = False if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ: @@ -17,13 +27,26 @@ class LM: - def __init__(self, model, model_type="chat", temperature=0.0, max_tokens=1000, cache=True, **kwargs): + def __init__( + self, + model, + model_type='chat', + temperature=0.0, + max_tokens=1000, + cache=True, + launch_kwargs=None, + **kwargs + ): + # Remember to update LM.copy() if you modify the constructor! self.model = model self.model_type = model_type self.cache = cache + self.launch_kwargs = launch_kwargs or {} self.kwargs = dict(temperature=temperature, max_tokens=max_tokens, **kwargs) self.history = [] + # TODO: Arbitrary model strings could include the substring "o1-". We + # should find a more robust way to check for the "o1-" family models. if "o1-" in model: assert ( max_tokens >= 5000 and temperature == 1.0 @@ -59,10 +82,56 @@ def __call__(self, prompt=None, messages=None, **kwargs): self.history.append(entry) return outputs - + def inspect_history(self, n: int = 1): _inspect_history(self, n) + def launch(self): + """Send a request to the provider to launch the model, if needed.""" + msg = f"`launch()` is called for the auto-launched model {self.model}" + msg += " -- no action is taken!" + logger.info(msg) + + def kill(self): + """Send a request to the provider to kill the model, if needed.""" + msg = f"`kill()` is called for the auto-launched model {self.model}" + msg += " -- no action is taken!" + logger.info(msg) + + def finetune( + self, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]]=None, + train_method: TrainingMethod = TrainingMethod.SFT, + provider: str = "openai", + cache_finetune: bool = True, + ) -> FinetuneJob: + """Start model fine-tuning, if supported.""" + from dspy import settings as settings + err = "Fine-tuning is an experimental feature." + err += " Set `dspy.settings.experimental` to `True` to use it." + assert settings.experimental, err + + FinetuneJobClass = get_provider_finetune_job_class(provider=provider) + finetune_job = FinetuneJobClass( + model=self.model, + train_data=train_data, + train_kwargs=train_kwargs, + train_method=train_method, + provider=provider + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit( + execute_finetune_job, + finetune_job, + lm=self, + cache_finetune=cache_finetune + ) + executor.shutdown(wait=False) + + return finetune_job + def copy(self, **kwargs): """Returns a copy of the language model with possibly updated parameters.""" @@ -99,6 +168,7 @@ def litellm_text_completion(request, cache={"no-cache": True, "no-store": True}) kwargs = ujson.loads(request) # Extract the provider and model from the model string. + # TODO: Not all the models are in the format of "provider/model" model = kwargs.pop("model").split("/", 1) provider, model = model[0] if len(model) > 1 else "openai", model[-1] diff --git a/dspy/clients/lm_finetune_utils.py b/dspy/clients/lm_finetune_utils.py new file mode 100644 index 0000000000..c9b4863d03 --- /dev/null +++ b/dspy/clients/lm_finetune_utils.py @@ -0,0 +1,91 @@ +from typing import Any, Dict, List, Optional, Type, Union + +from dspy.clients.anyscale import FinetuneJobAnyScale, finetune_anyscale +from dspy.clients.finetune import FinetuneJob, TrainingMethod +from dspy.clients.openai import FinetuneJobOpenAI, finetune_openai +from dspy.utils.logging import logger + +_PROVIDER_ANYSCALE = "anyscale" +_PROVIDER_OPENAI = "openai" + + +def get_provider_finetune_job_class(provider: str) -> Type[FinetuneJob]: + """Get the FinetuneJob class for the provider.""" + provider_to_job_class = { + _PROVIDER_ANYSCALE: FinetuneJobAnyScale, + _PROVIDER_OPENAI: FinetuneJobOpenAI, + } + return provider_to_job_class[provider] + + +def get_provider_finetune_function(provider: str) -> callable: + """Return the finetune function for the given model.""" + provider_to_finetune_function = { + _PROVIDER_ANYSCALE: finetune_anyscale, + _PROVIDER_OPENAI: finetune_openai, + } + return provider_to_finetune_function[provider] + + +# Note: Type of LM should be LM. We aren't importing it here to avoid +# circular imports. +def execute_finetune_job(job: FinetuneJob, lm: Any, cache_finetune: bool = True): + """Execute the finetune job in a blocking manner.""" + try: + job_kwargs = job.get_kwargs() + if cache_finetune: + model = cached_finetune(job=job, **job_kwargs) + else: + model = finetune(job=job, **job_kwargs) + lm = lm.copy(model=model) + job.set_result(lm) + except Exception as err: + logger.error(err) + job.set_result(err) + + +# TODO: Add DiskCache, ignore job +def cached_finetune( + job, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + train_method: TrainingMethod = TrainingMethod.SFT, + provider: str = "openai", +) -> Union[str, Exception]: + return finetune( + job=job, + model=model, + train_data=train_data, + train_kwargs=train_kwargs, + train_method=train_method, + provider=provider, + ) + + +def finetune( + job, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + train_method: TrainingMethod = TrainingMethod.SFT, + provider: str = "openai", +) -> Union[str, Exception]: + """Fine-tune a new model based on the given model.""" + # Get the fine-tuning provider + try: + # Get the finetune function + provider_finetune_function = get_provider_finetune_function(provider) + + # Fine-tune a new model based on the given model + model = provider_finetune_function( + job=job, + model=model, + train_data=train_data, + train_kwargs=train_kwargs, + train_method=train_method, + ) + except Exception as err: + raise err + + return model diff --git a/dspy/clients/openai.py b/dspy/clients/openai.py new file mode 100644 index 0000000000..ec084fe5f5 --- /dev/null +++ b/dspy/clients/openai.py @@ -0,0 +1,358 @@ +import re +import time +from collections import defaultdict +from typing import Any, Dict, List, Optional, Union + +import openai + +from dspy.clients.finetune import ( + FinetuneJob, + TrainingMethod, + TrainingStatus, + save_data, + validate_finetune_data, +) +from dspy.utils.logging import logger + +# Provider name +PROVIDER_OPENAI = "openai" + + +def is_openai_model(model: str) -> bool: + """Check if the model is an OpenAI model.""" + # Filter the provider_prefix, if exists + provider_prefix = f"{PROVIDER_OPENAI}/" + if model.startswith(provider_prefix): + model = model[len(provider_prefix) :] + + client = openai.OpenAI() + valid_model_names = [model.id for model in client.models.list().data] + # Check if the model is a base OpenAI model + if model in valid_model_names: + return True + + # Check if the model is a fine-tuned OpneAI model. Fine-tuned OpenAI models + # have the prefix "ft::", followed by a string specifying + # the fine-tuned model. The following RegEx pattern is used to match the + # base model name. + # TODO: This part can be updated to match the actual fine-tuned model names + # by making a call to the OpenAI API to be more exact, but this might + # require an API key with the right permissions. + match = re.match(r"ft:([^:]+):", model) + if match and match.group(1) in valid_model_names: + return True + + return False + + +class FinetuneJobOpenAI(FinetuneJob): + def __init__(self, *args, **kwargs): + self.provider_file_id = None # TODO: Can we get this using the job_id? + self.provider_job_id = None + super().__init__(*args, **kwargs) + + def cancel(self): + # Cancel the provider job + if _does_job_exist(self.provider_job_id): + status = _get_training_status(self.provider_job_id) + if _is_terminal_training_status(status): + err_msg = "Jobs that are complete cannot be canceled." + err_msg += f" Job with ID {self.provider_job_id} is done." + raise Exception(err_msg) + openai.fine_tuning.jobs.cancel(self.provider_job_id) + self.provider_job_id = None + + # Delete the provider file + # TODO: Should there be a separate clean method? + if self.provider_file_id is not None: + if _does_file_exist(self.provider_file_id): + openai.files.delete(self.provider_file_id) + self.provider_file_id = None + + # Call the super's cancel method after the custom cancellation logic + super().cancel() + + def status(self) -> TrainingStatus: + status = _get_training_status(self.provider_job_id) + return status + + +def finetune_openai( + job: FinetuneJobOpenAI, + model: str, + train_data: List[Dict[str, Any]], + train_kwargs: Optional[Dict[str, Any]] = None, + train_method: TrainingMethod = TrainingMethod.SFT, +) -> str: + train_kwargs = train_kwargs or {} + train_method = TrainingMethod.SFT # Note: This could be an argument; ignoring method + + # Validate train data and method + logger.info("[Finetune] Validating the formatting of the data") + _validate_data(train_data, train_method) + logger.info("[Finetune] Done!") + + # Convert to the OpenAI format + logger.info("[Finetune] Converting the data to the OpenAI format") + # TODO: Should we use the system prompt? + train_data = _convert_data(train_data) + logger.info("[Finetune] Done!") + + # Save to a file + logger.info("[Finetune] Saving the data to a file") + data_path = save_data(train_data, provider_name=PROVIDER_OPENAI) + logger.info("[Finetune] Done!") + + # Upload the data to the cloud + logger.info("[Finetune] Uploading the data to the provider") + provider_file_id = _upload_data(data_path) + job.provider_file_id = provider_file_id + logger.info("[Finetune] Done!") + + logger.info("[Finetune] Start remote training") + # We utilize model and train_kwargs here + provider_job_id = _start_remote_training( + train_file_id=job.provider_file_id, + model=model, + train_kwargs=train_kwargs, + ) + job.provider_job_id = provider_job_id + # job.provider_job_id = "ftjob-ZdEL1mUDk0dwdDuZJQOng8Vv" + logger.info("[Finetune] Done!") + + logger.info("[Finetune] Wait for training to complete") + # TODO: Would it be possible to stream the logs? + _wait_for_job(job) + logger.info("[Finetune] Done!") + + logger.info("[Finetune] Get trained model if the run was a success") + model = _get_trained_model(job) + logger.info("[Finetune] Done!") + + return model + + +_SUPPORTED_TRAINING_METHODS = [ + TrainingMethod.SFT, +] + + +def _get_training_status(job_id: str) -> Union[TrainingStatus, Exception]: + # TODO: Should this type be shared across all fine-tune clients? + provider_status_to_training_status = { + "validating_files": TrainingStatus.pending, + "queued": TrainingStatus.pending, + "running": TrainingStatus.running, + "succeeded": TrainingStatus.succeeded, + "failed": TrainingStatus.failed, + "cancelled": TrainingStatus.cancelled, + } + + # Check if there is an active job + if job_id is None: + logger.info("There is no active job.") + return TrainingStatus.not_started + + err_msg = f"Job with ID {job_id} does not exist." + assert _does_job_exist(job_id), err_msg + + # Retrieve the provider's job and report the status + provider_job = openai.fine_tuning.jobs.retrieve(job_id) + provider_status = provider_job.status + status = provider_status_to_training_status[provider_status] + + return status + + +def _does_job_exist(job_id: str) -> bool: + try: + # TODO: Error handling is vague + openai.fine_tuning.jobs.retrieve(job_id) + return True + except Exception: + return False + + +def _does_file_exist(file_id: str) -> bool: + try: + # TODO: Error handling is vague + openai.files.retrieve(file_id) + return True + except Exception: + return False + + +def _is_terminal_training_status(status: TrainingStatus) -> bool: + return status in [ + TrainingStatus.succeeded, + TrainingStatus.failed, + TrainingStatus.cancelled, + ] + + +def _validate_data(data: Dict[str, str], train_method: TrainingMethod) -> Optional[Exception]: + # Check if this train method is supported + if train_method not in _SUPPORTED_TRAINING_METHODS: + err_msg = f"OpenAI does not support the training method {train_method}." + raise ValueError(err_msg) + + validate_finetune_data(data, train_method) + + +def _convert_data( + data: List[Dict[str, str]], + system_prompt: Optional[str] = None, +) -> Union[List[Dict[str, Any]], Exception]: + # Item-wise conversion function + def _row_converter(d): + messages = [{"role": "user", "content": d["prompt"]}, {"role": "assistant", "content": d["completion"]}] + if system_prompt: + messages.insert(0, {"role": "system", "content": system_prompt}) + messages_dict = {"messages": messages} + return messages_dict + + # Convert the data to the OpenAI format; validate the converted data + converted_data = list(map(_row_converter, data)) + openai_data_validation(converted_data) + return converted_data + + +def _upload_data(data_path: str) -> str: + # Upload the data to the provider + provider_file = openai.files.create( + file=open(data_path, "rb"), + purpose="fine-tune", + ) + return provider_file.id + + +def _start_remote_training(train_file_id: str, model: id, train_kwargs: Optional[Dict[str, Any]] = None) -> str: + train_kwargs = train_kwargs or {} + provider_job = openai.fine_tuning.jobs.create( + model=model, + training_file=train_file_id, + hyperparameters=train_kwargs, + ) + return provider_job.id + + +def _wait_for_job( + job: FinetuneJobOpenAI, + poll_frequency: int = 60, +): + while not _is_terminal_training_status(job.status()): + time.sleep(poll_frequency) + + +def _get_trained_model(job): + status = job.status() + if status != TrainingStatus.succeeded: + err_msg = f"Job status is {status}." + err_msg += f" Must be {TrainingStatus.succeeded} to retrieve the model." + logger.error(err_msg) + raise Exception(err_msg) + + provider_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id) + finetuned_model = provider_job.fine_tuned_model + return finetuned_model + + +# Adapted from https://cookbook.openai.com/examples/chat_finetuning_data_prep +def openai_data_validation(dataset: List[dict[str, Any]]): + format_errors = defaultdict(int) + for ex in dataset: + if not isinstance(ex, dict): + format_errors["data_type"] += 1 + continue + + messages = ex.get("messages", None) + if not messages: + format_errors["missing_messages_list"] += 1 + continue + + for message in messages: + if "role" not in message or "content" not in message: + format_errors["message_missing_key"] += 1 + + if any(k not in ("role", "content", "name", "function_call", "weight") for k in message): + format_errors["message_unrecognized_key"] += 1 + + if message.get("role", None) not in ("system", "user", "assistant", "function"): + format_errors["unrecognized_role"] += 1 + + content = message.get("content", None) + function_call = message.get("function_call", None) + + if (not content and not function_call) or not isinstance(content, str): + format_errors["missing_content"] += 1 + + if not any(message.get("role", None) == "assistant" for message in messages): + format_errors["example_missing_assistant_message"] += 1 + + # Raise an error if there are any format errors + if format_errors: + err_msg = "Found errors in the dataset format using the OpenAI API." + err_msg += " Here are the number of datapoints for each error type:" + for k, v in format_errors.items(): + err_msg += "\n {k}: {v}" + raise ValueError(err_msg) + + +def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]: + n_missing_system = 0 + n_missing_user = 0 + n_messages = [] + convo_lens = [] + assistant_message_lens = [] + + for ex in dataset: + messages = ex["messages"] + if not any(message["role"] == "system" for message in messages): + n_missing_system += 1 + if not any(message["role"] == "user" for message in messages): + n_missing_user += 1 + n_messages.append(len(messages)) + convo_lens.append(num_tokens_from_messages(messages)) + assistant_message_lens.append(num_assistant_tokens_from_messages(messages)) + n_too_long = sum([length > 16385 for length in convo_lens]) + + if n_too_long > 0: + logger.info( + f"There are {n_too_long} examples that may be over the 16,385 token limit, they will be truncated during fine-tuning." + ) + + if n_missing_system > 0: + logger.info(f"There are {n_missing_system} examples that are missing a system message.") + + if n_missing_user > 0: + logger.info(f"There are {n_missing_user} examples that are missing a user message.") + + return convo_lens + + +def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1): + import tiktoken + + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + return num_tokens + + +def num_assistant_tokens_from_messages(messages): + import tiktoken + + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + if message["role"] == "assistant": + num_tokens += len(encoding.encode(message["content"])) + return num_tokens diff --git a/dspy/predict/chain_of_thought.py b/dspy/predict/chain_of_thought.py index 74656697ca..b6b9cb575e 100644 --- a/dspy/predict/chain_of_thought.py +++ b/dspy/predict/chain_of_thought.py @@ -29,7 +29,7 @@ def __init__(self, signature, rationale_type=None, activated=True, **config): rationale_type = rationale_type or dspy.OutputField(prefix=prefix, desc=desc) # Add "rationale" field to the output signature. - if isinstance(dspy.settings.lm, dspy.LM): + if isinstance(dspy.settings.lm, dspy.LM) or dspy.settings.experimental: extended_signature = signature.prepend("reasoning", rationale_type, type_=str) else: extended_signature = signature.prepend("rationale", rationale_type, type_=str) diff --git a/dspy/primitives/program.py b/dspy/primitives/program.py index ac3646275f..3cf6220b6c 100644 --- a/dspy/primitives/program.py +++ b/dspy/primitives/program.py @@ -1,18 +1,12 @@ import magicattr +import dspy from dspy.primitives.assertions import * from dspy.primitives.module import BaseModule class ProgramMeta(type): pass - # def __call__(cls, *args, **kwargs): - # obj = super(ProgramMeta, cls).__call__(*args, **kwargs) - - # if issubclass(cls, Program) and not getattr(obj, "_program_init_called", False): - # obj._base_init() - # obj._program_init_called = True - # return obj class Module(BaseModule, metaclass=ProgramMeta): @@ -32,23 +26,29 @@ def named_predictors(self): def predictors(self): return [param for _, param in self.named_predictors()] - + def set_lm(self, lm): - import dspy - assert dspy.settings.experimental, "Setting the lm is an experimental feature." + if not dspy.settings.experimental: + raise ValueError( + "Setting or getting the LM of a program is an experimental feature. Please enable the " + "'dspy.settings.experimental' flag to use these features." + ) for _, param in self.named_predictors(): param.lm = lm def get_lm(self): - import dspy - assert dspy.settings.experimental, "Getting the lm is an experimental feature." + if not dspy.settings.experimental: + raise ValueError( + "Setting or getting the LM of a program is an experimental feature. Please enable the " + "'dspy.settings.experimental' flag to use these features." + ) all_used_lms = [param.lm for _, param in self.named_predictors()] if len(set(all_used_lms)) == 1: return all_used_lms[0] - + raise ValueError("Multiple LMs are being used in the module.") def __repr__(self): @@ -95,4 +95,5 @@ def activate_assertions(self, handler=backtrack_handler, **handler_args): def set_attribute_by_name(obj, name, value): magicattr.set(obj, name, value) -Program = Module \ No newline at end of file + +Program = Module diff --git a/dspy/teleprompt/finetune_teleprompter.py b/dspy/teleprompt/finetune_teleprompter.py new file mode 100644 index 0000000000..c519035d08 --- /dev/null +++ b/dspy/teleprompt/finetune_teleprompter.py @@ -0,0 +1,146 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import dspy +from dspy.evaluate.evaluate import Evaluate +from dspy.primitives.example import Example +from dspy.primitives.prediction import Prediction +from dspy.primitives.program import Program +from dspy.utils.logging import logger + + +# TODO: Shared below are useful functions. Similar procedures are implemented +# separately and used by other DSPy teleprompters. These can be moved to shared +# locations. +def prepare_teacher(student: Program, teacher: Program = None) -> Program: + """Prepare the teacher program with respect to the student program. + Args: + student: The student program. + teacher: The teacher program. If `None`, a copy of the student program + is used as the teacher. Defaults to `None`. + """ + # If teacher is None, use a copy of the student program as the teacher + if teacher is None: + logger.info("No teacher provided. Using a copy of the student program as the teacher.") + teacher = student.deepcopy() + else: + teacher = teacher.deepcopy() + + # Ensure that the student and teacher programs have the same structure + logger.info("Ensuring that the student and teacher are are structurally equivalent.") + student._assert_structural_equivalency(teacher) + + # Ensure that the predictors of the programs point to different objects + logger.info("Ensuring that the student and teacher programs do not share predictors.") + student._assert_no_shared_predictor(teacher) + + # Ensure that the LM consistency property is satisfied + logger.info("Ensuring that the teacher program satisfies the LM consistency property.") + teacher._assert_lm_consistency() + + # If the global LM is being used, set it to the LMs of the copied teacher + # program predictors to to avoid handling the same edge cases later + if dspy.settings.lm: + teacher._set_all_predictor_lms(dspy.settings.lm) + + return teacher + + +def convert_to_module_level_message_data( + data: List[Dict], + keep_data_keys: bool = False, + exclude_demos: bool = False, + try_to_record_lm_kwargs: bool = False, + program: Program = None, +) -> List[Dict]: + """Wrapper around the function + `build_messages_from_trace`, calling it on the "trace" field + of each dictionary in the input data list and combiningin the results into + a list of prompt-completion data dictionaries.""" + + prompt_completion_data = [] + for data_dict in data: + trace = data_dict["trace"] + trace_prompt_comletion_data = build_messages_from_trace( + trace=trace, exclude_demos=exclude_demos, try_to_record_lm_kwargs=try_to_record_lm_kwargs, program=program + ) + for prompt_completion_dict in trace_prompt_comletion_data: + if keep_data_keys: + prompt_completion_dict = {**data_dict, **prompt_completion_dict} + prompt_completion_data.append(prompt_completion_dict) + return prompt_completion_data + + +def build_messages_from_trace( + trace: List[Dict], + exclude_demos: bool = False, + try_to_record_lm_kwargs: bool = False, + program: Program = None, +) -> Dict[str, List[Dict[str, Any]]]: + messages = [] + # If the program is provided, build the predictor index to name mapping + if program: + pred_ind_to_name = {ind: name for ind, (name, _) in enumerate(program.named_predictors())} + + # Build the prompt-completion data + + adapter = dspy.settings.adapter or dspy.ChatAdapter() + data = [] + + # TODO: Make sure that this works for multi-stage pipelines + for pred_ind, (pred, inputs, outputs) in enumerate(trace): + # Get the demos from the predictor if exclude_demos is False + demos = [] if exclude_demos else pred.demos + messages = adapter.format(pred.signature, demos, inputs) + messages.append( + adapter.format_turn(signature=pred.signature, values=outputs, role="assistant", incomplete=False) + ) + data.append(messages) + + return data + + +def bootstrap_data( + program: Program, + dataset: List[Example], + metric: Optional[Callable[[Example, Prediction, Optional[List]], Union[bool, int, float]]] = None, + num_threads=1, + max_errors: int = 0, +) -> List[Dict[str, Any]]: + """Bootstrap example, prediction, trace, example_ind, score data for the program using the dataset.""" + data = [] + + # Use Evaluate to call the program have the responses cached + cname = program.__class__.__name__ + info = f"Bootstrapping data on {len(dataset)} examples with the program {cname}, with {num_threads} threads" + logger.info(info) + evaluator = Evaluate( + devset=dataset, + num_threads=num_threads, + display_progress=True, + max_errors=max_errors, + provide_traceback=True, + ) + evaluator(program, metric=metric) + + data = [] + for example_ind, example in enumerate(dataset): + data_dict = bootstrap_one_example(example=example, example_ind=example_ind, program=program, metric=metric) + if data_dict is not None: + data.append(data_dict) + + return data + + +def bootstrap_one_example( + example: Any, example_ind: int, program: Program, metric: Optional[Callable] = None +) -> Dict[str, Any]: + with dspy.context(trace=[]): + prediction = program(**example.inputs()) + trace = dspy.settings.trace + score = metric(example, prediction, trace) if metric else None + + data_dict = {"example": example, "prediction": prediction, "trace": trace, "example_ind": example_ind} + if metric: + data_dict["score"] = score + + return data_dict diff --git a/dspy/teleprompt/random_search.py b/dspy/teleprompt/random_search.py index 57bd38150c..4fa8ba23fd 100644 --- a/dspy/teleprompt/random_search.py +++ b/dspy/teleprompt/random_search.py @@ -85,6 +85,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None max_labeled_demos=self.max_labeled_demos, teacher_settings=self.teacher_settings, max_rounds=self.max_rounds, + max_errors=self.max_errors, ) program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy) @@ -101,6 +102,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None max_labeled_demos=self.max_labeled_demos, teacher_settings=self.teacher_settings, max_rounds=self.max_rounds, + max_errors=self.max_errors, ) program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy)