diff --git a/QEfficient/finetune/experimental/core/logger.py b/QEfficient/finetune/experimental/core/logger.py new file mode 100644 index 000000000..a1b9c771f --- /dev/null +++ b/QEfficient/finetune/experimental/core/logger.py @@ -0,0 +1,170 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + + +import logging +import sys +from pathlib import Path +from typing import Optional + +from transformers.utils.logging import get_logger as hf_get_logger + +from QEfficient.finetune.experimental.core.utils.dist_utils import get_local_rank + +# ----------------------------------------------------------------------------- +# Logger usage: +# Initialize logger: +# logger = Logger("my_logger", log_file="logs/output.log", level=logging.DEBUG) +# Log messages: +# logger.info("This is an info message") +# logger.error("This is an error message") +# logger.log_rank_zero("This message is logged only on rank 0") +# logger.log_exception("An error occurred", exception, raise_exception=False) +# Attach file handler later if needed: +# logger.prepare_for_logs(output_dir="logs", log_level="DEBUG") +# ----------------------------------------------------------------------------- + + +class Logger: + """Custom logger with console and file logging capabilities.""" + + def __init__( + self, + name: str = "transformers", # We are using "transformers" as default to align with HF logs + log_file: Optional[str] = None, + level: int = logging.INFO, + ): + """ + Initialize the logger. + + Args: + name: Logger name + log_file: Path to log file (if None, log only to console) + level: Logging level + """ + self.logger = hf_get_logger(name) + self.logger.setLevel(level) + + # Clear any existing handlers + self.logger.handlers.clear() + + # Create formatter + self.formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(level) + console_handler.setFormatter(self.formatter) + self.logger.addHandler(console_handler) + + # File handler (if log_file is provided) + if log_file: + # Create directory if it doesn't exist + log_path = Path(log_file) + log_path.parent.mkdir(parents=True, exist_ok=True) + + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + file_handler.setFormatter(self.formatter) + self.logger.addHandler(file_handler) + + def debug(self, message: str) -> None: + """Log debug message.""" + self.logger.debug(message) + + def info(self, message: str) -> None: + """Log info message.""" + self.logger.info(message) + + def warning(self, message: str) -> None: + """Log warning message.""" + self.logger.warning(message) + + def error(self, message: str) -> None: + """Log error message.""" + self.logger.error(message) + + def critical(self, message: str) -> None: + """Log critical message.""" + self.logger.critical(message) + + def log_rank_zero(self, message: str, level: int = logging.INFO) -> None: + """ + Log message only on rank 0 process. + + Args: + message: Message to log + level: Logging level + """ + if get_local_rank() == 0: + self.logger.log(level, message) + + def log_exception(self, message: str, exception: Exception, raise_exception: bool = True) -> None: + """ + Log exception message and optionally raise the exception. + + Args: + message: Custom message to log + exception: Exception to log + raise_exception: Whether to raise the exception after logging + """ + error_message = f"{message}: {str(exception)}" + self.logger.error(error_message) + + if raise_exception: + raise exception + + def prepare_for_logs(self, output_dir: Optional[str] = None, log_level: str = "INFO") -> None: + """ + Prepare existing logger to log to both console and file with specified + output directory and log level. + + Args: + output_dir: Output directory for logs + log_level: Logging level as string + """ + # Convert string log level to logging constant + level = getattr(logging, log_level.upper(), logging.INFO) + self.logger.setLevel(level) + + # Update existing handlers' levels + for handler in self.logger.handlers: + handler.setLevel(level) + + # Add file handler if saving metrics + if output_dir: + log_file = Path(output_dir) / "training.log" + log_file.parent.mkdir(parents=True, exist_ok=True) + + # Check if file handler already exists + file_handler_exists = any(isinstance(handler, logging.FileHandler) for handler in self.logger.handlers) + + if not file_handler_exists: + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(level) + file_handler.setFormatter(self.formatter) + self.logger.addHandler(file_handler) + + +# Global logger instance +_logger: Optional[Logger] = None + + +def get_logger(log_file: Optional[str] = None) -> Logger: + """ + Get or create a logger instance. + + Args: + log_file: Path to log file (if None, log only to console) + + Returns: + Logger instance + """ + global _logger + if _logger is None: + _logger = Logger(log_file=log_file) + return _logger diff --git a/QEfficient/finetune/experimental/core/utils/dist_utils.py b/QEfficient/finetune/experimental/core/utils/dist_utils.py index d647b73a6..aed88862d 100644 --- a/QEfficient/finetune/experimental/core/utils/dist_utils.py +++ b/QEfficient/finetune/experimental/core/utils/dist_utils.py @@ -4,3 +4,36 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- + +import torch.distributed as dist + + +def is_dist_available_and_initialized() -> bool: + """Check if distributed training is available and initialized.""" + return dist.is_available() and dist.is_initialized() + + +def get_rank() -> int: + """Return the global rank of the current process, else 0.""" + if not is_dist_available_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """Return the local rank of the current process on its node, else 0.""" + if not is_dist_available_and_initialized(): + return 0 + return dist.get_node_local_rank() + + +def get_world_size() -> int: + """Get the total number of processes in distributed training.""" + if not is_dist_available_and_initialized(): + return 1 + return dist.get_world_size() + + +def is_main_process() -> bool: + """Check if the current process is the main process (rank 0).""" + return get_rank() == 0 diff --git a/QEfficient/finetune/experimental/tests/test_logger.py b/QEfficient/finetune/experimental/tests/test_logger.py new file mode 100644 index 000000000..0af0c8b51 --- /dev/null +++ b/QEfficient/finetune/experimental/tests/test_logger.py @@ -0,0 +1,233 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import logging +from unittest.mock import patch + +import pytest + +from QEfficient.finetune.experimental.core.logger import Logger, get_logger + + +class TestLogger: + def setup_method(self): + """Reset the global logger before each test method""" + import QEfficient.finetune.experimental.core.logger as logger_module + + logger_module._logger = None + + def test_init_console_only(self): + """Test logger initialization with console-only output""" + logger = Logger("test_logger") + + # Check logger attributes + assert logger.logger.name == "test_logger" + assert logger.logger.level == logging.INFO + + # Check handlers - should have console handler only + assert len(logger.logger.handlers) == 1 # Only console handler + assert isinstance(logger.logger.handlers[0], logging.StreamHandler) + + def test_init_with_file(self, tmp_path): + """Test logger initialization with file output""" + log_file = tmp_path / "test.log" + logger = Logger("file_test_logger", str(log_file)) + + # Check handlers - should have both console and file handlers + assert len(logger.logger.handlers) == 2 # Console + file handler + assert isinstance(logger.logger.handlers[0], logging.StreamHandler) + assert isinstance(logger.logger.handlers[1], logging.FileHandler) + + # Check file creation + assert log_file.exists() + + def test_log_levels(self, caplog): + """Test all log levels work correctly""" + logger = Logger("level_test_logger", level=logging.DEBUG) + + with caplog.at_level(logging.DEBUG): + logger.debug("Debug message") + logger.info("Info message") + logger.warning("Warning message") + logger.error("Error message") + logger.critical("Critical message") + + # Check all messages were logged + assert "Debug message" in caplog.text + assert "Info message" in caplog.text + assert "Warning message" in caplog.text + assert "Error message" in caplog.text + assert "Critical message" in caplog.text + + @patch("QEfficient.finetune.experimental.core.logger.get_local_rank") + def test_log_rank_zero_positive_case(self, mock_get_local_rank, caplog): + """Test rank zero logging functionality""" + mock_get_local_rank.return_value = 0 + logger = Logger("rank_test_logger") + + with caplog.at_level(logging.INFO): + logger.log_rank_zero("Rank zero message") + + assert "Rank zero message" in caplog.text + + @patch("QEfficient.finetune.experimental.core.logger.get_local_rank") + def test_log_rank_zero_negative_case(self, mock_get_local_rank, caplog): + """Test to verify that only rank‑zero messages are logged""" + mock_get_local_rank.return_value = 1 + logger = Logger("rank_test_logger") + + with caplog.at_level(logging.INFO): + logger.log_rank_zero("Should not appear") + + assert "Should not appear" not in caplog.text + + def test_log_exception_raise(self, caplog): + """Test exception logging with raising""" + logger = Logger("exception_test_logger") + + with pytest.raises(ValueError), caplog.at_level(logging.ERROR): + logger.log_exception("Custom error", ValueError("Test exception"), raise_exception=True) + + # The actual logged message is "Custom error: Test exception" + # But the exception itself contains just "Test exception" + assert "Custom error: Test exception" in caplog.text + + def test_log_exception_no_raise(self, caplog): + """Test exception logging without raising""" + logger = Logger("exception_test_logger") + + with caplog.at_level(logging.ERROR): + logger.log_exception("Custom error", ValueError("Test exception"), raise_exception=False) + + # Check that the formatted message was logged + assert "Custom error: Test exception" in caplog.text + + def test_prepare_for_logs(self, tmp_path): + """Test preparing logger for training logs""" + output_dir = tmp_path / "output" + logger = Logger("prepare_test_logger") + + # Prepare for logs + logger.prepare_for_logs(str(output_dir), log_level="DEBUG") + + # Check file handler was added + file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) == 1 + + # Check file exists + log_file = output_dir / "training.log" + assert log_file.exists() + + # Check log level was updated + assert logger.logger.level == logging.DEBUG + + def test_prepare_for_logs_no_file_handler(self): + """Test preparing logger without saving to file""" + logger = Logger("prepare_test_logger") + + # Prepare for logs without saving metrics + logger.prepare_for_logs(log_level="INFO") + + # Check no file handler was added + file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) == 0 + + def test_prepare_for_logs_already_has_file_handler(self, tmp_path): + """Test preparing logger when file handler already exists""" + output_dir = tmp_path / "output" + logger = Logger("prepare_test_logger") + + # Add a file handler manually first + log_file = output_dir / "manual.log" + log_file.parent.mkdir(parents=True, exist_ok=True) + file_handler = logging.FileHandler(str(log_file)) + logger.logger.addHandler(file_handler) + + # Prepare for logs again + logger.prepare_for_logs(str(output_dir), log_level="INFO") + + # Should still have only one file handler + file_handlers = [h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)] + assert len(file_handlers) == 1 + + def test_get_logger_singleton(self): + """Test that get_logger returns the same instance""" + logger1 = get_logger() + logger2 = get_logger() + + assert logger1 is logger2 + + def test_get_logger_with_file(self, tmp_path): + """Test get_logger with file parameter""" + log_file = tmp_path / "get_logger_test.log" + logger = get_logger(str(log_file)) + + # Check that we have 2 handlers (console + file) + assert len(logger.logger.handlers) == 2 # Console + file + assert isinstance(logger.logger.handlers[1], logging.FileHandler) + + # Check file exists + assert log_file.exists() + + +class TestLoggerIntegration: + """Integration tests for logger functionality""" + + def setup_method(self): + """Reset the global logger before each test method""" + import QEfficient.finetune.experimental.core.logger as logger_module + + logger_module._logger = None + + def test_complete_workflow(self, tmp_path, caplog): + """Test complete logger workflow""" + # Setup + log_file = tmp_path / "workflow.log" + logger = Logger("workflow_test", str(log_file), logging.DEBUG) + + # Test all methods + logger.debug("Debug test") + logger.info("Info test") + logger.warning("Warning test") + logger.error("Error test") + logger.critical("Critical test") + + # Test exception handling + try: + raise ValueError("Test exception") + except ValueError as e: + logger.log_exception("Caught exception", e, raise_exception=False) + + # Test rank zero logging + with patch("QEfficient.finetune.experimental.core.logger.get_local_rank") as mock_rank: + mock_rank.return_value = 0 + logger.log_rank_zero("Rank zero test") + + # Verify all messages were logged + with caplog.at_level(logging.DEBUG): + assert "Debug test" in caplog.text + assert "Info test" in caplog.text + assert "Warning test" in caplog.text + assert "Error test" in caplog.text + assert "Critical test" in caplog.text + assert "Caught exception: Test exception" in caplog.text + assert "Rank zero test" in caplog.text + + # Check file was written to + assert log_file.exists() + content = log_file.read_text() + assert "Debug test" in content + assert "Info test" in content + assert "Warning test" in content + assert "Error test" in content + assert "Critical test" in content + assert "Caught exception: Test exception" in content + assert "Rank zero test" in content + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])