|
| 1 | +# ----------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | +# |
| 6 | +# ----------------------------------------------------------------------------- |
| 7 | + |
| 8 | + |
| 9 | +import logging |
| 10 | +import sys |
| 11 | +from pathlib import Path |
| 12 | +from typing import Optional |
| 13 | + |
| 14 | +from transformers.utils.logging import get_logger as hf_get_logger |
| 15 | + |
| 16 | +from QEfficient.finetune.experimental.core.utils.dist_utils import get_local_rank |
| 17 | + |
| 18 | +# ----------------------------------------------------------------------------- |
| 19 | +# Logger usage: |
| 20 | +# Initialize logger: |
| 21 | +# logger = Logger("my_logger", log_file="logs/output.log", level=logging.DEBUG) |
| 22 | +# Log messages: |
| 23 | +# logger.info("This is an info message") |
| 24 | +# logger.error("This is an error message") |
| 25 | +# logger.log_rank_zero("This message is logged only on rank 0") |
| 26 | +# logger.log_exception("An error occurred", exception, raise_exception=False) |
| 27 | +# Attach file handler later if needed: |
| 28 | +# logger.prepare_for_logs(output_dir="logs", log_level="DEBUG") |
| 29 | +# ----------------------------------------------------------------------------- |
| 30 | + |
| 31 | + |
| 32 | +class Logger: |
| 33 | + """Custom logger with console and file logging capabilities.""" |
| 34 | + |
| 35 | + def __init__( |
| 36 | + self, |
| 37 | + name: str = "transformers", # We are using "transformers" as default to align with HF logs |
| 38 | + log_file: Optional[str] = None, |
| 39 | + level: int = logging.INFO, |
| 40 | + ): |
| 41 | + """ |
| 42 | + Initialize the logger. |
| 43 | +
|
| 44 | + Args: |
| 45 | + name: Logger name |
| 46 | + log_file: Path to log file (if None, log only to console) |
| 47 | + level: Logging level |
| 48 | + """ |
| 49 | + self.logger = hf_get_logger(name) |
| 50 | + self.logger.setLevel(level) |
| 51 | + |
| 52 | + # Clear any existing handlers |
| 53 | + self.logger.handlers.clear() |
| 54 | + |
| 55 | + # Create formatter |
| 56 | + self.formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") |
| 57 | + |
| 58 | + # Console handler |
| 59 | + console_handler = logging.StreamHandler(sys.stdout) |
| 60 | + console_handler.setLevel(level) |
| 61 | + console_handler.setFormatter(self.formatter) |
| 62 | + self.logger.addHandler(console_handler) |
| 63 | + |
| 64 | + # File handler (if log_file is provided) |
| 65 | + if log_file: |
| 66 | + # Create directory if it doesn't exist |
| 67 | + log_path = Path(log_file) |
| 68 | + log_path.parent.mkdir(parents=True, exist_ok=True) |
| 69 | + |
| 70 | + file_handler = logging.FileHandler(log_file) |
| 71 | + file_handler.setLevel(level) |
| 72 | + file_handler.setFormatter(self.formatter) |
| 73 | + self.logger.addHandler(file_handler) |
| 74 | + |
| 75 | + def debug(self, message: str) -> None: |
| 76 | + """Log debug message.""" |
| 77 | + self.logger.debug(message) |
| 78 | + |
| 79 | + def info(self, message: str) -> None: |
| 80 | + """Log info message.""" |
| 81 | + self.logger.info(message) |
| 82 | + |
| 83 | + def warning(self, message: str) -> None: |
| 84 | + """Log warning message.""" |
| 85 | + self.logger.warning(message) |
| 86 | + |
| 87 | + def error(self, message: str) -> None: |
| 88 | + """Log error message.""" |
| 89 | + self.logger.error(message) |
| 90 | + |
| 91 | + def critical(self, message: str) -> None: |
| 92 | + """Log critical message.""" |
| 93 | + self.logger.critical(message) |
| 94 | + |
| 95 | + def log_rank_zero(self, message: str, level: int = logging.INFO) -> None: |
| 96 | + """ |
| 97 | + Log message only on rank 0 process. |
| 98 | +
|
| 99 | + Args: |
| 100 | + message: Message to log |
| 101 | + level: Logging level |
| 102 | + """ |
| 103 | + if get_local_rank() == 0: |
| 104 | + self.logger.log(level, message) |
| 105 | + |
| 106 | + def log_exception(self, message: str, exception: Exception, raise_exception: bool = True) -> None: |
| 107 | + """ |
| 108 | + Log exception message and optionally raise the exception. |
| 109 | +
|
| 110 | + Args: |
| 111 | + message: Custom message to log |
| 112 | + exception: Exception to log |
| 113 | + raise_exception: Whether to raise the exception after logging |
| 114 | + """ |
| 115 | + error_message = f"{message}: {str(exception)}" |
| 116 | + self.logger.error(error_message) |
| 117 | + |
| 118 | + if raise_exception: |
| 119 | + raise exception |
| 120 | + |
| 121 | + def prepare_for_logs(self, output_dir: Optional[str] = None, log_level: str = "INFO") -> None: |
| 122 | + """ |
| 123 | + Prepare existing logger to log to both console and file with specified |
| 124 | + output directory and log level. |
| 125 | +
|
| 126 | + Args: |
| 127 | + output_dir: Output directory for logs |
| 128 | + log_level: Logging level as string |
| 129 | + """ |
| 130 | + # Convert string log level to logging constant |
| 131 | + level = getattr(logging, log_level.upper(), logging.INFO) |
| 132 | + self.logger.setLevel(level) |
| 133 | + |
| 134 | + # Update existing handlers' levels |
| 135 | + for handler in self.logger.handlers: |
| 136 | + handler.setLevel(level) |
| 137 | + |
| 138 | + # Add file handler if saving metrics |
| 139 | + if output_dir: |
| 140 | + log_file = Path(output_dir) / "training.log" |
| 141 | + log_file.parent.mkdir(parents=True, exist_ok=True) |
| 142 | + |
| 143 | + # Check if file handler already exists |
| 144 | + file_handler_exists = any(isinstance(handler, logging.FileHandler) for handler in self.logger.handlers) |
| 145 | + |
| 146 | + if not file_handler_exists: |
| 147 | + file_handler = logging.FileHandler(log_file) |
| 148 | + file_handler.setLevel(level) |
| 149 | + file_handler.setFormatter(self.formatter) |
| 150 | + self.logger.addHandler(file_handler) |
| 151 | + |
| 152 | + |
| 153 | +# Global logger instance |
| 154 | +_logger: Optional[Logger] = None |
| 155 | + |
| 156 | + |
| 157 | +def get_logger(log_file: Optional[str] = None) -> Logger: |
| 158 | + """ |
| 159 | + Get or create a logger instance. |
| 160 | +
|
| 161 | + Args: |
| 162 | + log_file: Path to log file (if None, log only to console) |
| 163 | +
|
| 164 | + Returns: |
| 165 | + Logger instance |
| 166 | + """ |
| 167 | + global _logger |
| 168 | + if _logger is None: |
| 169 | + _logger = Logger(log_file=log_file) |
| 170 | + return _logger |
0 commit comments