diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index ddb0ee3e80..3467ca31fb 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -10,5 +10,5 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): value = self.parse(signature, output, _parse_values=_parse_values) assert set(value.keys()) == set(signature.output_fields.keys()), f"Expected {signature.output_fields.keys()} but got {value.keys()}" values.append(value) - + return values diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 5a20dcff32..4b6186c384 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -62,11 +62,28 @@ def parse(self, signature, completion, _parse_values=True): f"Error parsing field {k}: {e}.\n\n\t\tOn attempting to parse the value\n```\n{v}\n```" ) - if fields.keys() != signature.output_fields.keys(): - raise ValueError(f"Expected {signature.output_fields.keys()} but got {fields.keys()}") + if list(fields.keys()) != list(signature.output_fields.keys()): + raise ValueError(f"Expected {list(signature.output_fields.keys())} but got {fields.keys()}") return fields + def format_completion(self, signature, outputs): + reconstructed = [] + + fields_dict = signature.output_fields + + field_name_output_map = {field: outputs[field] for field in fields_dict.keys()} + + for field, value in field_name_output_map.items(): + reconstructed.append(f"[[ ## {field} ## ]]") + reconstructed.append(str(value)) + reconstructed.append("") # Add an empty line for separation + + reconstructed.append("[[ ## completed ## ]]") + + result = "\n".join(reconstructed).strip() + + return result def format_blob(blob): if "\n" not in blob and "«" not in blob and "»" not in blob: diff --git a/dspy/clients/anyscale.py b/dspy/clients/anyscale.py index dea3ca4c5e..5844abe55b 100644 --- a/dspy/clients/anyscale.py +++ b/dspy/clients/anyscale.py @@ -1,13 +1,27 @@ -import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union +import json +import yaml +import os +from concurrent.futures import ThreadPoolExecutor +from collections import defaultdict +import functools from dspy.utils.logging import logger from dspy.clients.finetune import ( FinetuneJob, TrainingMethod, - get_finetune_directory, - validate_finetune_data, + save_data, ) +from dspy.clients.openai import openai_data_validation, check_message_lengths +import asyncio + +try: + # Importing the AnyScale library for users in the AnyScale workspace, where + # the library is already installed. + from anyscale.job import JobConfig + import anyscale +except ImportError: + anyscale = None #------------------------------------------------------------------------------- # Variables @@ -18,6 +32,9 @@ TrainingMethod.SFT, ] +PROVIDER_ANYSCALE = "anyscale" + + #------------------------------------------------------------------------------- # Launching and killing LMs #------------------------------------------------------------------------------- @@ -38,7 +55,19 @@ def anyscale_model_kill(model: str, launch_kwargs: Dict[str, Any]): # Function and classes required for the fine-tune interface #------------------------------------------------------------------------------- + +def is_anyscale_model(model: str) -> bool: + """Check if the model is an AnyScale model.""" + # TODO: This needs to be implemented to support fine-tuning + logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py") + return False + + class FinetuneJobAnyScale(FinetuneJob): + + def __init__(self, *args, **kwargs): + self.job_id = None + super().__init__(*args, **kwargs) def cancel(self): """Cancel the finetune job.""" @@ -52,56 +81,304 @@ def status(self): raise NotImplementedError("Method `status` is not implemented.") -def is_anyscale_model(model: str) -> bool: - """Check if the model is an AnyScale model.""" - logger.info("Is AnyScale model is not implemented, returning False as a default to not break lm.py") - return False - - def finetune_anyscale( job: FinetuneJobAnyScale, model: str, - message_completion_pairs: List[Dict[str, str]], + train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]]=None, - ) -> str: - """Fine-tune with AnyScale.""" - # Fake fine-tuning - logger.info("[Finetune] Fake fine-tuning") - train_kwargs = train_kwargs or {} + ) -> Union[str, Exception]: + """Start the finetune job.""" + try: + train_kwargs = train_kwargs or {} + assert "model" not in train_kwargs, "Model should not be in the train_kwargs" + train_kwargs["model"] = model + train_method = TrainingMethod.SFT # Note: This could be an argument + + logger.info("[Finetune] Starting training process...") + if train_method not in TRAINING_METHODS_ANYSCALE: + raise NotImplementedError(f"AnyScale can only support {TRAINING_METHODS_ANYSCALE} for the time being") + + logger.info("[Finetune] Retrieving the `serve_config_path` if provided, otherwise defaulting to `serve_1B.yaml`...") + serve_config_path = train_kwargs.pop("serve_config_path", "serve_1B.yaml") + + logger.info("[Finetune] Validating the dataset format...") + if not verify_dataset(train_data): + # TODO: Does AnyScale support text completion models? + err = "[Finetune] Error: Unable to verify that the dataset is in the correct format." + logger.error(err) + raise RuntimeError(err) + + logger.info("[Finetune] Converting data to JSONL format...") + data_path = save_data(train_data, provider_name=PROVIDER_ANYSCALE) + + logger.info("[Finetune] Submitting data to remote storage...") + remote_train_path, _ = submit_data(train_path=data_path, eval_path=None) + logger.info(f"[Finetune] Data submitted. Remote train path: {remote_train_path}") + + logger.info("[Finetune] Generating configuration files...") + _, compute_config = generate_config_files(train_path=remote_train_path, eval_path=None, **train_kwargs) + + logger.info("[Finetune] Starting remote training...") + job_id = start_remote_training(compute_config=compute_config, **train_kwargs) + job.job_id = job_id + logger.info(f"[Finetune] Remote training started. Job ID: {job.job_id}") + + logger.info("[Finetune] Waiting for training to complete...") + wait_for_training(job.job_id) + logger.info("[Finetune] Training completed.") + + logger.info("[Finetune] Retrieving model information...") + model_info = get_model_info(job.job_id) + logger.info(f"[Finetune] Model info retrieved: {model_info}") + + storage_uri = model_info["storage_uri"] + logger.info(f"[Finetune] Copying LoRA weights from {storage_uri}...") + model_names, lora_dynamic_path = copy_lora_weights(storage_uri, model_info, job.job_id) + logger.info(f"[Finetune] LoRA weights copied. Model names: {model_names}") + + logger.info("[Finetune] Setting result in future object...") + last_model_checkpoint = model_names[-1] # TODO: Is this correct? + logger.info("[Finetune] Training process completed successfully.") + + # TODO: Is this the right place to call update_model_config? + logger.info("[Finetune] Updating model config with the proper dynamic path") + update_model_config(lora_dynamic_path, serve_config_path) + + return last_model_checkpoint + + except Exception as e: + logger.error(f"[Finetune] Error occurred during training: {str(e)}") + raise e + + +#------------------------------------------------------------------------------- +# Custom functions to support the finetune_* function +#------------------------------------------------------------------------------- + +def wait_for_training(job_id): + """Wait for the training to complete.""" + anyscale.job.wait(id=job_id, timeout_s=18000) + + +def update_model_config(lora_dynamic_path: str, serve_config_path: str): + """Update the model config storage location with the job_id.""" + with open(serve_config_path, "r") as f: + serve_config = yaml.safe_load(f) + + model_config_location = serve_config["applications"][0]["args"]["llm_configs"][0] + + with open(model_config_location, "r") as f: + model_config = yaml.safe_load(f) + + model_config["lora_config"]["dynamic_lora_loading_path"] = lora_dynamic_path + + with open(model_config_location, "w") as f: + yaml.safe_dump(model_config, f) + +def verify_dataset(dataset: List[dict[str, Any]]) -> bool: + """Verify the training arguments before starting training.""" + dataset_validation = openai_data_validation(dataset) + + if dataset_validation: + logger.error(f"Dataset validation failed: {dataset_validation}") + return False + + # TODO: Is this still useful? + convo_lens = check_message_lengths(dataset) + return True + + +def submit_data(train_path: str, eval_path: Optional[str]): + """Upload the data to the Workspace cloud storage.""" + storage = os.environ['ANYSCALE_ARTIFACT_STORAGE'] + + datasets = {"train": train_path} + if eval_path: + datasets["val"] = eval_path + + fine_tuning_file_ids = {} + for name, path in datasets.items(): + num_items = len(read_jsonl(path)) + logger.info(f"Number of items in {name} data: {num_items}") + + remote_path = os.path.join(storage, path.split("/")[-1]) + logger.info(f"Uploading {name} data to S3 at {remote_path}") + if remote_path[:2] == "s3": + os.system(f"aws s3 cp {path} {remote_path}") + elif remote_path[:2] == "gs": + os.system(f"gcloud storage cp {path} {remote_path}") + else: + os.system(f"cp {path} {remote_path}") + logger.info(f"Copied {path} to {remote_path}") + fine_tuning_file_ids[name] = remote_path + + return fine_tuning_file_ids["train"], fine_tuning_file_ids.get("val", None) + + +def generate_config_files(train_path: str, eval_path: Optional[str], **kwargs): + base_model_yaml_path = kwargs.get("train_config_yaml", None) + assert kwargs["model"] is not None, "Model is required to generate the config files" + + use_lora = kwargs.get("use_lora", False) + example_dir = "" + lora_path = "configs/training/lora" if use_lora else "configs/training/full_param" + + + if not base_model_yaml_path: + def get_yaml_config(model_name): + if "llama" in model_name.lower(): + if "70b" in model_name: + return "llama-3-70b.yaml" + elif "13b" in model_name: + return "llama-3-70b.yaml" + else: + return "llama-3-8b.yaml" + elif "mistral" in model_name.lower(): + if "mixtral" in model_name.lower(): + return "mixtral-8x7b.yaml" + else: + return "mistral-7b.yaml" + else: + raise RuntimeError("No default yaml found for the model") + + default_model_yaml_path = get_yaml_config(kwargs["model"]) + base_model_yaml_path = os.path.join(example_dir, lora_path, default_model_yaml_path) + logger.info(f"Using default yaml template for model: {base_model_yaml_path}") + + model_config_data = yaml.safe_load(open(base_model_yaml_path, "r")) + model_config_data.update(kwargs.get("hyperparameters", {})) + + model_config_data["model_id"] = kwargs["model"] + + custom_modifications = { + "model_id": kwargs["model"], + "train_path": train_path, + "valid_path": eval_path, + "logger": { + "provider": "wandb", + }, + "num_checkpoints_to_keep": 10 + } + if kwargs.get("output_dir", None): + custom_modifications["output_dir"] = kwargs["output_dir"] + + model_config_data.update(custom_modifications) + model_config_data = {k: v for k, v in model_config_data.items() if v is not None} + + def freeze(d): + if isinstance(d, dict): + return tuple(sorted((key, freeze(value)) for key, value in d.items())) + elif isinstance(d, list): + return tuple(freeze(value) for value in sorted(d)) + elif isinstance(d, set): + return tuple(freeze(value) for value in sorted(d)) + return d + + def hash_dict(d): + return hash(freeze(d)) + dict_sorted_hash = hash_dict(model_config_data) + if dict_sorted_hash < 0: + dict_sorted_hash = -dict_sorted_hash + filename = f"model_config_dspy_{dict_sorted_hash}.yaml" + logger.info(f"Model config data: {model_config_data}") + yaml.safe_dump(model_config_data, open(filename, "w")) + + ft_path = os.path.join("utils", "ft.py") + + compute_config_dict = { + "name": "dspy-llmforge-fine-tuning-job", + "entrypoint": f"llmforge anyscale finetune {filename}", + "working_dir": ".", + "image_uri": "localhost:5555/anyscale/llm-forge:0.5.6", + "requirements": [ + "wandb", + ], + "env_vars": { + "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", ""), + "HF_TOKEN": os.environ.get("HF_TOKEN", ""), + "HF_HOME": os.environ.get("HF_HOME", ""), + } + } + compute_config_kwargs = kwargs.get("compute_config", {}) + compute_config_dict.update(compute_config_kwargs) + compute_config = JobConfig(**compute_config_dict) + + job_runner_config_path = kwargs.get("compute_yaml_path", "job_runner_config.yaml") + + return job_runner_config_path, compute_config + + +def start_remote_training(compute_config, **kwargs) -> str: + job_id: str = anyscale.job.submit(compute_config) + return job_id + + +def wait_for_training(job): + logger.info("Waiting for training to complete") + anyscale.job.wait(id=job.job_id, timeout_s=18000) + + +def get_model_info(job_id): + return anyscale.llm.model.get(job_id=job_id).to_dict() + + +def copy_lora_weights(storage_uri, model_info, job_id): try: - logger.info("[Finetune] Validate the formatting of the fine-tuning data") - training_method = TrainingMethod.SFT # Hardcoded - validate_finetune_data(message_completion_pairs, training_method) - time.sleep(5) - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Saving the data to a file") - finetune_parent_dir = get_finetune_directory() - # We utilize message_completion_pairs here - time.sleep(1) - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Uploading the data to the cloud") - time.sleep(1) - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Launch training") - # We utilize model and train_kwargs here - time.sleep(1) - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Wait for training to complete") - time.sleep(1) - logger.info("[Finetune] Done!") - - logger.info("[Finetune] Get trained model client") - model = "anyscale_model" # Hardcoded - time.sleep(1) - logger.info("[Finetune] Done!") - - return model + from google.cloud import storage + + storage_client = storage.Client() + + bucket_name = storage_uri.split('/')[2] + source_folder = '/'.join(storage_uri.split('/')[3:-1]) + logger.info(f"Source folder: {source_folder}") + bucket = storage_client.bucket(bucket_name) + + blobs = bucket.list_blobs(prefix=source_folder) + + subfolders = set() + for blob in blobs: + if '/' in blob.name[len(source_folder):]: + subfolder = blob.name.split('/')[:-1] + subfolders.add('/'.join(subfolder)) + + base_model_id = model_info["base_model_id"] + lora_dynamic_path = f"dspy/lora_weights/{job_id}/{base_model_id}" + + model_names = [] + for subfolder in subfolders: + subfolder_name = subfolder.split('/')[-1] + destination_folder = f"{lora_dynamic_path}:{subfolder_name}" + if subfolder_name.startswith("epoch"): + model_names.append("/".join(destination_folder.split("/")[-2:])) + else: + continue + + subfolder_blobs = bucket.list_blobs(prefix=subfolder) + + for blob in subfolder_blobs: + source_blob = bucket.blob(blob.name) + destination_blob_name = f"{destination_folder}/{blob.name.split('/')[-1]}" + bucket.copy_blob(source_blob, bucket, destination_blob_name) + logger.info(f"Copied {source_blob.name} to {destination_blob_name}") + + logger.info(f"All subfolders copied to: gs://{bucket_name}/{lora_dynamic_path}") + completed_path = f"gs://{bucket_name}/{lora_dynamic_path}" + return model_names, completed_path + except Exception as e: - logger.error(f"[Finetune] Error: {e}") - raise e + logger.error(f"An error occurred: {str(e)}") + raise e + + +def read_jsonl(filename): + with open(filename, "r") as f: + return [json.loads(line) for line in f] + + +# TODO: Is it true that we don't need this function anymore? +def write_jsonl(filename, data): + with open(filename, "w") as f: + for item in data: + f.write(json.dumps(item) + "\n") diff --git a/dspy/clients/finetune.py b/dspy/clients/finetune.py index 3c14be3630..c4464fbf30 100644 --- a/dspy/clients/finetune.py +++ b/dspy/clients/finetune.py @@ -2,6 +2,7 @@ from concurrent.futures import Future from enum import Enum import os +from pathlib import Path from typing import List, Dict, Any, Optional import ujson @@ -12,8 +13,12 @@ # Set the directory to save the fine-tuned models def get_finetune_directory() -> str: """Get the directory to save the fine-tuned models.""" - alternative_path = os.path.join(os.getcwd(), '.dspy_finetune') - return os.environ.get('DSPY_FINETUNEDIR') or alternative_path + # TODO: Move to a centralized location with all the other env variables + dspy_cachedir = os.environ.get("DSPY_CACHEDIR") + dspy_cachedir = dspy_cachedir or os.path.join(Path.home(), ".dspy_cache") + finetune_dir = os.path.join(dspy_cachedir, 'finetune') + finetune_dir = os.path.abspath(finetune_dir) + return finetune_dir FINETUNE_DIRECTORY = get_finetune_directory() @@ -51,18 +56,18 @@ class FinetuneJob(Future): def __init__(self, model: str, - message_completion_pairs: List[Dict[str, str]], + train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]]=None, ): self.model = model - self.message_completion_pairs = message_completion_pairs + self.train_data = train_data self.train_kwargs: Dict[str, Any] = train_kwargs or {} super().__init__() def get_kwargs(self): return dict( model=self.model, - message_completion_pairs=self.message_completion_pairs, + train_data=self.train_data, train_kwargs=self.train_kwargs, ) diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 014025efc9..e6cfcff3df 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -24,6 +24,7 @@ import litellm from litellm.caching import Cache + DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache") litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk") litellm.telemetry = False @@ -32,11 +33,6 @@ os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" - -#------------------------------------------------------------------------------- -# LiteLLM Client -#------------------------------------------------------------------------------- - class LM: def __init__(self, model, @@ -128,7 +124,7 @@ def kill(self): logger.info(msg) def finetune(self, - message_completion_pairs: List[Dict[str, str]], + train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]]=None, cache_finetune: bool = True, ) -> FinetuneJob: @@ -149,7 +145,7 @@ def finetune(self, FinetuneJobClass = get_provider_finetune_job_class(provider=provider) finetune_job = FinetuneJobClass( model=self.model, - message_completion_pairs=message_completion_pairs, + train_data=train_data, train_kwargs=train_kwargs, ) @@ -164,6 +160,14 @@ def finetune(self, executor.shutdown(wait=False) return finetune_job + + def copy(self, **kwargs): + """Returns a copy of the language model with the same parameters.""" + kwargs = {**self.kwargs, **kwargs} + # model = kwargs.pop("model") or self.model + init_kwargs = dict(model=self.model, model_type=self.model_type, cache=self.cache, temperature=self.kwargs["temperature"], max_tokens=self.kwargs["max_tokens"]) + init_kwargs = {**init_kwargs, **kwargs} + return self.__class__(**init_kwargs) @functools.lru_cache(maxsize=None) @@ -248,12 +252,12 @@ def _inspect_history(lm, n: int = 1): from dspy.clients.openai import ( FinetuneJobOpenAI, is_openai_model, - finetune_openai + finetune_openai, ) from dspy.clients.anyscale import ( FinetuneJobAnyScale, is_anyscale_model, - finetune_anyscale + finetune_anyscale, ) @@ -333,13 +337,13 @@ def execute_finetune_job( def cached_finetune( job, model: str, - message_completion_pairs: List[Dict[str, str]], + train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]]=None, ) -> Union[str, ValueError]: return finetune( job=job, model=model, - message_completion_pairs=message_completion_pairs, + train_data=train_data, train_kwargs=train_kwargs, ) @@ -347,7 +351,7 @@ def cached_finetune( def finetune( job, model: str, - message_completion_pairs: List[Dict[str, str]], + train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]]=None, ) -> Union[str, Exception]: """Fine-tune a new model based on the given model.""" @@ -362,7 +366,7 @@ def finetune( model = provider_finetune_function( job=job, model=model, - message_completion_pairs=message_completion_pairs, + train_data=train_data, train_kwargs=train_kwargs, ) except Exception as err: diff --git a/dspy/clients/openai.py b/dspy/clients/openai.py index f7feb961a1..dae20b9d09 100644 --- a/dspy/clients/openai.py +++ b/dspy/clients/openai.py @@ -132,11 +132,10 @@ def status(self) -> TrainingStatus: def finetune_openai( job: FinetuneJobOpenAI, model: str, - message_completion_pairs: List[Dict[str, str]], + train_data: List[Dict[str, Any]], train_kwargs: Optional[Dict[str, Any]]=None, ) -> Union[str, Exception]: train_kwargs = train_kwargs or {} - train_data = message_completion_pairs train_method = TrainingMethod.SFT # Note: This could be an argument # Validate train data and method @@ -365,3 +364,59 @@ def openai_data_validation(dataset: List[dict[str, Any]]): for k, v in format_errors.items(): err_msg += "\n {k}: {v}" raise ValueError(err_msg) + + +def check_message_lengths(dataset: List[dict[str, Any]]) -> list[int]: + n_missing_system = 0 + n_missing_user = 0 + n_messages = [] + convo_lens = [] + assistant_message_lens = [] + + for ex in dataset: + messages = ex["messages"] + if not any(message["role"] == "system" for message in messages): + n_missing_system += 1 + if not any(message["role"] == "user" for message in messages): + n_missing_user += 1 + n_messages.append(len(messages)) + convo_lens.append(num_tokens_from_messages(messages)) + assistant_message_lens.append(num_assistant_tokens_from_messages(messages)) + n_too_long = sum([length > 16385 for length in convo_lens]) + + if n_too_long > 0: + logger.info(f"There are {n_too_long} examples that may be over the 16,385 token limit, they will be truncated during fine-tuning.") + + if n_missing_system > 0: + logger.info(f"There are {n_missing_system} examples that are missing a system message.") + + if n_missing_user > 0: + logger.info(f"There are {n_missing_user} examples that are missing a user message.") + + return convo_lens + + +def num_tokens_from_messages(messages, tokens_per_message=3, tokens_per_name=1): + import tiktoken + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + num_tokens += tokens_per_message + for key, value in message.items(): + num_tokens += len(encoding.encode(value)) + if key == "name": + num_tokens += tokens_per_name + num_tokens += 3 + return num_tokens + + +def num_assistant_tokens_from_messages(messages): + import tiktoken + encoding = tiktoken.get_encoding("cl100k_base") + + num_tokens = 0 + for message in messages: + if message["role"] == "assistant": + num_tokens += len(encoding.encode(message["content"])) + return num_tokens diff --git a/dspy/teleprompt/finetune_teleprompter.py b/dspy/teleprompt/finetune_teleprompter.py index d5c5f563a9..4bddd815c7 100644 --- a/dspy/teleprompt/finetune_teleprompter.py +++ b/dspy/teleprompt/finetune_teleprompter.py @@ -6,8 +6,8 @@ from dspy.primitives.example import Example from dspy.primitives.program import Program from dspy.primitives.prediction import Prediction -from dspy.signatures.signature import signature_to_template +import concurrent.futures #------------------------------------------------------------------------------- # Templates for the user-facing strings used by this module @@ -79,101 +79,8 @@ def prepare_teacher( return teacher -def build_prompt_completion_data_from_trace( - trace: List[Dict], - exclude_demos: bool=False, - try_to_record_lm_kwargs: bool = False, - program: Program = None, - ) -> Dict[str, List[Dict[str, Any]]]: - """Build prompt completion data from a given trace. - - Args: - trace: The trace from which the prompt-completion data will be built. - exclude_demos: Exclude the demos from the prompts even if they are - present in the trace. Defaults to `False`. - try_to_record_lm_kwargs: Whether to record the LM kwargs in the data. - Defaults to `False`. If set, the `lm_kwargs` field of the LM used to - generate the prompt-completion pair is included in the data. To - find the LM, we first check if the predictor that generated the - prompt-completion pair has an LM field set (`predictor.lm`). If it - does, we record it's kwargs. If it doesn't, we get the kwargs from - `dspy.settings.lm`. If `dspy.settings.lm` is not set either, this - function will not record the LM kwargs. - program: Optional argument used to infer the name of the predictor that - generated the prompt-completion pair. If provided, the returned data - will include the `predictor_name` field. If not provided, the - `predictor_name` field is not included in the data, but the caller - of this function can recover this information by using the - `predictor_ind` field that' included in the data by default, by - building the following dictionary: - - {ind: n for ind, (n, _) in enumerate(program.named_predictors())} - - where `ind` is the index of the predictor in the list returned by - the `named_predictors()` method of the program. Defaults to `None`. - - Returns: - Data as a list of dictionaries with the keys `prompt`, `completion` and - optionally with the keys `predictor_name` and `lm_kwargs`. For a given - prompt-completion pair: - - The `prompt` field corresponds to the prompt. - - The `completion` field corresponds to the completion. - - The `predictor_ind` field corresponds to the index of the predictor - in the predictor list - - The `predictor_name` field corresponds to the index of the predictor - in the list returned by the the named_predictor() method of the - program used to generate the trace. This field is included only if - the `pred_ind_to_name` argument is provided. - - The `lm_kwargs` field corresponds to the LM kwargs that generated the - prompt-completion pair. Included only if the `record_lm_kwargs` is - set and there is an active LM. - """ - # If the program is provided, build the predictor index to name mapping - if program: - pred_ind_to_name = { - ind: name for ind, (name, _) in enumerate(program.named_predictors()) - } - - # Build the prompt-completion data - data = [] - for pred_ind, (pred, inputs, outputs) in enumerate(trace): - # Get the demos from the predictor if exclude_demos is False - demos = [] if exclude_demos else pred.demos - - # Build prompt and completion strings - template = signature_to_template(pred.signature) - prompt = template(Example(demos=demos, **inputs)) - completion = template.query(Example(**outputs)) - - # TODO: This part of the code could be improved. - # The method we use to build the completion (template.query) is meant to - # be used for creating, well, queries, and hence contains field prefixes - # (e.g. "Reasoning: Let's think step by step in order to"), which are - # also contained in the last piece of the prompt (separated with a new - # line) We remove this piece from the completion. This is a hacky - # solution since it assumes a particular template format. - prompt_last = prompt.split("\n")[-1] - completion = completion[len(prompt_last):] - - # Create prompt-completion dictionary and add it to the data; optionally - # add the predictor_name key as well as the lm_kwargs. - data_dict = dict(prompt=prompt, completion=completion) - - # Record the predictor index and optionally, name - data_dict['predictor_ind'] = pred_ind - if program: - data_dict['predictor_name'] = pred_ind_to_name[pred_ind] - - # Optionally, record the LM kwargs - lm = pred.lm or dspy.settings.lm - if try_to_record_lm_kwargs and lm: - data_dict['lm_kwargs'] = lm.kwargs - data.append(data_dict) - - return data - - -def convert_to_module_level_prompt_completion_data( +# TODO: fix docstring +def convert_to_module_level_message_data( data: List[Dict], keep_data_keys: bool = False, exclude_demos: bool = False, @@ -214,7 +121,7 @@ def convert_to_module_level_prompt_completion_data( prompt_completion_data = [] for data_dict in data: trace = data_dict["trace"] - trace_prompt_comletion_data = build_prompt_completion_data_from_trace( + trace_prompt_comletion_data = build_messages_from_trace( trace=trace, exclude_demos=exclude_demos, try_to_record_lm_kwargs=try_to_record_lm_kwargs, program=program ) @@ -225,13 +132,51 @@ def convert_to_module_level_prompt_completion_data( return prompt_completion_data +# TODO: fix docstring +def build_messages_from_trace( + trace: List[Dict], + exclude_demos: bool=False, + try_to_record_lm_kwargs: bool = False, + program: Program = None, + ) -> Dict[str, List[Dict[str, Any]]]: + """Build messages from a given trace. + """ + messages = [] + # If the program is provided, build the predictor index to name mapping + if program: + pred_ind_to_name = { + ind: name for ind, (name, _) in enumerate(program.named_predictors()) + } + + # Build the prompt-completion data + + adapter = dspy.settings.adapter or dspy.ChatAdapter() + data = [] + + # TODO: Make sure that this works for multi-stage pipelines + for pred_ind, (pred, inputs, outputs) in enumerate(trace): + # Get the demos from the predictor if exclude_demos is False + demos = [] if exclude_demos else pred.demos + + messages = adapter.format(pred.signature, demos, inputs) + + formatted_completion = adapter.format_completion(pred.signature, outputs) + messages.append({"role": "assistant", "content": formatted_completion}) + data.append(messages) + + return data + +def dummy_metric(example, pred, trace=None, frac=1.0): + return 1 + def bootstrap_data( program: Program, dataset: List[Example], metric: Optional[Callable[ [Example, Prediction, Optional[List]], Union[bool, int, float] - ]] = None, + ]] = dummy_metric, num_threads = 1, + max_errors: int = 0 ) -> List[Dict[str, Any]]: """Bootstrap prediction and trace data for the program using the dataset. @@ -266,29 +211,44 @@ def bootstrap_data( info = _INFO_BOOTSTRAP_DATA.format(len(dataset), cname, num_threads) logger.info(info) evaluator = Evaluate( - devset=dataset, num_threads=num_threads, display_progress=True + devset=dataset, num_threads=num_threads, display_progress=True, max_errors=max_errors, provide_traceback=True ) evaluator(program, metric=metric) - # Re-iterate over the dataset to build the cached prompt-completion data - for example_ind, example in enumerate(dataset): - - # Run the program on the example - with dspy.context(trace=[]): - prediction = program(**example.inputs()) - trace = dspy.settings.trace - score = metric(example, prediction, trace) if metric else None - - # Build the data dictionary and extend the data list - data_dict = dict(example=example, prediction=prediction, trace=trace) - data_dict['example_ind'] = example_ind - if metric: - data_dict['score'] = score - data.append(data_dict) + data = [] + for example in dataset: + data_dict = process_example(example, 0, program, metric) + if data_dict is not None: + data.append(data_dict) return data +def process_example(example: Any, example_ind: int, program: Callable, metric: Optional[Callable] = None) -> Dict[str, Any]: + # print("Processing example:", example_ind) + with dspy.context(trace=[]): + # print("Running program...", example_ind) + try: + prediction = program(**example.inputs()) + except Exception as e: + print(f"Error processing example {example_ind}: {e}") + return None + # print("Getting trace...", example_ind) + trace = dspy.settings.trace + # print("Getting score...", example_ind) + score = metric(example, prediction, trace) if metric else None + + data_dict = { + 'example': example, + 'prediction': prediction, + 'trace': trace, + 'example_ind': example_ind + } + if metric: + data_dict['score'] = score + + return data_dict + # TODO: If we can ensure to pass the "round" information every time a call is # issued to an LM, we can make repetitive un-cached calls to the same LM without # modifying it's temperature. This function can be removed then. @@ -302,6 +262,7 @@ def bootstrap_data_for_round( sampling_round: int = 0, sampling_temperature: Optional[float] = 0.9, sampling_temperature_delta: float = 0.001, + max_errors: int = 0 ) -> Union[List[Dict], AssertionError]: """ Bootstrap data for the given sampling round. @@ -388,7 +349,8 @@ def copy_model_with_updated_temp(lm): # Ensure that the LM consistency is satisfied, which ensures that either (1) # the global LM is set or (2) all the predictors have an LM set. - program._assert_lm_consistency() + # TODO(isaac): Uncomment this line after the LM consistency property is + # program._assert_lm_consistency() # Deepcopy the program and copy the dataset to avoid modifying the original program = program.deepcopy() @@ -404,8 +366,9 @@ def copy_model_with_updated_temp(lm): # Collect the data for the given round with dspy.context(lm=context_lm): + # print(context_lm.kwargs) data = bootstrap_data( - program, dataset, metric=metric, num_threads=num_threads + program, dataset, metric=metric, num_threads=num_threads, max_errors=max_errors ) # Add the round information to the data diff --git a/dspy/teleprompt/random_search.py b/dspy/teleprompt/random_search.py index 57bd38150c..4fa8ba23fd 100644 --- a/dspy/teleprompt/random_search.py +++ b/dspy/teleprompt/random_search.py @@ -85,6 +85,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None max_labeled_demos=self.max_labeled_demos, teacher_settings=self.teacher_settings, max_rounds=self.max_rounds, + max_errors=self.max_errors, ) program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy) @@ -101,6 +102,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, restrict=None max_labeled_demos=self.max_labeled_demos, teacher_settings=self.teacher_settings, max_rounds=self.max_rounds, + max_errors=self.max_errors, ) program = optimizer.compile(student, teacher=teacher, trainset=trainset_copy) diff --git a/examples/finetuning/finetune_test.ipynb b/examples/finetuning/finetune_test.ipynb index f74fd4a91b..71c43bbcfa 100644 --- a/examples/finetuning/finetune_test.ipynb +++ b/examples/finetuning/finetune_test.ipynb @@ -18,7 +18,6 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "import tqdm\n", "import os\n", "\n", "os.environ[\"DSPY_CACHEDIR\"] = \"\"\n", diff --git a/requirements.txt b/requirements.txt index ae69ac3de3..01ff8c349b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -# DSPy requirements backoff datasets joblib~=1.3