Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add incremental updating of open streams count and closed_streams state #1185

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
14 changes: 13 additions & 1 deletion h2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
Objects for controlling the configuration of the HTTP/2 stack.
"""

import logging


class _BooleanConfigOption(object):
"""
Expand Down Expand Up @@ -34,7 +36,17 @@ class DummyLogger(object):
logging functions when no logger is passed into the corresponding object.
"""
def __init__(self, *vargs):
pass
# Disable all logging
self.lvl = logging.CRITICAL + 1

def isEnabledFor(self, lvl):
"""
Dummy logger, so nothing is enabled.
"""
return lvl >= self.lvl

def setLevel(self, lvl):
self.lvl = lvl

def debug(self, *vargs, **kwargs):
"""
Expand Down
43 changes: 26 additions & 17 deletions h2/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
An implementation of a HTTP/2 connection.
"""
import base64
import logging

from enum import Enum, IntEnum

Expand Down Expand Up @@ -292,6 +293,7 @@ def __init__(self, config=None):
self.encoder = Encoder()
self.decoder = Decoder()

self._open_stream_counts = {0: 0, 1: 0}
# This won't always actually do anything: for versions of HPACK older
# than 2.3.0 it does nothing. However, we have to try!
self.decoder.max_header_list_size = self.DEFAULT_MAX_HEADER_LIST_SIZE
Expand Down Expand Up @@ -362,6 +364,8 @@ def __init__(self, config=None):
size_limit=self.MAX_CLOSED_STREAMS
)

self._streams_to_close = list()

# The flow control window manager for the connection.
self._inbound_flow_control_window_manager = WindowManager(
max_window_size=self.local_settings.initial_window_size
Expand All @@ -383,6 +387,13 @@ def __init__(self, config=None):
ExtensionFrame: self._receive_unknown_frame
}

def _increment_open_streams(self, stream_id, incr):
remainder = stream_id % 2
self._open_stream_counts[remainder] += incr

def _close_stream(self, stream_id):
self._streams_to_close.append(stream_id)

def _prepare_for_sending(self, frames):
if not frames:
return
Expand All @@ -393,22 +404,15 @@ def _open_streams(self, remainder):
"""
A common method of counting number of open streams. Returns the number
of streams that are open *and* that have (stream ID % 2) == remainder.
While it iterates, also deletes any closed streams.
Also cleans up closed streams.
"""
count = 0
to_delete = []

for stream_id, stream in self.streams.items():
if stream.open and (stream_id % 2 == remainder):
count += 1
elif stream.closed:
to_delete.append(stream_id)

for stream_id in to_delete:
for stream_id in self._streams_to_close:
stream = self.streams.pop(stream_id)
assert stream.closed
self._closed_streams[stream_id] = stream.closed_by
self._streams_to_close = list()

return count
return self._open_stream_counts[remainder]

@property
def open_outbound_streams(self):
Expand Down Expand Up @@ -467,14 +471,20 @@ def _begin_new_stream(self, stream_id, allowed_ids):
stream_id,
config=self.config,
inbound_window_size=self.local_settings.initial_window_size,
outbound_window_size=self.remote_settings.initial_window_size
outbound_window_size=self.remote_settings.initial_window_size,
increment_open_stream_count_callback=self._increment_open_streams,
close_stream_callback=self._close_stream,
)
self.config.logger.debug("Stream ID %d created", stream_id)
s.max_inbound_frame_size = self.max_inbound_frame_size
s.max_outbound_frame_size = self.max_outbound_frame_size

self.streams[stream_id] = s
self.config.logger.debug("Current streams: %s", self.streams.keys())
# Disable this log if we're not in debug mode, as it can be expensive
# when there are many concurrently open streams
if self.config.logger.isEnabledFor(logging.DEBUG):
self.config.logger.debug(
"Current streams: %s", self.streams.keys())

if outbound:
self.highest_outbound_stream_id = stream_id
Expand Down Expand Up @@ -1025,7 +1035,6 @@ def reset_stream(self, stream_id, error_code=0):

def close_connection(self, error_code=0, additional_data=None,
last_stream_id=None):

"""
Close a connection, emitting a GOAWAY frame.

