Skip to content

Commit

Permalink
Implement a more robust protection against the closing of standard st…
Browse files Browse the repository at this point in the history
…reams (issue #117, PR #118)
  • Loading branch information
vxgmichel committed May 1, 2024
2 parents efb531f + a2f6924 commit bc84ece
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 61 deletions.
19 changes: 10 additions & 9 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ jobs:
Quality:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- uses: pre-commit/action@v2.0.0
- uses: pre-commit/action@v3.0.1

Tests:
runs-on: ${{ matrix.os }}
Expand All @@ -26,9 +26,9 @@ jobs:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install setuptools
Expand All @@ -37,20 +37,21 @@ jobs:
- name: Install test requirements
run: pip install -r test-requirements.txt
- name: Run tests
run: python setup.py test --addopts "--cov-report xml"
run: python setup.py test
- name: Upload coverage
uses: codecov/codecov-action@v1
uses: codecov/codecov-action@v4
with:
env_vars: OS,PYTHON
token: ${{ secrets.CODECOV_TOKEN }}

Release:
runs-on: ubuntu-latest
needs: [Quality, Tests]
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: 3.11
- name: Build source distribution
Expand Down
51 changes: 32 additions & 19 deletions aioconsole/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
from . import compat


class ProtectedPipe:
"""Wrapper to protect a pipe from being closed."""

def __init__(self, pipe):
self.pipe = pipe

def fileno(self):
return self.pipe.fileno()

def close(self):
pass


def is_pipe_transport_compatible(pipe):
if compat.platform == "win32":
return False
Expand Down Expand Up @@ -65,17 +78,6 @@ def daemon():
raise exc.args[0]


def protect_standard_streams(stream):
if stream._transport is None:
return
try:
fileno = stream._transport.get_extra_info("pipe").fileno()
except (ValueError, OSError):
return
if fileno < 3:
stream._transport._pipe = None


class StandardStreamReaderProtocol(asyncio.StreamReaderProtocol):
def connection_made(self, transport):
# The connection is already made
Expand All @@ -94,8 +96,6 @@ def connection_lost(self, exc):


class StandardStreamReader(asyncio.StreamReader):
__del__ = protect_standard_streams

async def readuntil(self, separator=b"\n"):
# Re-implement `readuntil` to work around self._limit.
# The limit is still useful to prevent the internal buffer
Expand All @@ -115,7 +115,18 @@ async def readuntil(self, separator=b"\n"):


class StandardStreamWriter(asyncio.StreamWriter):
__del__ = protect_standard_streams
def __del__(self):
# No `__del__` method for StreamWriter in Python 3.10 and before
try:
parent_del = super().__del__
except AttributeError:
return
# Do not attempt to close the transport if the loop is closed
try:
asyncio.get_running_loop()
except RuntimeError:
return
parent_del()

def write(self, data):
if isinstance(data, str):
Expand Down Expand Up @@ -224,16 +235,18 @@ async def open_standard_pipe_connection(pipe_in, pipe_out, pipe_err, *, loop=Non
# Reader
in_reader = StandardStreamReader(loop=loop)
protocol = StandardStreamReaderProtocol(in_reader, loop=loop)
await loop.connect_read_pipe(lambda: protocol, pipe_in)
await loop.connect_read_pipe(lambda: protocol, ProtectedPipe(pipe_in))

# Out writer
out_write_connect = loop.connect_write_pipe(lambda: protocol, pipe_out)
out_transport, _ = await out_write_connect
out_transport, _ = await loop.connect_write_pipe(
lambda: protocol, ProtectedPipe(pipe_out)
)
out_writer = StandardStreamWriter(out_transport, protocol, in_reader, loop)

# Err writer
err_write_connect = loop.connect_write_pipe(lambda: protocol, pipe_err)
err_transport, _ = await err_write_connect
err_transport, _ = await loop.connect_write_pipe(
lambda: protocol, ProtectedPipe(pipe_err)
)
err_writer = StandardStreamWriter(err_transport, protocol, in_reader, loop)

# Set the write buffer limits to zero
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"pytest-asyncio",
"pytest-cov",
"pytest-repeat",
'uvloop; python_implementation != "PyPy" and sys_platform != "win32"',
],
license="GPLv3",
python_requires=">=3.8",
Expand Down
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pytest
pytest-asyncio
pytest-cov
pytest-repeat
uvloop; python_implementation != "PyPy" and sys_platform != "win32"
36 changes: 36 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
import asyncio


def import_uvloop():
try:
import uvloop
except ImportError:
return None
else:
return uvloop


@pytest.fixture(
params=(
"default",
"uvloop",
),
)
def event_loop_policy(request):
if request.param == "default":
return asyncio.DefaultEventLoopPolicy()
elif request.param == "uvloop":
uvloop = import_uvloop()
if uvloop is None:
pytest.skip("uvloop is not installed")
return uvloop.EventLoopPolicy()
return request.param


@pytest.fixture
def is_uvloop(event_loop_policy):
uvloop = import_uvloop()
if uvloop is None:
return False
return isinstance(event_loop_policy, uvloop.EventLoopPolicy)
2 changes: 1 addition & 1 deletion tests/test_apython.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_apython_with_stdout_logs(capfd, use_readline):
assert err == "test\n>>> 7\n>>> \n"


def test_apython_server(capfd, event_loop, monkeypatch):
def test_apython_server(capfd):
def run_forever(self, orig=InteractiveEventLoop.run_forever):
if self.console_server is not None:
self.call_later(0, self.stop)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def say_hello(reader, writer, name=None):
"input_string, expected", list(testdata.values()), ids=list(testdata.keys())
)
@pytest.mark.asyncio
async def test_async_cli(event_loop, monkeypatch, input_string, expected):
async def test_async_cli(monkeypatch, input_string, expected):
monkeypatch.setattr("sys.ps1", "[Hello!] ", raising=False)
monkeypatch.setattr("sys.stdin", io.StringIO(input_string))
monkeypatch.setattr("sys.stderr", io.StringIO())
Expand Down
37 changes: 19 additions & 18 deletions tests/test_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,19 @@


