diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index c9bbe2ac014351..d2ee49dd88f8cf 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -58,6 +58,7 @@ def __init__(self, selector=None): def _make_socket_transport(self, sock, protocol, waiter=None, *, extra=None, server=None): + self._ensure_fd_no_transport(sock) return _SelectorSocketTransport(self, sock, protocol, waiter, extra, server) @@ -68,6 +69,7 @@ def _make_ssl_transport( ssl_handshake_timeout=constants.SSL_HANDSHAKE_TIMEOUT, ssl_shutdown_timeout=constants.SSL_SHUTDOWN_TIMEOUT, ): + self._ensure_fd_no_transport(rawsock) ssl_protocol = sslproto.SSLProtocol( self, protocol, sslcontext, waiter, server_side, server_hostname, @@ -80,6 +82,7 @@ def _make_ssl_transport( def _make_datagram_transport(self, sock, protocol, address=None, waiter=None, extra=None): + self._ensure_fd_no_transport(sock) return _SelectorDatagramTransport(self, sock, protocol, address, waiter, extra) diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 22dcfb23083522..796037bcf59c49 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -61,8 +61,10 @@ def setUp(self): def test_make_socket_transport(self): m = mock.Mock() self.loop.add_reader = mock.Mock() + self.loop._ensure_fd_no_transport = mock.Mock() transport = self.loop._make_socket_transport(m, asyncio.Protocol()) self.assertIsInstance(transport, _SelectorSocketTransport) + self.assertEqual(self.loop._ensure_fd_no_transport.call_count, 1) # Calling repr() must not fail when the event loop is closed self.loop.close() @@ -78,8 +80,10 @@ def test_make_ssl_transport_without_ssl_error(self): self.loop.add_writer = mock.Mock() self.loop.remove_reader = mock.Mock() self.loop.remove_writer = mock.Mock() + self.loop._ensure_fd_no_transport = mock.Mock() with self.assertRaises(RuntimeError): self.loop._make_ssl_transport(m, m, m, m) + self.assertEqual(self.loop._ensure_fd_no_transport.call_count, 1) def test_close(self): class EventLoop(BaseSelectorEventLoop):