Expand Down Expand Up @@ -1542,8 +1551,8 @@ def _receive_headers_frame(self, frame):
max_open_streams = self.local_settings.max_concurrent_streams
if (self.open_inbound_streams + 1) > max_open_streams:
raise TooManyStreamsError(
"Max outbound streams is %d, %d open" %
(max_open_streams, self.open_outbound_streams)
"Max inbound streams is %d, %d open" %
(max_open_streams, self.open_inbound_streams)
)

# Let's decode the headers. We handle headers as bytes internally up
Expand Down
86 changes: 84 additions & 2 deletions h2/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class H2StreamStateMachine(object):
:param stream_id: The stream ID of this stream. This is stored primarily
for logging purposes.
"""

def __init__(self, stream_id):
self.state = StreamState.IDLE
self.stream_id = stream_id
Expand Down Expand Up @@ -767,6 +768,55 @@ def send_alt_svc(self, previous_state):
(H2StreamStateMachine.send_on_closed_stream, StreamState.CLOSED),
}

"""
Wraps a stream state change function to ensure that we keep
the parent H2Connection's state in sync
"""


def sync_state_change(func):
def wrapper(self, *args, **kwargs):
# Collect state at the beginning.
start_state = self.state_machine.state
started_open = self.open
started_closed = not started_open

# Do the state change (if any).
result = func(self, *args, **kwargs)

# Collect state at the end.
end_state = self.state_machine.state
ended_open = self.open
ended_closed = not ended_open

# If at any point we've tranwsitioned to the CLOSED state
# from any other state, close our stream.
if end_state == StreamState.CLOSED and start_state != end_state:
if self._close_stream_callback:
self._close_stream_callback(self.stream_id)
# Clear callback so we only call this once per stream
self._close_stream_callback = None

# If we were open, but are now closed, decrement
# the open stream count, and call the close callback.
if started_open and ended_closed:
if self._decrement_open_stream_count_callback:
self._decrement_open_stream_count_callback(self.stream_id,
-1,)
# Clear callback so we only call this once per stream
self._decrement_open_stream_count_callback = None

# If we were closed, but are now open, increment
# the open stream count.
elif started_closed and ended_open:
if self._increment_open_stream_count_callback:
self._increment_open_stream_count_callback(self.stream_id,
1,)
# Clear callback so we only call this once per stream
self._increment_open_stream_count_callback = None
return result
return wrapper


class H2Stream(object):
"""
Expand All @@ -778,22 +828,36 @@ class H2Stream(object):
Attempts to create frames that cannot be sent will raise a
``ProtocolError``.
"""

def __init__(self,
stream_id,
config,
inbound_window_size,
outbound_window_size):
outbound_window_size,
increment_open_stream_count_callback,
close_stream_callback,):
self.state_machine = H2StreamStateMachine(stream_id)
self.stream_id = stream_id
self.max_outbound_frame_size = None
self.request_method = None

# The current value of the outbound stream flow control window
# The current value of the outbound stream flow control window.
self.outbound_flow_control_window = outbound_window_size

# The flow control manager.
self._inbound_window_manager = WindowManager(inbound_window_size)

# Callback to increment open stream count for the H2Connection.
self._increment_open_stream_count_callback = \
increment_open_stream_count_callback

# Callback to decrement open stream count for the H2Connection.
self._decrement_open_stream_count_callback = \
increment_open_stream_count_callback

# Callback to clean up state for the H2Connection once we're closed.
self._close_stream_callback = close_stream_callback

# The expected content length, if any.
self._expected_content_length = None

Expand Down Expand Up @@ -850,6 +914,7 @@ def closed_by(self):
"""
return self.state_machine.stream_closed_by

@sync_state_change
def upgrade(self, client_side):
"""
Called by the connection to indicate that this stream is the initial
Expand All @@ -868,6 +933,7 @@ def upgrade(self, client_side):
self.state_machine.process_input(input_)
return

@sync_state_change
def send_headers(self, headers, encoder, end_stream=False):
"""
Returns a list of HEADERS/CONTINUATION frames to emit as either headers
Expand Down Expand Up @@ -917,6 +983,7 @@ def send_headers(self, headers, encoder, end_stream=False):

return frames

