Skip to content

Commit

Permalink
probe: block signals around USB transfers
Browse files Browse the repository at this point in the history
  • Loading branch information
flit committed May 20, 2023
1 parent 1a61989 commit 6a6a443
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 39 deletions.
9 changes: 6 additions & 3 deletions pyocd/probe/picoprobe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pyOCD debugger
# Copyright (c) 2021 Federico Zuccardi Merli
# Copyright (c) 2021 Chris Reed
# Copyright (c) 2021-2022 Chris Reed
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -32,6 +32,7 @@
from ..core.options import OptionInfo
from ..core.plugin import Plugin
from ..utility.mask import parity32_high
from ..utility.signals import ThreadSignalBlocker

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,7 +136,8 @@ def flush_queue(self):
self._queue[:self.PKT_HDR_LEN] = array(
'B', self._qulen.to_bytes(4, 'little'))
try:
self._wr_ep.write(self._queue)
with ThreadSignalBlocker():
self._wr_ep.write(self._queue)
except Exception:
# Anything from the USB layer assumes probe is no longer connected
raise exceptions.ProbeDisconnected(
Expand All @@ -150,7 +152,8 @@ def get_bits(self):
try:
# A single read is enough, as the 8 kB buffer in the Picoprobe can
# contain about 454 ACKs+Register reads, and I never queue more than 256
received = self._rd_ep.read(self._bits)
with ThreadSignalBlocker():
received = self._rd_ep.read(self._bits)
except Exception:
# Anything from the USB layer assumes probe is no longer connected
raise exceptions.ProbeDisconnected(
Expand Down
9 changes: 8 additions & 1 deletion pyocd/probe/pydapaccess/interface/hidapi_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..dap_access_api import DAPAccessIntf
from ....utility.compatibility import to_str_safe
from ....utility.timeout import Timeout
from ....utility.signals import ThreadSignalBlocker

LOG = logging.getLogger(__name__)
TRACE = LOG.getChild("trace")
Expand Down Expand Up @@ -95,6 +96,9 @@ def open(self):
self.thread.start()

def rx_task(self):
# Block all signals on this thread.
ThreadSignalBlocker()

try:
while not self.closed_event.is_set():
self.read_sem.acquire()
Expand Down Expand Up @@ -158,12 +162,15 @@ def write(self, data):
data.extend([0] * (self.packet_size - len(data)))
if not _IS_WINDOWS:
self.read_sem.release()
self.device.write([0] + data)
with ThreadSignalBlocker():
self.device.write([0] + data)

def read(self):
"""@brief Read data on the IN endpoint associated to the HID interface"""
# Windows doesn't use the read thread, so read directly.
if _IS_WINDOWS:
# Note that we don't use ThreadSignalBlocker here because signals cannot be blocked
# at the thread level on Windows (at least via the python API ThreadSignalBlocker uses).
read_data = self.device.read(self.packet_size)

if TRACE.isEnabledFor(logging.DEBUG):
Expand Down
12 changes: 9 additions & 3 deletions pyocd/probe/pydapaccess/interface/pyusb_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from ..dap_access_api import DAPAccessIntf
from ....utility.timeout import Timeout
from ....utility.signals import ThreadSignalBlocker

LOG = logging.getLogger(__name__)
TRACE = LOG.getChild("trace")
Expand Down Expand Up @@ -144,6 +145,9 @@ def start_rx(self):
self.thread.start()

def rx_task(self):
# Block all signals on this thread.
ThreadSignalBlocker()

try:
while not self.closed:
self.read_sem.acquire()
Expand Down Expand Up @@ -204,11 +208,13 @@ def write(self, data):
bmRequest = 0x09 #Set_REPORT (HID class-specific request for transferring data over EP0)
wValue = 0x200 #Issuing an OUT report
wIndex = self.intf_number #mBed Board interface number for HID
self.dev.ctrl_transfer(bmRequestType, bmRequest, wValue, wIndex, data,
timeout=self.DEFAULT_USB_TIMEOUT_MS)
with ThreadSignalBlocker():
self.dev.ctrl_transfer(bmRequestType, bmRequest, wValue, wIndex, data,
timeout=self.DEFAULT_USB_TIMEOUT_MS)
return

self.ep_out.write(data, timeout=self.DEFAULT_USB_TIMEOUT_MS)
with ThreadSignalBlocker():
self.ep_out.write(data, timeout=self.DEFAULT_USB_TIMEOUT_MS)

def read(self):
"""@brief Read data on the IN endpoint associated to the HID interface"""
Expand Down
10 changes: 9 additions & 1 deletion pyocd/probe/pydapaccess/interface/pyusb_v2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..dap_access_api import DAPAccessIntf
from ... import common
from ....utility.timeout import Timeout
from ....utility.signals import ThreadSignalBlocker

LOG = logging.getLogger(__name__)
TRACE = LOG.getChild("trace")
Expand Down Expand Up @@ -162,6 +163,9 @@ def stop_swo(self):
self.is_swo_running = False

def rx_task(self):
# Block all signals on this thread to prevent broken USB transfers.
ThreadSignalBlocker()

try:
while not self.rx_stop_event.is_set():
self.read_sem.acquire()
Expand All @@ -177,6 +181,9 @@ def rx_task(self):
self.rcv_data.append(None)

def swo_rx_task(self):
# Block all signals on this thread to prevent broken USB transfers.
ThreadSignalBlocker()

try:
while not self.swo_stop_event.is_set():
try:
Expand Down Expand Up @@ -218,7 +225,8 @@ def write(self, data):

self.read_sem.release()

self.ep_out.write(data, timeout=self.DEFAULT_USB_TIMEOUT_MS)
with ThreadSignalBlocker():
self.ep_out.write(data, timeout=self.DEFAULT_USB_TIMEOUT_MS)

def read(self):
"""@brief Read data on the IN endpoint."""
Expand Down
70 changes: 39 additions & 31 deletions pyocd/probe/stlink/usb.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pyOCD debugger
# Copyright (c) 2018-2019 Arm Limited
# Copyright (c) 2021 Chris Reed
# Copyright (c) 2021-2022 Chris Reed
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -26,6 +26,7 @@

from ...core import exceptions
from .. import common
from ...utility.signals import ThreadSignalBlocker

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -235,41 +236,48 @@ def transfer(self, cmd, writeData=None, readSize=None, timeout=1000):
paddedCmd = bytearray(self.CMD_SIZE)
paddedCmd[0:len(cmd)] = cmd

try:
# Command phase.
if TRACE.isEnabledFor(logging.DEBUG):
TRACE.debug(" USB CMD> (%d) %s", len(paddedCmd), ' '.join([f'{i:02x}' for i in paddedCmd]))
count = self._ep_out.write(paddedCmd, timeout)
assert count == len(paddedCmd)

# Optional data out phase.
if writeData is not None:
if TRACE.isEnabledFor(logging.DEBUG):
TRACE.debug(" USB OUT> (%d) %s", len(writeData), ' '.join([f'{i:02x}' for i in writeData]))
count = self._ep_out.write(writeData, timeout)
assert count == len(writeData)

# Optional data in phase.
if readSize is not None:
if TRACE.isEnabledFor(logging.DEBUG):
TRACE.debug(" USB IN < (req %d bytes)", readSize)
data = self._read(readSize)
# Block signals while we transfer.
with ThreadSignalBlocker():
try:
# Command phase.
if TRACE.isEnabledFor(logging.DEBUG):
TRACE.debug(" USB IN < (%d) %s", len(data), ' '.join([f'{i:02x}' for i in data]))

# Verify we got all requested data.
if len(data) < readSize:
raise exceptions.ProbeError("received incomplete command response from STLink "
f"(got {len(data)}, expected {readSize}")

return data
except usb.core.USBError as exc:
raise exceptions.ProbeError("USB Error: %s" % exc) from exc
TRACE.debug(" USB CMD> (%d) %s", len(paddedCmd), ' '.join([f'{i:02x}' for i in paddedCmd]))
count = self._ep_out.write(paddedCmd, timeout)
assert count == len(paddedCmd)

# Optional data out phase.
if writeData is not None:
if TRACE.isEnabledFor(logging.DEBUG):
TRACE.debug(" USB OUT> (%d) %s", len(writeData), ' '.join([f'{i:02x}' for i in writeData]))
count = self._ep_out.write(writeData, timeout)
assert count == len(writeData)

# Optional data in phase.
if readSize is not None:
if TRACE.isEnabledFor(logging.DEBUG):
TRACE.debug(" USB IN < (req %d bytes)", readSize)
data = self._read(readSize)
if TRACE.isEnabledFor(logging.DEBUG):
TRACE.debug(" USB IN < (%d) %s", len(data), ' '.join([f'{i:02x}' for i in data]))

# Verify we got all requested data.
if len(data) < readSize:
raise exceptions.ProbeError("received incomplete command response from STLink "
f"(got {len(data)}, expected {readSize}")

return data
except usb.core.USBError as exc:
raise exceptions.ProbeError("USB Error: %s" % exc) from exc
return None

def read_swv(self, size, timeout=1000):
assert self._ep_swv
return bytearray(self._ep_swv.read(size, timeout))
# Block signals while we transfer.
with ThreadSignalBlocker():
try:
return bytearray(self._ep_swv.read(size, timeout))
except usb.core.USBError as exc:
raise exceptions.ProbeError("USB Error: %s" % exc) from exc

def __repr__(self):
return "<{} @ {:#x} vid={:#06x} pid={:#06x} sn={} version={}>".format(
Expand Down
60 changes: 60 additions & 0 deletions pyocd/utility/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# pyOCD debugger
# Copyright (c) 2022 Chris Reed
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import signal
from typing import (Any, Iterable, TYPE_CHECKING)

if TYPE_CHECKING:
from types import TracebackType

class ThreadSignalBlocker:
"""@brief Context manager class to block all signals on the current thread.
Can be used either as a context manager or simply by instantiating the class. All signals are blocked
on the current thread when the class is instantiated (not when entering a context). If used as a context
manager, those signals blocked in the constructor will be restored on context exit.
The ThreadSignalBlocked object is returned as the value for the _with_ statement when entering
a context. Usually it is not needed, but allows for calling restore() to restore blocked signals early
if necessary.
This class can be used on Windows too, but does nothing.
"""

def __init__(self) -> None:
if hasattr(signal, 'pthread_sigmask'):
self._saved_mask = signal.pthread_sigmask(signal.SIG_BLOCK, signal.valid_signals())
else:
self._saved_mask = set()

@property
def saved_signal_mask(self) -> Iterable[int]:
return self._saved_mask

def __enter__(self) -> "ThreadSignalBlocker":
return self

def __exit__(self, exc_type: type, value: Any, traceback: "TracebackType") -> None:
self.restore()

def restore(self) -> None:
"""@brief Restore signals that were blocked in the constructor."""
if hasattr(signal, 'pthread_sigmask'):
signal.pthread_sigmask(signal.SIG_SETMASK, self._saved_mask)

# Prevent restoring a second time on context exit, in case the caller has modified
# the signal mask after we return.
self._saved_mask = set()

0 comments on commit 6a6a443

Please sign in to comment.