-
Notifications
You must be signed in to change notification settings - Fork 97
/
cmd_log.py
211 lines (181 loc) · 5.96 KB
/
cmd_log.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import argparse
import logging
import re
import sys
import threading
import time
from queue import Queue
from typing import List, Optional, TextIO, Tuple
from torchx import specs
from torchx.cli.cmd_base import SubCommand
from torchx.cli.colors import ENDC, GREEN
from torchx.runner import get_runner, Runner
from torchx.schedulers.api import Stream
from torchx.specs.api import is_started
from torchx.specs.builders import make_app_handle
from torchx.util.types import none_throws
logger: logging.Logger = logging.getLogger(__name__)
ID_FORMAT = "SCHEDULER://[SESSION_NAME]/APP_ID/[ROLE_NAME/[REPLICA_IDS,...]]"
def validate(job_identifier: str) -> None:
if not re.match(r"^\w+://[^/]*/[^/]+(/[^/]+(/(\d+,?)+)?)?$", job_identifier):
logger.error(
f"{job_identifier} is not of the form {ID_FORMAT}",
)
sys.exit(1)
def _prefix_line(prefix: str, line: str) -> str:
"""
_prefix_line ensure the prefix is still present even when dealing with return characters
"""
if "\r" in line:
line = line.replace("\r", f"\r{prefix}")
if "\n" in line[:-1]:
line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:]
if not line.startswith("\r"):
line = f"{prefix}{line}"
return line
def print_log_lines(
file: TextIO,
runner: Runner,
app_handle: str,
role_name: str,
replica_id: int,
regex: str,
should_tail: bool,
exceptions: "Queue[Exception]",
streams: Optional[Stream],
) -> None:
try:
for line in runner.log_lines(
app_handle,
role_name,
replica_id,
regex,
should_tail=should_tail,
streams=streams,
):
prefix = f"{GREEN}{role_name}/{replica_id}{ENDC} "
print(_prefix_line(prefix, line), file=file, end="", flush=True)
except Exception as e:
exceptions.put(e)
raise
def get_logs(
file: TextIO,
identifier: str,
regex: Optional[str],
should_tail: bool = False,
runner: Optional[Runner] = None,
streams: Optional[Stream] = None,
) -> None:
validate(identifier)
scheduler_backend, _, path_str = identifier.partition("://")
# path is of the form ["", "app_id", "master", "0"]
path = path_str.split("/")
session_name = path[0] or "default"
app_id = path[1]
role_name = path[2] if len(path) > 2 else None
if not runner:
runner = get_runner()
app_handle = make_app_handle(scheduler_backend, session_name, app_id)
if len(path) == 4:
replica_ids = [(role_name, int(id)) for id in path[3].split(",") if id]
else:
display_waiting = True
while True:
status = runner.status(app_handle)
if status and is_started(status.state):
break
elif display_waiting:
logger.info("Waiting for app state response before fetching logs...")
display_waiting = False
time.sleep(1)
app = none_throws(runner.describe(app_handle))
# print all replicas for the role
replica_ids = find_role_replicas(app, role_name)
if not replica_ids:
valid_ids = "\n".join(
[
f" {idx}: {scheduler_backend}://{app_id}/{role.name}"
for idx, role in enumerate(app.roles)
]
)
logger.error(
f"No role [{role_name}] found for app: {app.name}."
f" Did you mean one of the following:\n{valid_ids}",
)
sys.exit(1)
threads = []
exceptions = Queue()
for role_name, replica_id in replica_ids:
thread = threading.Thread(
target=print_log_lines,
args=(
file,
runner,
app_handle,
role_name,
replica_id,
regex,
should_tail,
exceptions,
streams,
),
)
thread.daemon = True
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
# Retrieve all exceptions, print all except one and raise the first recorded exception
threads_exceptions = []
while not exceptions.empty():
threads_exceptions.append(exceptions.get())
if len(threads_exceptions) > 0:
for i in range(1, len(threads_exceptions)):
logger.error(threads_exceptions[i])
raise threads_exceptions[0]
def find_role_replicas(
app: specs.AppDef, role_name: Optional[str]
) -> List[Tuple[str, int]]:
role_replicas = []
for role in app.roles:
if role_name is None or role_name == role.name:
for i in range(role.num_replicas):
role_replicas.append((role.name, i))
return role_replicas
class CmdLog(SubCommand):
def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
subparser.add_argument(
"--regex",
type=str,
help="regex filter",
)
subparser.add_argument(
"-t",
"--tail",
action="store_true",
help="Tail logs",
)
subparser.add_argument(
"--streams",
type=Stream,
choices=list(s.value for s in Stream),
default=None,
help="IO streams to use. Default is scheduler specific.",
)
subparser.add_argument(
"identifier",
type=str,
metavar=ID_FORMAT,
help="identifiers for the roles and replicas to log",
)
def run(self, args: argparse.Namespace) -> None:
get_logs(
sys.stdout, args.identifier, args.regex, args.tail, streams=args.streams
)