/
stream.py
217 lines (177 loc) · 5.91 KB
/
stream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""Provide an asychronous equivalent to *input*."""
import os
import sys
import stat
import asyncio
from . import compat
def is_pipe_transport_compatible(pipe):
if compat.platform == 'win32':
return False
try:
fileno = pipe.fileno()
except OSError:
return False
mode = os.fstat(fileno).st_mode
is_char = stat.S_ISCHR(mode)
is_fifo = stat.S_ISFIFO(mode)
is_socket = stat.S_ISSOCK(mode)
if not (is_char or is_fifo or is_socket):
return False
return True
def protect_standard_streams(stream):
if stream._transport is None:
return
try:
fileno = stream._transport.get_extra_info('pipe').fileno()
except (ValueError, OSError):
return
if fileno < 3:
stream._transport._pipe = None
class StandardStreamReaderProtocol(asyncio.StreamReaderProtocol):
def connection_made(self, transport):
# The connection is already made
if self._stream_reader._transport is not None:
return
# Make the connection
super().connection_made(transport)
def connection_lost(self, exc):
# Copy the inner state
state = self.__dict__.copy()
# Call the parent
super().connection_lost(exc)
# Restore the inner state
self.__dict__.update(state)
class StandardStreamReader(asyncio.StreamReader):
__del__ = protect_standard_streams
class StandardStreamWriter(asyncio.StreamWriter):
__del__ = protect_standard_streams
def write(self, data):
if isinstance(data, str):
data = data.encode()
super().write(data)
class NonFileStreamReader:
def __init__(self, stream, *, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
self.stream = stream
self.eof = False
def at_eof(self):
return self.eof
@asyncio.coroutine
def readline(self):
data = yield from self.loop.run_in_executor(None, self.stream.readline)
if isinstance(data, str):
data = data.encode()
self.eof = not data
return data
@asyncio.coroutine
def read(self, n=-1):
data = yield from self.loop.run_in_executor(None, self.stream.read, n)
if isinstance(data, str):
data = data.encode()
self.eof = not data
return data
if compat.PY35:
@asyncio.coroutine
def __aiter__(self):
return self
@asyncio.coroutine
def __anext__(self):
val = yield from self.readline()
if val == b'':
raise StopAsyncIteration
return val
class NonFileStreamWriter:
def __init__(self, stream, *, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
self.loop = loop
self.stream = stream
def write(self, data):
if isinstance(data, bytes):
data = data.decode()
self.stream.write(data)
@asyncio.coroutine
def drain(self):
try:
flush = self.stream.flush
except AttributeError:
pass
else:
yield from self.loop.run_in_executor(None, flush)
@asyncio.coroutine
def open_stantard_pipe_connection(pipe_in, pipe_out, pipe_err, *, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
# Reader
in_reader = StandardStreamReader(loop=loop)
protocol = StandardStreamReaderProtocol(in_reader, loop=loop)
yield from loop.connect_read_pipe(lambda: protocol, pipe_in)
# Out writer
out_write_connect = loop.connect_write_pipe(lambda: protocol, pipe_out)
out_transport, _ = yield from out_write_connect
out_writer = StandardStreamWriter(out_transport, protocol, in_reader, loop)
# Err writer
err_write_connect = loop.connect_write_pipe(lambda: protocol, pipe_err)
err_transport, _ = yield from err_write_connect
err_writer = StandardStreamWriter(err_transport, protocol, in_reader, loop)
# Return
return in_reader, out_writer, err_writer
@asyncio.coroutine
def create_standard_streams(stdin, stdout, stderr, *, loop=None):
if all(map(is_pipe_transport_compatible, (stdin, stdout, stderr))):
return (yield from open_stantard_pipe_connection(
stdin, stdout, stderr, loop=loop))
return (
NonFileStreamReader(stdin, loop=loop),
NonFileStreamWriter(stdout, loop=loop),
NonFileStreamWriter(stderr, loop=loop))
@asyncio.coroutine
def get_standard_streams(*, cache={}, use_stderr=False, loop=None):
if loop is None:
loop = asyncio.get_event_loop()
args = sys.stdin, sys.stdout, sys.stderr
key = args, loop
if cache.get(key) is None:
connection = create_standard_streams(*args, loop=loop)
cache[key] = yield from connection
in_reader, out_writer, err_writer = cache[key]
return in_reader, err_writer if use_stderr else out_writer
@asyncio.coroutine
def ainput(prompt='', *, streams=None, use_stderr=False, loop=None):
"""Asynchronous equivalent to *input*."""
# Get standard streams
if streams is None:
streams = yield from get_standard_streams(
use_stderr=use_stderr, loop=loop)
reader, writer = streams
# Write prompt
writer.write(prompt.encode())
yield from writer.drain()
# Get data
data = yield from reader.readline()
# Decode data
data = data.decode()
# Return or raise EOF
if not data.endswith('\n'):
raise EOFError
return data.rstrip('\n')
@asyncio.coroutine
def aprint(
*values,
sep=None,
end='\n',
flush=False,
streams=None,
use_stderr=False,
loop=None
):
"""Asynchronous equivalent to *print*."""
# Get standard streams
if streams is None:
streams = yield from get_standard_streams(
use_stderr=use_stderr, loop=loop)
_, writer = streams
print(*values, sep=sep, end=end, flush=flush, file=writer)
yield from writer.drain()