Skip to content

Commit 0350c0f

Browse files
[QEff. Finetune]: Added logger and its test cases. (#644)
- Added a logger which will log onto console and file. This code is similar to existing QEff. Finetuning logger code. - Also added dist_utils which serves as utility code when dealing with distributed training. - Added logger test cases for sanity checks. --------- Signed-off-by: meetkuma <meetkuma@qti.qualcomm.com>
1 parent ea26341 commit 0350c0f

File tree

3 files changed

+436
-0
lines changed

3 files changed

+436
-0
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
9+
import logging
10+
import sys
11+
from pathlib import Path
12+
from typing import Optional
13+
14+
from transformers.utils.logging import get_logger as hf_get_logger
15+
16+
from QEfficient.finetune.experimental.core.utils.dist_utils import get_local_rank
17+
18+
# -----------------------------------------------------------------------------
19+
# Logger usage:
20+
# Initialize logger:
21+
# logger = Logger("my_logger", log_file="logs/output.log", level=logging.DEBUG)
22+
# Log messages:
23+
# logger.info("This is an info message")
24+
# logger.error("This is an error message")
25+
# logger.log_rank_zero("This message is logged only on rank 0")
26+
# logger.log_exception("An error occurred", exception, raise_exception=False)
27+
# Attach file handler later if needed:
28+
# logger.prepare_for_logs(output_dir="logs", log_level="DEBUG")
29+
# -----------------------------------------------------------------------------
30+
31+
32+
class Logger:
33+
"""Custom logger with console and file logging capabilities."""
34+
35+
def __init__(
36+
self,
37+
name: str = "transformers", # We are using "transformers" as default to align with HF logs
38+
log_file: Optional[str] = None,
39+
level: int = logging.INFO,
40+
):
41+
"""
42+
Initialize the logger.
43+
44+
Args:
45+
name: Logger name
46+
log_file: Path to log file (if None, log only to console)
47+
level: Logging level
48+
"""
49+
self.logger = hf_get_logger(name)
50+
self.logger.setLevel(level)
51+
52+
# Clear any existing handlers
53+
self.logger.handlers.clear()
54+
55+
# Create formatter
56+
self.formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
57+
58+
# Console handler
59+
console_handler = logging.StreamHandler(sys.stdout)
60+
console_handler.setLevel(level)
61+
console_handler.setFormatter(self.formatter)
62+
self.logger.addHandler(console_handler)
63+
64+
# File handler (if log_file is provided)
65+
if log_file:
66+
# Create directory if it doesn't exist
67+
log_path = Path(log_file)
68+
log_path.parent.mkdir(parents=True, exist_ok=True)
69+
70+
file_handler = logging.FileHandler(log_file)
71+
file_handler.setLevel(level)
72+
file_handler.setFormatter(self.formatter)
73+
self.logger.addHandler(file_handler)
74+
75+
def debug(self, message: str) -> None:
76+
"""Log debug message."""
77+
self.logger.debug(message)
78+
79+
def info(self, message: str) -> None:
80+
"""Log info message."""
81+
self.logger.info(message)
82+
83+
def warning(self, message: str) -> None:
84+
"""Log warning message."""
85+
self.logger.warning(message)
86+
87+
def error(self, message: str) -> None:
88+
"""Log error message."""
89+
self.logger.error(message)
90+
91+
def critical(self, message: str) -> None:
92+
"""Log critical message."""
93+
self.logger.critical(message)
94+
95+
def log_rank_zero(self, message: str, level: int = logging.INFO) -> None:
96+
"""
97+
Log message only on rank 0 process.
98+
99+
Args:
100+
message: Message to log
101+
level: Logging level
102+
"""
103+
if get_local_rank() == 0:
104+
self.logger.log(level, message)
105+
106+
def log_exception(self, message: str, exception: Exception, raise_exception: bool = True) -> None:
107+
"""
108+
Log exception message and optionally raise the exception.
109+
110+
Args:
111+
message: Custom message to log
112+
exception: Exception to log
113+
raise_exception: Whether to raise the exception after logging
114+
"""
115+
error_message = f"{message}: {str(exception)}"
116+
self.logger.error(error_message)
117+
118+
if raise_exception:
119+
raise exception
120+
121+
def prepare_for_logs(self, output_dir: Optional[str] = None, log_level: str = "INFO") -> None:
122+
"""
123+
Prepare existing logger to log to both console and file with specified
124+
output directory and log level.
125+
126+
Args:
127+
output_dir: Output directory for logs
128+
log_level: Logging level as string
129+
"""
130+
# Convert string log level to logging constant
131+
level = getattr(logging, log_level.upper(), logging.INFO)
132+
self.logger.setLevel(level)
133+
134+
# Update existing handlers' levels
135+
for handler in self.logger.handlers:
136+
handler.setLevel(level)
137+
138+
# Add file handler if saving metrics
139+
if output_dir:
140+
log_file = Path(output_dir) / "training.log"
141+
log_file.parent.mkdir(parents=True, exist_ok=True)
142+
143+
# Check if file handler already exists
144+
file_handler_exists = any(isinstance(handler, logging.FileHandler) for handler in self.logger.handlers)
145+
146+
if not file_handler_exists:
147+
file_handler = logging.FileHandler(log_file)
148+
file_handler.setLevel(level)
149+
file_handler.setFormatter(self.formatter)
150+
self.logger.addHandler(file_handler)
151+
152+
153+
# Global logger instance
154+
_logger: Optional[Logger] = None
155+
156+
157+
def get_logger(log_file: Optional[str] = None) -> Logger:
158+
"""
159+
Get or create a logger instance.
160+
161+
Args:
162+
log_file: Path to log file (if None, log only to console)
163+
164+
Returns:
165+
Logger instance
166+
"""
167+
global _logger
168+
if _logger is None:
169+
_logger = Logger(log_file=log_file)
170+
return _logger

QEfficient/finetune/experimental/core/utils/dist_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,36 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
import torch.distributed as dist
9+
10+
11+
def is_dist_available_and_initialized() -> bool:
12+
"""Check if distributed training is available and initialized."""
13+
return dist.is_available() and dist.is_initialized()
14+
15+
16+
def get_rank() -> int:
17+
"""Return the global rank of the current process, else 0."""
18+
if not is_dist_available_and_initialized():
19+
return 0
20+
return dist.get_rank()
21+
22+
23+
def get_local_rank() -> int:
24+
"""Return the local rank of the current process on its node, else 0."""
25+
if not is_dist_available_and_initialized():
26+
return 0
27+
return dist.get_node_local_rank()
28+
29+
30+
def get_world_size() -> int:
31+
"""Get the total number of processes in distributed training."""
32+
if not is_dist_available_and_initialized():
33+
return 1
34+
return dist.get_world_size()
35+
36+
37+
def is_main_process() -> bool:
38+
"""Check if the current process is the main process (rank 0)."""
39+
return get_rank() == 0

0 commit comments

Comments
 (0)