Skip to content
This repository has been archived by the owner on Jan 13, 2021. It is now read-only.

Commit

Permalink
Remove all headers referred to by 'Connection'
Browse files Browse the repository at this point in the history
  • Loading branch information
Lukasa committed Feb 21, 2015
1 parent ab37185 commit e6c47c5
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 25 deletions.
7 changes: 7 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Release History
===============

Upcoming
--------

*Minor Changes*

- We not only remove the Connection header but all headers it refers to.

0.2.0 (2015-02-07)
------------------

Expand Down
8 changes: 0 additions & 8 deletions hyper/http20/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,6 @@ def putheader(self, header, argument, stream_id=None):
:returns: Nothing.
"""
stream = self._get_stream(stream_id)

# Initially, strip the Connection header. Note that we do this after
# the call to `_get_stream` to ensure that we don't accidentally hide
# bugs just because the user sent a connection header.
if header.lower() == 'connection':
log.debug('Ignoring connection header with value %s', argument)
return

stream.add_header(header, argument)

return
Expand Down
6 changes: 4 additions & 2 deletions hyper/http20/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
FRAME_MAX_LEN, FRAMES, HeadersFrame, DataFrame, PushPromiseFrame,
WindowUpdateFrame, ContinuationFrame, BlockedFrame
)
from .util import get_from_key_value_set
from .util import get_from_key_value_set, h2_safe_headers
import collections
import logging
import zlib
Expand Down Expand Up @@ -250,8 +250,10 @@ def open(self, end):
The `end` flag controls whether this will be the end of the stream, or
whether data will follow.
"""
# Strip any headers invalid in H2.
headers = h2_safe_headers(self.headers)
# Encode the headers.
encoded_headers = self._encoder.encode(self.headers)
encoded_headers = self._encoder.encode(headers)

# It's possible that there is a substantial amount of data here. The
# data needs to go into one HEADERS frame, followed by a number of
Expand Down
14 changes: 14 additions & 0 deletions hyper/http20/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,17 @@ def split_repeated_headers(kvset):
headers[key] = value.split(b'\x00')

return dict(headers)


def h2_safe_headers(headers):
"""
This method takes a set of headers that are provided by the user and
transforms them into a form that is safe for emitting over HTTP/2.
Currently, this strips the Connection header and any header it refers to.
"""
stripped = {i.lower().strip() for k, v in headers if k == 'connection'
for i in v.split(',')}
stripped.add('connection')

return [header for header in headers if header[0] not in stripped]
44 changes: 29 additions & 15 deletions test/test_hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
HPACKDecodingError, HPACKEncodingError, ProtocolError
)
from hyper.http20.window import FlowControlManager
from hyper.http20.util import combine_repeated_headers, split_repeated_headers
from hyper.http20.util import (
combine_repeated_headers, split_repeated_headers, h2_safe_headers
)
from hyper.compat import zlib_compressobj
from hyper.contrib import HTTP20Adapter
import errno
Expand Down Expand Up @@ -1011,20 +1013,6 @@ def test_putheader_puts_headers(self):
('name', 'value'),
]

def test_putheader_ignores_connection(self):
c = HTTP20Connection("www.google.com")

c.putrequest('GET', '/')
c.putheader('Connection', 'keep-alive')
s = c.recent_stream

assert s.headers == [
(':method', 'GET'),
(':scheme', 'https'),
(':authority', 'www.google.com'),
(':path', '/'),
]

def test_endheaders_sends_data(self):
frames = []

Expand Down Expand Up @@ -2048,6 +2036,32 @@ def test_nghttp2_installs_correctly(self):

assert True

def test_stripping_connection_header(self):
headers = [('one', 'two'), ('connection', 'close')]
stripped = [('one', 'two')]

assert h2_safe_headers(headers) == stripped

def test_stripping_related_headers(self):
headers = [
('one', 'two'), ('three', 'four'), ('five', 'six'),
('connection', 'close, three, five')
]
stripped = [('one', 'two')]

assert h2_safe_headers(headers) == stripped

def test_stripping_multiple_connection_headers(self):
headers = [
('one', 'two'), ('three', 'four'), ('five', 'six'),
('connection', 'close'),
('connection', 'three, five')
]
stripped = [('one', 'two')]

assert h2_safe_headers(headers) == stripped


# Some utility classes for the tests.
class NullEncoder(object):
@staticmethod
Expand Down

0 comments on commit e6c47c5

Please sign in to comment.