Skip to content

Commit 056d6c5

Browse files
authored
gh-141999: Handle KeyboardInterrupt when sampling in the new tachyon profiler (#142000)
1 parent ea51e74 commit 056d6c5

File tree

3 files changed

+89
-41
lines changed

3 files changed

+89
-41
lines changed

Lib/profiling/sampling/sample.py

Lines changed: 47 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -57,50 +57,56 @@ def sample(self, collector, duration_sec=10):
5757
last_sample_time = start_time
5858
realtime_update_interval = 1.0 # Update every second
5959
last_realtime_update = start_time
60+
interrupted = False
6061

61-
while running_time < duration_sec:
62-
# Check if live collector wants to stop
63-
if hasattr(collector, 'running') and not collector.running:
64-
break
65-
66-
current_time = time.perf_counter()
67-
if next_time < current_time:
68-
try:
69-
stack_frames = self.unwinder.get_stack_trace()
70-
collector.collect(stack_frames)
71-
except ProcessLookupError:
72-
duration_sec = current_time - start_time
62+
try:
63+
while running_time < duration_sec:
64+
# Check if live collector wants to stop
65+
if hasattr(collector, 'running') and not collector.running:
7366
break
74-
except (RuntimeError, UnicodeDecodeError, MemoryError, OSError):
75-
collector.collect_failed_sample()
76-
errors += 1
77-
except Exception as e:
78-
if not self._is_process_running():
79-
break
80-
raise e from None
81-
82-
# Track actual sampling intervals for real-time stats
83-
if num_samples > 0:
84-
actual_interval = current_time - last_sample_time
85-
self.sample_intervals.append(
86-
1.0 / actual_interval
87-
) # Convert to Hz
88-
self.total_samples += 1
89-
90-
# Print real-time statistics if enabled
91-
if (
92-
self.realtime_stats
93-
and (current_time - last_realtime_update)
94-
>= realtime_update_interval
95-
):
96-
self._print_realtime_stats()
97-
last_realtime_update = current_time
98-
99-
last_sample_time = current_time
100-
num_samples += 1
101-
next_time += sample_interval_sec
10267

68+
current_time = time.perf_counter()
69+
if next_time < current_time:
70+
try:
71+
stack_frames = self.unwinder.get_stack_trace()
72+
collector.collect(stack_frames)
73+
except ProcessLookupError:
74+
duration_sec = current_time - start_time
75+
break
76+
except (RuntimeError, UnicodeDecodeError, MemoryError, OSError):
77+
collector.collect_failed_sample()
78+
errors += 1
79+
except Exception as e:
80+
if not self._is_process_running():
81+
break
82+
raise e from None
83+
84+
# Track actual sampling intervals for real-time stats
85+
if num_samples > 0:
86+
actual_interval = current_time - last_sample_time
87+
self.sample_intervals.append(
88+
1.0 / actual_interval
89+
) # Convert to Hz
90+
self.total_samples += 1
91+
92+
# Print real-time statistics if enabled
93+
if (
94+
self.realtime_stats
95+
and (current_time - last_realtime_update)
96+
>= realtime_update_interval
97+
):
98+
self._print_realtime_stats()
99+
last_realtime_update = current_time
100+
101+
last_sample_time = current_time
102+
num_samples += 1
103+
next_time += sample_interval_sec
104+
105+
running_time = time.perf_counter() - start_time
106+
except KeyboardInterrupt:
107+
interrupted = True
103108
running_time = time.perf_counter() - start_time
109+
print("Interrupted by user.")
104110

105111
# Clear real-time stats line if it was being displayed
106112
if self.realtime_stats and len(self.sample_intervals) > 0:
@@ -121,7 +127,7 @@ def sample(self, collector, duration_sec=10):
121127
collector.set_stats(self.sample_interval_usec, running_time, sample_rate, error_rate, mode=self.mode)
122128

123129
expected_samples = int(duration_sec / sample_interval_sec)
124-
if num_samples < expected_samples and not is_live_mode:
130+
if num_samples < expected_samples and not is_live_mode and not interrupted:
125131
print(
126132
f"Warning: missed {expected_samples - num_samples} samples "
127133
f"from the expected total of {expected_samples} "

Lib/test/test_profiling/test_sampling_profiler/test_profiler.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,46 @@ def test_sample_profiler_missed_samples_warning(self):
224224
self.assertIn("Warning: missed", result)
225225
self.assertIn("samples from the expected total", result)
226226

227+
def test_sample_profiler_keyboard_interrupt(self):
228+
mock_unwinder = mock.MagicMock()
229+
mock_unwinder.get_stack_trace.side_effect = [
230+
[
231+
(
232+
1,
233+
[
234+
mock.MagicMock(
235+
filename="test.py", lineno=10, funcname="test_func"
236+
)
237+
],
238+
)
239+
],
240+
KeyboardInterrupt(),
241+
]
242+
243+
with mock.patch(
244+
"_remote_debugging.RemoteUnwinder"
245+
) as mock_unwinder_class:
246+
mock_unwinder_class.return_value = mock_unwinder
247+
profiler = SampleProfiler(
248+
pid=12345, sample_interval_usec=10000, all_threads=False
249+
)
250+
mock_collector = mock.MagicMock()
251+
times = [0.0, 0.01, 0.02, 0.03, 0.04]
252+
with mock.patch("time.perf_counter", side_effect=times):
253+
with io.StringIO() as output:
254+
with mock.patch("sys.stdout", output):
255+
try:
256+
profiler.sample(mock_collector, duration_sec=1.0)
257+
except KeyboardInterrupt:
258+
self.fail(
259+
"KeyboardInterrupt was not handled by the profiler"
260+
)
261+
result = output.getvalue()
262+
self.assertIn("Interrupted by user.", result)
263+
self.assertIn("Captured", result)
264+
self.assertIn("samples", result)
265+
self.assertNotIn("Warning: missed", result)
266+
227267

228268
@force_not_colorized_test_class
229269
class TestPrintSampledStats(unittest.TestCase):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Correctly allow :exc:`KeyboardInterrupt` to stop the process when using
2+
:mod:`!profiling.sampling`.

0 commit comments

Comments
 (0)