Skip to content

Commit

Permalink
Merge pull request #51 from wildfoundry/fix-graceful-clos
Browse files Browse the repository at this point in the history
fixes for #49 and #50
  • Loading branch information
willmcgugan committed May 9, 2018
2 parents 6fbd5d9 + 909c003 commit 2e9924e
Show file tree
Hide file tree
Showing 12 changed files with 108 additions and 105 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).

## [0.2.2] - 2018-05-09

### Fixed
- Fixed handling of non-ws URLs on Windows
- Fixed broken close timeout

## [0.2.1] - 2018-04-03

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion lomond/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import unicode_literals

__version__ = "0.2.1"
__version__ = "0.2.2"
2 changes: 1 addition & 1 deletion lomond/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def validate(self):
"""Check the frame and raise any errors."""
if self.is_control and len(self.payload) > 125:
raise errors.ProtocolError(
"control frames must <= 125 bytes in length"
"control frames must be <= 125 bytes in length"
)
if self.rsv1 or self.rsv2 or self.rsv3:
raise errors.ProtocolError(
Expand Down
5 changes: 5 additions & 0 deletions lomond/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Binary(Message):
"""
__slots__ = ['data']

def __init__(self, data):
self.data = data
super(Binary, self).__init__(Opcode.BINARY)
Expand All @@ -96,6 +97,7 @@ class Text(Message):
"""
__slots__ = ['text']

def __init__(self, text):
self.text = text
super(Text, self).__init__(Opcode.TEXT)
Expand Down Expand Up @@ -123,6 +125,7 @@ class Close(Message):
"""
__slots__ = ['code', 'reason']

def __init__(self, code, reason):
self.code = code
self.reason = reason
Expand Down Expand Up @@ -165,6 +168,7 @@ class Ping(Message):
"""
__slots__ = ['data']

def __init__(self, data):
self.data = data
super(Ping, self).__init__(Opcode.PING)
Expand All @@ -180,6 +184,7 @@ class Pong(Message):
"""
__slots__ = ['data']

def __init__(self, data):
self.data = data
super(Pong, self).__init__(Opcode.PONG)
Expand Down
3 changes: 3 additions & 0 deletions lomond/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ def validate(self, chunk):
class _ReadBytes(_Awaitable):
"""Reads a fixed number of bytes."""
__slots__ = ['remaining']

def __init__(self, count):
self.remaining = count


class _ReadUtf8(_ReadBytes):
"""Reads a fixed number of bytes, validates utf-8."""
__slots__ = ['utf8_validator']

def __init__(self, count, utf8_validator):
self.remaining = count
self.utf8_validator = utf8_validator
Expand All @@ -49,6 +51,7 @@ def validate(self, data):
class _ReadUntil(_Awaitable):
"""Read until a separator."""
__slots__ = ['sep', 'max_bytes']

def __init__(self, sep, max_bytes=None):
self.sep = sep
self.max_bytes = max_bytes
Expand Down
2 changes: 1 addition & 1 deletion lomond/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def parse(self):
response.status_code,
response.status
)
yield response
yield response
92 changes: 35 additions & 57 deletions lomond/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
class _SocketFail(Exception):
"""Used internally to respond to socket fails."""


class _ForceDisconnect(Exception):
"""Used internally when the close timeout is tripped."""

Expand All @@ -54,8 +55,8 @@ def __repr__(self):
return "<ws-session '{}'>".format(self.websocket.url)

@property
def _time(self):
"""Get the time since the socket started."""
def session_time(self):
"""Get the time (in seconds) since the WebSocket session started."""
return time.time() - self._start_time

def close(self):
Expand Down Expand Up @@ -226,50 +227,47 @@ def _send_request(self):
"""Send the request over the wire."""
self.write(self.websocket.build_request())

def _check_poll(self, poll):
def _check_poll(self, poll, session_time):
"""Check if it is time for a poll."""
_time = self._time
_time = session_time
if self._poll_start is None or _time - self._poll_start >= poll:
self._poll_start = _time
return True
else:
return False

def _check_auto_ping(self, ping_rate):
def _check_auto_ping(self, ping_rate, session_time):
"""Check if a ping is required."""
if not ping_rate:
return
current_time = self._time
if current_time > self._next_ping:
if ping_rate and session_time > self._next_ping:
# Calculate next ping time that is in the future.
self._next_ping = (
math.ceil(current_time / ping_rate) * ping_rate
math.ceil(session_time / ping_rate) * ping_rate
)
try:
self.websocket.send_ping()
except errors.WebSocketError:
pass # If the websocket has gone away

def _check_ping_timeout(self, ping_timeout):
def _check_ping_timeout(self, ping_timeout, session_time):
"""Check if the server is not responding to pings."""
if ping_timeout:
time_since_last_pong = self._time - self._last_pong
time_since_last_pong = session_time - self._last_pong
if time_since_last_pong > ping_timeout:
log.debug('ping_timeout time exceeded')
return True
return False

def _check_close_timeout(self, close_timeout):
def _check_close_timeout(self, close_timeout, session_time):
"""Check if the close timeout was tripped."""
if not close_timeout:
return False
sent_close_time = self.websocket.sent_close_time
if (sent_close_time is not None and
self._time >= sent_close_time + close_timeout):
raise _ForceDisconnect(
"server didn't respond to close packet "
"within {}s".format(close_timeout)
)
if close_timeout:
sent_close_time = self.websocket.sent_close_time
if sent_close_time is None:
return
if session_time >= sent_close_time + close_timeout:
raise _ForceDisconnect(
"server didn't respond to close packet "
"within {}s".format(close_timeout)
)

def _recv(self, count):
"""Receive and return pending data from the socket."""
Expand All @@ -293,15 +291,15 @@ def _recv(self, count):
def _regular(self, poll, ping_rate, ping_timeout, close_timeout):
"""Run regularly to do polling / pings."""
# Check for regularly running actions.
if self._check_poll(poll):
if self._check_poll(poll, self.session_time):
yield events.Poll()
self._check_auto_ping(ping_rate)
if self._check_ping_timeout(ping_timeout):
self._check_auto_ping(ping_rate, self.session_time)
if self._check_ping_timeout(ping_timeout, self.session_time):
yield events.Unresponsive()
raise _ForceDisconnect(
'exceeded {:.0f}s ping timeout'.format(ping_timeout)
)
self._check_close_timeout(close_timeout)
self._check_close_timeout(close_timeout, self.session_time)

def _send_pong(self, event):
"""Send a pong message in response to ping event."""
Expand All @@ -313,7 +311,7 @@ def _send_pong(self, event):

def _on_pong(self, event):
"""Record last pong time."""
self._last_pong = self._time
self._last_pong = self.session_time

def _on_ready(self):
"""Called when a ready event is received."""
Expand Down Expand Up @@ -347,7 +345,7 @@ def run(self,
# Create socket and connect to remote server
try:
sock, proxy = self._connect()
self._sock = sock
self._sock = sock
except _SocketFail as error:
yield events.ConnectFail('{}'.format(error))
return
Expand Down Expand Up @@ -380,24 +378,20 @@ def _regular():
return ()

try:
while True:
while not websocket.is_closed:
readable = selector.wait_readable(poll)
for event in _regular():
yield event
if not readable:
continue
data = self._recv(64 * 1024)
if data:
for event in self.websocket.feed(data):
self._on_event(event, auto_pong)
yield event
for event in _regular():
if readable:
data = self._recv(64 * 1024)
if data:
for event in self.websocket.feed(data):
self._on_event(event, auto_pong)
yield event
else:
if websocket.is_active:
for event in _regular():
yield event
else:
self._socket_fail('connection lost')
break

except _ForceDisconnect as error:
self._close_socket()
yield events.Disconnected('disconnected; {}'.format(error))
Expand All @@ -418,19 +412,3 @@ def _regular():
yield events.Disconnected(graceful=True)
finally:
selector.close()


if __name__ == "__main__": # pragma: no cover

# Test with wstest -m echoserver -w ws://127.0.0.1:9001 -d
# Get wstest app from http://autobahn.ws/testsuite/

from .websocket import WebSocket

#ws = WebSocket('wss://echo.websocket.org')
ws = WebSocket('ws://127.0.0.1:9001/')
for event in ws.connect(poll=5):
print(event)
if isinstance(event, events.Poll):
ws.send_text('Hello, World')
ws.send_binary(b'hello world in binary')
20 changes: 7 additions & 13 deletions lomond/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
import json
import logging
import os
import time

import six
from six import text_type
from six.moves.urllib.parse import urlparse

from . import constants
Expand Down Expand Up @@ -117,8 +115,8 @@ def is_active(self):

@property
def sent_close_time(self):
"""The epoch time a close packet was sent (or None if no close
packet has been sent).
"""The time (seconds since session start) when a close packet
was sent (or None if no close packet has been sent).
"""
return self.state.sent_close_time
Expand Down Expand Up @@ -202,14 +200,13 @@ def connect(self,
)
return run_generator


def reset(self):
"""Reset the state."""
self.state = self.State()

__iter__ = connect

def close(self, code=None, reason=None):
def close(self, code=Status.NORMAL, reason=b'goodbye'):
"""Close the websocket.
:param int code: A closing code, which should probably be one of
Expand All @@ -231,13 +228,10 @@ def close(self, code=None, reason=None):
if self.is_closed:
log.debug('%r already closed', self)
else:
if code is None:
code = Status.NORMAL
if reason is None:
reason = b'goodbye'
self._send_close(code, reason)
self.state.closing = True
self.state.sent_close_time = time.time()
if not self.is_closing:
self._send_close(code, reason)
self.state.closing = True
self.state.sent_close_time = self.session.session_time

def _on_close(self, message):
"""Close logic generator."""
Expand Down
7 changes: 2 additions & 5 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,6 @@ def inner(opcode=Opcode.TEXT, payload=b'', fin=1):
return inner


def test_frame_constructor(frame_factory):
assert isinstance(frame_factory(), object)


def test_length_of_frame(frame_factory):
frame = frame_factory(Opcode.TEXT, b'\x00' * 137)
assert len(frame) == 137
Expand All @@ -33,6 +29,7 @@ def test_masking_key():
expected = b'\x81\x8c\xaa\xf7\x7f\x00\xe2\x92\x13l\xc5\xdb_W\xc5\x85\x13d'
assert frame_bytes == expected


def test_repr_of_frame(frame_factory):
assert repr(frame_factory()) == '<frame TEXT (0 bytes) fin=1>'
assert repr(
Expand Down Expand Up @@ -193,7 +190,7 @@ def test_calling_build_close_payload_requires_status():
@pytest.mark.parametrize('init_params, expected_error', [
(
{'opcode': Opcode.PING, 'payload': b'A' * 126},
"control frames must <= 125 bytes in length"
"control frames must be <= 125 bytes in length"
),
(
{'opcode': Opcode.TEXT, 'payload': b'A', 'rsv1': 1},
Expand Down

0 comments on commit 2e9924e

Please sign in to comment.