@contextmanager
def stdcontrol(event_loop, monkeypatch):
def stdcontrol(monkeypatch):
# PS1
monkeypatch.setattr("sys.ps1", "[Hello!]", raising=False)
# Stdin control
stdin_read, stdin_write = os.pipe()
monkeypatch.setattr("sys.stdin", open(stdin_read))
writer = NonFileStreamWriter(open(stdin_write, "w"), loop=event_loop)
writer = NonFileStreamWriter(open(stdin_write, "w"))
# Stdout control
monkeypatch.setattr(sys, "stdout", io.StringIO())
# Stderr control
stderr_read, stderr_write = os.pipe()
monkeypatch.setattr("sys.stderr", open(stderr_write, "w"))
reader = NonFileStreamReader(open(stderr_read), loop=event_loop)
reader = NonFileStreamReader(open(stderr_read))
# Yield
yield reader, writer
# Check
Expand All @@ -40,17 +40,18 @@ async def assert_stream(stream, expected, loose=False):


@pytest.fixture(params=["unix", "not-unix"])
def signaling(request, monkeypatch, event_loop):
async def signaling(request, monkeypatch):
if request.param == "not-unix":
event_loop = asyncio.get_running_loop()
m = Mock(side_effect=NotImplementedError)
monkeypatch.setattr(event_loop, "add_signal_handler", m)
monkeypatch.setattr(event_loop, "remove_signal_handler", m)
yield request.param


@pytest.mark.asyncio
async def test_interact_simple(event_loop, monkeypatch):
with stdcontrol(event_loop, monkeypatch) as (reader, writer):
async def test_interact_simple(monkeypatch):
with stdcontrol(monkeypatch) as (reader, writer):
banner = "A BANNER"
writer.write("1+1\n")
await writer.drain()
Expand All @@ -62,8 +63,8 @@ async def test_interact_simple(event_loop, monkeypatch):


