Skip to content
Merged
199 changes: 199 additions & 0 deletions QEfficient/finetune/experimental/core/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,202 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import json
import os
from typing import Any, Dict, Optional

from transformers import (
DefaultFlowCallback,
EarlyStoppingCallback,
PrinterCallback,
ProgressCallback,
TrainingArguments,
)
from transformers.integrations.integration_utils import TensorBoardCallback
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState

from QEfficient.finetune.experimental.core.component_registry import registry
from QEfficient.finetune.experimental.core.utils.profiler_utils import (
get_op_verifier_ctx,
init_qaic_profiling,
stop_qaic_profiling,
)

registry.callback("early_stopping")(EarlyStoppingCallback)
registry.callback("printer")(PrinterCallback)
registry.callback("default_flow")(DefaultFlowCallback)
registry.callback("tensorboard")(TensorBoardCallback)


@registry.callback("enhanced_progressbar")
class EnhancedProgressCallback(ProgressCallback):
"""
A [`TrainerCallback`] that displays the progress of training or evaluation.
You can modify `max_str_len` to control how long strings are truncated when logging.
"""

def __init__(self, *args, **kwargs):
"""
Initialize the callback with optional max_str_len parameter to control string truncation length.

Args:
max_str_len (`int`):
Maximum length of strings to display in logs.
Longer strings will be truncated with a message.
"""
super().__init__(*args, **kwargs)

def on_train_begin(self, args, state, control, **kwargs):
"""Set progress bar description at the start of training."""
super().on_train_begin(args, state, control, **kwargs)
if self.training_bar is not None:
self.training_bar.set_description("Training Progress")

def on_log(self, args, state, control, logs=None, **kwargs):
"""
Override the default `on_log` behavior during training to display
the current epoch number, loss, and learning rate in the logs.
"""
if state.is_world_process_zero and self.training_bar is not None:
# make a shallow copy of logs so we can mutate the fields copied
# but avoid doing any value pickling.
shallow_logs = {}
for k, v in logs.items():
if isinstance(v, str) and len(v) > self.max_str_len:
shallow_logs[k] = (
f"[String too long to display, length: {len(v)} > {self.max_str_len}. "
"Consider increasing `max_str_len` if needed.]"
)
else:
shallow_logs[k] = v
_ = shallow_logs.pop("total_flos", None)
# round numbers so that it looks better in console
if "epoch" in shallow_logs:
shallow_logs["epoch"] = round(shallow_logs["epoch"], 2)

updated_dict = {}
if "epoch" in shallow_logs:
updated_dict["epoch"] = shallow_logs["epoch"]
if "loss" in shallow_logs:
updated_dict["loss"] = shallow_logs["loss"]
if "learning_rate" in shallow_logs:
updated_dict["lr"] = shallow_logs["learning_rate"]
self.training_bar.set_postfix(updated_dict)


@registry.callback("json_logger")
class JSONLoggerCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs training and evaluation metrics to a JSON file.
"""

def __init__(self, log_path=None, *args, **kwargs):
"""
Initialize the callback with the path to the JSON log file.

Args:
log_path (`str`):
Path to the jsonl file where logs will be saved.
"""
super().__init__(*args, **kwargs)
if log_path is None:
log_path = os.path.join(os.environ.get("OUTPUT_DIR", "./"), "training_logs.jsonl")
self.log_path = log_path
# Ensure the log file is created and empty
with open(self.log_path, "w") as _:
pass

def on_log(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
logs: Optional[Dict] = None,
**kwargs,
):
"""Append sanitized log metrics (including global_step) to a JSONL file."""
if logs is None:
return
logs.pop("entropy", None)
logs.pop("mean_token_accuracy", None)
if state.global_step:
logs["global_step"] = state.global_step
if logs is not None:
with open(self.log_path, "a") as f:
json_line = json.dumps(logs, separators=(",", ":"))
f.write(json_line + "\n")


@registry.callback("qaic_profiler_callback")
class QAICProfilerCallback(TrainerCallback):
"""Callback to profile QAIC devices over a specified training step range."""

def __init__(self, *args, **kwargs):
"""
Initialize QAIC profiler settings (start/end steps and target device IDs).
"""

self.start_step = kwargs.get("start_step", -1)
self.end_step = kwargs.get("end_step", -1)
self.device_ids = kwargs.get("device_ids", [0])

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
if state.global_step == self.start_step:
for device_id in self.device_ids:
init_qaic_profiling(True, f"qaic:{device_id}")
elif state.global_step == self.end_step:
for device_id in self.device_ids:
stop_qaic_profiling(True, f"qaic:{device_id}")


@registry.callback("qaic_op_by_op_verifier_callback")
class QAICOpByOpVerifierCallback(TrainerCallback):
"""Callback to verify QAIC operations step-by-step during a specified training range."""

def __init__(self, *args, **kwargs):
""" "
Initialize QAIC Op-by-Op verifier callback with profiling and tolerance settings.
"""
self.start_step = kwargs.get("start_step", -1)
self.end_step = kwargs.get("end_step", -1)
self.trace_dir = kwargs.get("trace_dir", "qaic_op_by_op_traces")
self.atol = kwargs.get("atol", 1e-1)
self.rtol = kwargs.get("rtol", 1e-5)

def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the beginning of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
if self.start_step <= state.global_step < self.end_step:
self.op_verifier_ctx_step = get_op_verifier_ctx(
use_op_by_op_verifier=True,
device_type="qaic",
dump_dir=self.trace_dir,
step=state.global_step,
atol=self.atol,
rtol=self.rtol,
)
self.op_verifier_ctx_step.__enter__()

def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
"""
Event called at the end of a training step. If using gradient accumulation, one training step might take
several inputs.
"""
if self.start_step <= state.global_step < self.end_step:
if self.op_verifier_ctx_step is not None:
self.op_verifier_ctx_step.__exit__(None, None, None)


def create_callbacks(name: str, **kwargs) -> Any:
"""Create a callback instance."""
callback_class = registry.get_callback(name)
if callback_class is None:
raise ValueError(f"Unknown callback: {name}. Available: {registry.list_callbacks()}")
return callback_class(**kwargs)
88 changes: 88 additions & 0 deletions QEfficient/finetune/experimental/core/utils/profiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,91 @@
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------


from contextlib import nullcontext
from typing import ContextManager

import torch


def get_op_verifier_ctx(
use_op_by_op_verifier: bool,
device_type: str,
dump_dir: str,
step: int,
ref_device: str = "cpu",
ref_dtype: torch.dtype = torch.float32,
atol: float = 1e-1,
rtol: float = 1e-5,
use_ref_output_on_mismatch: bool = True,
) -> ContextManager:
"""Get the op-by-op verifier context manager when op-by-op verification is
enabled. It helps in debuging operator related issues by matching the
operator execution on qaic v/s cpu. This is meant only for qaic backend.

