diff --git a/QEfficient/finetune/experimental/core/callbacks.py b/QEfficient/finetune/experimental/core/callbacks.py index d647b73a6..30659e3bb 100644 --- a/QEfficient/finetune/experimental/core/callbacks.py +++ b/QEfficient/finetune/experimental/core/callbacks.py @@ -4,3 +4,202 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +import json +import os +from typing import Any, Dict, Optional + +from transformers import ( + DefaultFlowCallback, + EarlyStoppingCallback, + PrinterCallback, + ProgressCallback, + TrainingArguments, +) +from transformers.integrations.integration_utils import TensorBoardCallback +from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState + +from QEfficient.finetune.experimental.core.component_registry import registry +from QEfficient.finetune.experimental.core.utils.profiler_utils import ( + get_op_verifier_ctx, + init_qaic_profiling, + stop_qaic_profiling, +) + +registry.callback("early_stopping")(EarlyStoppingCallback) +registry.callback("printer")(PrinterCallback) +registry.callback("default_flow")(DefaultFlowCallback) +registry.callback("tensorboard")(TensorBoardCallback) + + +@registry.callback("enhanced_progressbar") +class EnhancedProgressCallback(ProgressCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation. + You can modify `max_str_len` to control how long strings are truncated when logging. + """ + + def __init__(self, *args, **kwargs): + """ + Initialize the callback with optional max_str_len parameter to control string truncation length. + + Args: + max_str_len (`int`): + Maximum length of strings to display in logs. + Longer strings will be truncated with a message. + """ + super().__init__(*args, **kwargs) + + def on_train_begin(self, args, state, control, **kwargs): + """Set progress bar description at the start of training.""" + super().on_train_begin(args, state, control, **kwargs) + if self.training_bar is not None: + self.training_bar.set_description("Training Progress") + + def on_log(self, args, state, control, logs=None, **kwargs): + """ + Override the default `on_log` behavior during training to display + the current epoch number, loss, and learning rate in the logs. + """ + if state.is_world_process_zero and self.training_bar is not None: + # make a shallow copy of logs so we can mutate the fields copied + # but avoid doing any value pickling. + shallow_logs = {} + for k, v in logs.items(): + if isinstance(v, str) and len(v) > self.max_str_len: + shallow_logs[k] = ( + f"[String too long to display, length: {len(v)} > {self.max_str_len}. " + "Consider increasing `max_str_len` if needed.]" + ) + else: + shallow_logs[k] = v + _ = shallow_logs.pop("total_flos", None) + # round numbers so that it looks better in console + if "epoch" in shallow_logs: + shallow_logs["epoch"] = round(shallow_logs["epoch"], 2) + + updated_dict = {} + if "epoch" in shallow_logs: + updated_dict["epoch"] = shallow_logs["epoch"] + if "loss" in shallow_logs: + updated_dict["loss"] = shallow_logs["loss"] + if "learning_rate" in shallow_logs: + updated_dict["lr"] = shallow_logs["learning_rate"] + self.training_bar.set_postfix(updated_dict) + + +@registry.callback("json_logger") +class JSONLoggerCallback(TrainerCallback): + """ + A [`TrainerCallback`] that logs training and evaluation metrics to a JSON file. + """ + + def __init__(self, log_path=None, *args, **kwargs): + """ + Initialize the callback with the path to the JSON log file. + + Args: + log_path (`str`): + Path to the jsonl file where logs will be saved. + """ + super().__init__(*args, **kwargs) + if log_path is None: + log_path = os.path.join(os.environ.get("OUTPUT_DIR", "./"), "training_logs.jsonl") + self.log_path = log_path + # Ensure the log file is created and empty + with open(self.log_path, "w") as _: + pass + + def on_log( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + logs: Optional[Dict] = None, + **kwargs, + ): + """Append sanitized log metrics (including global_step) to a JSONL file.""" + if logs is None: + return + logs.pop("entropy", None) + logs.pop("mean_token_accuracy", None) + if state.global_step: + logs["global_step"] = state.global_step + if logs is not None: + with open(self.log_path, "a") as f: + json_line = json.dumps(logs, separators=(",", ":")) + f.write(json_line + "\n") + + +@registry.callback("qaic_profiler_callback") +class QAICProfilerCallback(TrainerCallback): + """Callback to profile QAIC devices over a specified training step range.""" + + def __init__(self, *args, **kwargs): + """ + Initialize QAIC profiler settings (start/end steps and target device IDs). + """ + + self.start_step = kwargs.get("start_step", -1) + self.end_step = kwargs.get("end_step", -1) + self.device_ids = kwargs.get("device_ids", [0]) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + if state.global_step == self.start_step: + for device_id in self.device_ids: + init_qaic_profiling(True, f"qaic:{device_id}") + elif state.global_step == self.end_step: + for device_id in self.device_ids: + stop_qaic_profiling(True, f"qaic:{device_id}") + + +@registry.callback("qaic_op_by_op_verifier_callback") +class QAICOpByOpVerifierCallback(TrainerCallback): + """Callback to verify QAIC operations step-by-step during a specified training range.""" + + def __init__(self, *args, **kwargs): + """ " + Initialize QAIC Op-by-Op verifier callback with profiling and tolerance settings. + """ + self.start_step = kwargs.get("start_step", -1) + self.end_step = kwargs.get("end_step", -1) + self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces") + self.atol = kwargs.get("atol", 1e-1) + self.rtol = kwargs.get("rtol", 1e-5) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + if self.start_step <= state.global_step < self.end_step: + self.op_verifier_ctx_step = get_op_verifier_ctx( + use_op_by_op_verifier=True, + device_type="qaic", + dump_dir=self.trace_dir, + step=state.global_step, + atol=self.atol, + rtol=self.rtol, + ) + self.op_verifier_ctx_step.__enter__() + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + if self.start_step <= state.global_step < self.end_step: + if self.op_verifier_ctx_step is not None: + self.op_verifier_ctx_step.__exit__(None, None, None) + + +def create_callbacks(name: str, **kwargs) -> Any: + """Create a callback instance.""" + callback_class = registry.get_callback(name) + if callback_class is None: + raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}") + return callback_class(**kwargs) diff --git a/QEfficient/finetune/experimental/core/utils/profiler_utils.py b/QEfficient/finetune/experimental/core/utils/profiler_utils.py index d647b73a6..e24508e83 100644 --- a/QEfficient/finetune/experimental/core/utils/profiler_utils.py +++ b/QEfficient/finetune/experimental/core/utils/profiler_utils.py @@ -4,3 +4,91 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + + +from contextlib import nullcontext +from typing import ContextManager + +import torch + + +def get_op_verifier_ctx( + use_op_by_op_verifier: bool, + device_type: str, + dump_dir: str, + step: int, + ref_device: str = "cpu", + ref_dtype: torch.dtype = torch.float32, + atol: float = 1e-1, + rtol: float = 1e-5, + use_ref_output_on_mismatch: bool = True, +) -> ContextManager: + """Get the op-by-op verifier context manager when op-by-op verification is + enabled. It helps in debuging operator related issues by matching the + operator execution on qaic v/s cpu. This is meant only for qaic backend. + + Args: + use_op_by_op_verifier (bool): Boolean flag to enable op-by-op verifier. + device_type (str): Device on which the model is being executed. + dump_dir (str): Directory to dump the op-by-op verification results. + step (int): Step number for which the op-by-op verification is to be performed. + ref_device (str, optional): Device to use as reference for verification. + Defaults to "cpu". + ref_dtype (torch.dtype, optional): Data type to use as reference + datatype for verification. Defaults to torch.float32. + atol (float, optional): Absolute tolerance to match the results. Defaults to 1e-1. + rtol (float, optional): Relative tolerance to match the results. Defaults to 1e-5. + use_ref_output_on_mismatch (bool, optional): If an operator has a + mismatch with respect to the reference device, use the reference + device outputs and continue rest of the verification. Defaults to True. + + Returns: + ContextManager: Instance of context manager used to verify the operators. + """ + if (not use_op_by_op_verifier) or ("qaic" in device_type): + return nullcontext() + + # Lazily imported qaic_debug when it is actually needed. + import torch_qaic.debug as qaic_debug + + filter_config = qaic_debug.DispatchFilterConfig.default(device_type) + dump_dir = dump_dir + "/mismatches/step_" + str(step) + return qaic_debug.OpByOpVerifierMode( + ref_device=ref_device, + ref_dtype=ref_dtype, + atol=atol, + rtol=rtol, + use_ref_output_on_mismatch=use_ref_output_on_mismatch, + filter_config=filter_config, + dump_root_dir=dump_dir, + ) + + +def init_qaic_profiling(use_profiler: bool, device_type: str) -> None: + """Initialize the qaic profiling tool. Note: The profiler is only works + for qaic backend. + + Args: + use_profiler (bool): Boolean flag to enable profiler. + device_type (str): Device on which the model is being executed. + """ + if (use_profiler) and ("qaic" in device_type): + # Lazily imported qaic's qaic_profile when it is actually needed. + import torch_qaic.profile as qaic_profile + + qaic_profile.start_profiling(device_type, 1) + + +def stop_qaic_profiling(use_profiler: bool, device_type: str) -> None: + """Stop the qaic profiling tool. Note: The profiler is only works + for qaic backend. + + Args: + use_profiler (bool): Boolean flag to enable profiler. + device_type (str): Device on which the model is being executed. + """ + if (use_profiler) and ("qaic" in device_type): + # Lazily imported qaic's qaic_profile when it is actually needed. + import torch_qaic.profile as qaic_profile + + qaic_profile.stop_profiling(device_type) diff --git a/QEfficient/finetune/experimental/tests/test_callback.py b/QEfficient/finetune/experimental/tests/test_callback.py new file mode 100644 index 000000000..59ff4d117 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_callback.py @@ -0,0 +1,63 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import pytest +from transformers import TrainerCallback + +from QEfficient.finetune.experimental.core.callbacks import create_callbacks +from QEfficient.finetune.experimental.core.component_registry import registry + + +class ModelSummaryCallback(TrainerCallback): + def __init__(self): + pass + + +# Setup test data +CALLBACK_CONFIGS = { + "early_stopping": { + "name": "early_stopping", + "early_stopping_patience": 3, + "early_stopping_threshold": 0.001, + }, + "tensorboard": {"name": "tensorboard", "tb_writer": "SummaryWriter"}, + "model_summary": { + "name": "model_summary", + "max_depth": 1, + }, +} + +REGISTRY_CALLBACK_CONFIGS = { + "model_summary": { + "name": "model_summary", + "max_depth": 1, + "callback_class": ModelSummaryCallback, + }, +} + + +@pytest.mark.parametrize("callback_name", CALLBACK_CONFIGS.keys()) +def test_callbacks(callback_name): + """Test that registered callbacks that can be created with their configs.""" + # Create callbacks using the factory + config = CALLBACK_CONFIGS[callback_name] + try: + callback_inst = create_callbacks(**config) + except ValueError as e: + assert "Unknown callback" in str(e) + return + assert callback_inst is not None + assert isinstance(callback_inst, TrainerCallback) + + +@pytest.mark.parametrize("callback_name,callback_class", REGISTRY_CALLBACK_CONFIGS.items()) +def test_callbacks_registery(callback_name, callback_class): + """Test that a callback registered correctly.""" + registry.callback(callback_name)(callback_class) + callback = registry.get_callback(callback_name) + assert callback is not None + assert callback == callback_class