@pytest.mark.asyncio
async def test_interact_traceback(event_loop, monkeypatch):
with stdcontrol(event_loop, monkeypatch) as (reader, writer):
async def test_interact_traceback(monkeypatch):
with stdcontrol(monkeypatch) as (reader, writer):
banner = "A BANNER"
writer.write("1/0\n")
await writer.drain()
Expand All @@ -81,8 +82,8 @@ async def test_interact_traceback(event_loop, monkeypatch):


@pytest.mark.asyncio
async def test_interact_syntax_error(event_loop, monkeypatch):
with stdcontrol(event_loop, monkeypatch) as (reader, writer):
async def test_interact_syntax_error(monkeypatch):
with stdcontrol(monkeypatch) as (reader, writer):
writer.write("a b\n")
await writer.drain()
writer.stream.close()
Expand Down Expand Up @@ -114,8 +115,8 @@ async def test_interact_syntax_error(event_loop, monkeypatch):


@pytest.mark.asyncio
async def test_interact_keyboard_interrupt(event_loop, monkeypatch, signaling):
with stdcontrol(event_loop, monkeypatch) as (reader, writer):
async def test_interact_keyboard_interrupt(monkeypatch, signaling):
with stdcontrol(monkeypatch) as (reader, writer):
# Start interaction
banner = "A BANNER"
task = asyncio.ensure_future(interact(banner=banner, stop=False))
Expand All @@ -139,8 +140,8 @@ async def test_interact_keyboard_interrupt(event_loop, monkeypatch, signaling):


@pytest.mark.asyncio
async def test_broken_pipe(event_loop, monkeypatch, signaling):
with stdcontrol(event_loop, monkeypatch) as (reader, writer):
async def test_broken_pipe(monkeypatch, signaling):
with stdcontrol(monkeypatch) as (reader, writer):
# Start interaction
banner = "A BANNER"
task = asyncio.ensure_future(interact(banner=banner, stop=False))
Expand All @@ -156,8 +157,8 @@ async def test_broken_pipe(event_loop, monkeypatch, signaling):


@pytest.mark.asyncio
async def test_interact_multiple_indented_lines(event_loop, monkeypatch):
with stdcontrol(event_loop, monkeypatch) as (reader, writer):
async def test_interact_multiple_indented_lines(monkeypatch):
with stdcontrol(monkeypatch) as (reader, writer):
banner = "A BANNER"
writer.write("def x():\n print(1)\n print(2)\n\nx()\n")
await writer.drain()
Expand All @@ -168,8 +169,8 @@ async def test_interact_multiple_indented_lines(event_loop, monkeypatch):


@pytest.mark.asyncio
async def test_interact_cancellation(event_loop, monkeypatch):
with stdcontrol(event_loop, monkeypatch) as (reader, writer):
async def test_interact_cancellation(monkeypatch):
with stdcontrol(monkeypatch) as (reader, writer):
banner = "A BANNER"
task = asyncio.ensure_future(interact(banner=banner, stop=False))
# Wait for banner
Expand Down
6 changes: 3 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.mark.asyncio
async def test_server(event_loop):
async def test_server():
server = await start_console_server(host="127.0.0.1", port=0, banner="test")
address = server.sockets[0].getsockname()

Expand All @@ -30,7 +30,7 @@ async def test_server(event_loop):


@pytest.mark.asyncio
async def test_uds_server(event_loop, tmpdir_factory):
async def test_uds_server(tmpdir_factory):
path = str(tmpdir_factory.mktemp("uds") / "my_uds")

# Not available on windows
Expand Down Expand Up @@ -60,7 +60,7 @@ async def test_uds_server(event_loop, tmpdir_factory):


@pytest.mark.asyncio
async def test_invalid_server(event_loop):
async def test_invalid_server():
with pytest.raises(ValueError):
await start_console_server()
with pytest.raises(ValueError):
Expand Down
Loading

0 comments on commit bc84ece

Please sign in to comment.