Args:
use_op_by_op_verifier (bool): Boolean flag to enable op-by-op verifier.
device_type (str): Device on which the model is being executed.
dump_dir (str): Directory to dump the op-by-op verification results.
step (int): Step number for which the op-by-op verification is to be performed.
ref_device (str, optional): Device to use as reference for verification.
Defaults to "cpu".
ref_dtype (torch.dtype, optional): Data type to use as reference
datatype for verification. Defaults to torch.float32.
atol (float, optional): Absolute tolerance to match the results. Defaults to 1e-1.
rtol (float, optional): Relative tolerance to match the results. Defaults to 1e-5.
use_ref_output_on_mismatch (bool, optional): If an operator has a
mismatch with respect to the reference device, use the reference
device outputs and continue rest of the verification. Defaults to True.

Returns:
ContextManager: Instance of context manager used to verify the operators.
"""
if (not use_op_by_op_verifier) or ("qaic" in device_type):
return nullcontext()

# Lazily imported qaic_debug when it is actually needed.
import torch_qaic.debug as qaic_debug

filter_config = qaic_debug.DispatchFilterConfig.default(device_type)
dump_dir = dump_dir + "/mismatches/step_" + str(step)
return qaic_debug.OpByOpVerifierMode(
ref_device=ref_device,
ref_dtype=ref_dtype,
atol=atol,
rtol=rtol,
use_ref_output_on_mismatch=use_ref_output_on_mismatch,
filter_config=filter_config,
dump_root_dir=dump_dir,
)


def init_qaic_profiling(use_profiler: bool, device_type: str) -> None:
"""Initialize the qaic profiling tool. Note: The profiler is only works
for qaic backend.

Args:
use_profiler (bool): Boolean flag to enable profiler.
device_type (str): Device on which the model is being executed.
"""
if (use_profiler) and ("qaic" in device_type):
# Lazily imported qaic's qaic_profile when it is actually needed.
import torch_qaic.profile as qaic_profile

qaic_profile.start_profiling(device_type, 1)


def stop_qaic_profiling(use_profiler: bool, device_type: str) -> None:
"""Stop the qaic profiling tool. Note: The profiler is only works
for qaic backend.

Args:
use_profiler (bool): Boolean flag to enable profiler.
device_type (str): Device on which the model is being executed.
"""
if (use_profiler) and ("qaic" in device_type):
# Lazily imported qaic's qaic_profile when it is actually needed.
import torch_qaic.profile as qaic_profile

qaic_profile.stop_profiling(device_type)
63 changes: 63 additions & 0 deletions QEfficient/finetune/experimental/tests/test_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

import pytest
from transformers import TrainerCallback

from QEfficient.finetune.experimental.core.callbacks import create_callbacks
from QEfficient.finetune.experimental.core.component_registry import registry


class ModelSummaryCallback(TrainerCallback):
def __init__(self):
pass


# Setup test data
CALLBACK_CONFIGS = {
"early_stopping": {
"name": "early_stopping",
"early_stopping_patience": 3,
"early_stopping_threshold": 0.001,
},
"tensorboard": {"name": "tensorboard", "tb_writer": "SummaryWriter"},
"model_summary": {
"name": "model_summary",
"max_depth": 1,
},
}

REGISTRY_CALLBACK_CONFIGS = {
"model_summary": {
"name": "model_summary",
"max_depth": 1,
"callback_class": ModelSummaryCallback,
},
}


@pytest.mark.parametrize("callback_name", CALLBACK_CONFIGS.keys())
def test_callbacks(callback_name):
"""Test that registered callbacks that can be created with their configs."""
# Create callbacks using the factory
config = CALLBACK_CONFIGS[callback_name]
try:
callback_inst = create_callbacks(**config)
except ValueError as e:
assert "Unknown callback" in str(e)
return
assert callback_inst is not None
assert isinstance(callback_inst, TrainerCallback)


@pytest.mark.parametrize("callback_name,callback_class", REGISTRY_CALLBACK_CONFIGS.items())
def test_callbacks_registery(callback_name, callback_class):
"""Test that a callback registered correctly."""
registry.callback(callback_name)(callback_class)
callback = registry.get_callback(callback_name)
assert callback is not None
assert callback == callback_class