@sync_state_change
def push_stream_in_band(self, related_stream_id, headers, encoder):
"""
Returns a list of PUSH_PROMISE/CONTINUATION frames to emit as a pushed
Expand All @@ -941,6 +1008,7 @@ def push_stream_in_band(self, related_stream_id, headers, encoder):

return frames

@sync_state_change
def locally_pushed(self):
"""
Mark this stream as one that was pushed by this peer. Must be called
Expand All @@ -954,6 +1022,7 @@ def locally_pushed(self):
assert not events
return []

@sync_state_change
def send_data(self, data, end_stream=False, pad_length=None):
"""
Prepare some data frames. Optionally end the stream.
Expand Down Expand Up @@ -981,6 +1050,7 @@ def send_data(self, data, end_stream=False, pad_length=None):

return [df]

@sync_state_change
def end_stream(self):
"""
End a stream without sending data.
Expand All @@ -992,6 +1062,7 @@ def end_stream(self):
df.flags.add('END_STREAM')
return [df]

@sync_state_change
def advertise_alternative_service(self, field_value):
"""
Advertise an RFC 7838 alternative service. The semantics of this are
Expand All @@ -1005,6 +1076,7 @@ def advertise_alternative_service(self, field_value):
asf.field = field_value
return [asf]

@sync_state_change
def increase_flow_control_window(self, increment):
"""
Increase the size of the flow control window for the remote side.
Expand All @@ -1020,6 +1092,7 @@ def increase_flow_control_window(self, increment):
wuf.window_increment = increment
return [wuf]

@sync_state_change
def receive_push_promise_in_band(self,
promised_stream_id,
headers,
Expand All @@ -1044,6 +1117,7 @@ def receive_push_promise_in_band(self,
)
return [], events

@sync_state_change
def remotely_pushed(self, pushed_headers):
"""
Mark this stream as one that was pushed by the remote peer. Must be
Expand All @@ -1057,6 +1131,7 @@ def remotely_pushed(self, pushed_headers):
self._authority = authority_from_headers(pushed_headers)
return [], events

@sync_state_change
def receive_headers(self, headers, end_stream, header_encoding):
"""
Receive a set of headers (or trailers).
Expand Down Expand Up @@ -1091,6 +1166,7 @@ def receive_headers(self, headers, end_stream, header_encoding):
)
return [], events

@sync_state_change
def receive_data(self, data, end_stream, flow_control_len):
"""
Receive some data.
Expand All @@ -1114,6 +1190,7 @@ def receive_data(self, data, end_stream, flow_control_len):
events[0].flow_controlled_length = flow_control_len
return [], events

@sync_state_change
def receive_window_update(self, increment):
"""
Handle a WINDOW_UPDATE increment.
Expand Down Expand Up @@ -1150,6 +1227,7 @@ def receive_window_update(self, increment):

return frames, events

@sync_state_change
def receive_continuation(self):
"""
A naked CONTINUATION frame has been received. This is always an error,
Expand All @@ -1162,6 +1240,7 @@ def receive_continuation(self):
)
assert False, "Should not be reachable"

@sync_state_change
def receive_alt_svc(self, frame):
"""
An Alternative Service frame was received on the stream. This frame
Expand Down Expand Up @@ -1189,6 +1268,7 @@ def receive_alt_svc(self, frame):

return [], events

@sync_state_change
def reset_stream(self, error_code=0):
"""
Close the stream locally. Reset the stream with an error code.
Expand All @@ -1202,6 +1282,7 @@ def reset_stream(self, error_code=0):
rsf.error_code = error_code
return [rsf]

@sync_state_change
def stream_reset(self, frame):
"""
Handle a stream being reset remotely.
Expand All @@ -1217,6 +1298,7 @@ def stream_reset(self, frame):

return [], events

@sync_state_change
def acknowledge_received_data(self, acknowledged_size):
"""
The user has informed us that they've processed some amount of data
Expand Down
2 changes: 1 addition & 1 deletion test/test_basic_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1851,7 +1851,7 @@ def test_stream_repr(self):
"""
Ensure stream string representation is appropriate.
"""
s = h2.stream.H2Stream(4, None, 12, 14)
s = h2.stream.H2Stream(4, None, 12, 14, None, None)
assert repr(s) == "<H2Stream id:4 state:<StreamState.IDLE: 0>>"


Expand Down
Loading