Skip to content

Commit

Permalink
Merge pull request #73 from wildfoundry/bytearray
Browse files Browse the repository at this point in the history
Bytearray
  • Loading branch information
willmcgugan committed Jul 5, 2018
2 parents f7d0476 + 8e0f71e commit 2e3235f
Show file tree
Hide file tree
Showing 16 changed files with 120 additions and 350 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).


## [0.3.2] - 2018-07-04

### Changed

- Use bytearray internaly to reduce memcpys

## [0.3.1] - 2018-06-27

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion lomond/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import unicode_literals

__version__ = "0.3.1"
__version__ = "0.3.2"
19 changes: 15 additions & 4 deletions lomond/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import zlib

from six import PY2

from .errors import CompressionParameterError


Expand Down Expand Up @@ -66,10 +68,17 @@ def get_wbits(cls, options, key):

def decompress(self, frames):
"""Decompress payload, returned decompressed data."""
data = [
self._decompressobj.decompress(frame.payload)
for frame in frames
]
if PY2:
data = [
self._decompressobj.decompress(bytes(frame.payload))
for frame in frames
]
else:
data = [
self._decompressobj.decompress(frame.payload)
for frame in frames
]

data.append(self._decompressobj.decompress(b"\x00\x00\xff\xff"))
payload = b''.join(data)
if self.reset_decompress:
Expand All @@ -78,6 +87,8 @@ def decompress(self, frames):

def compress(self, payload):
"""Compress payload, return compressed data."""
if PY2:
payload = bytes(payload)
data = (
self._compressobj.compress(payload)
+ self._compressobj.flush(zlib.Z_SYNC_FLUSH)
Expand Down
1 change: 0 additions & 1 deletion lomond/examples/btcticker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from lomond import WebSocket


websocket = WebSocket('wss://ws-feed.gdax.com')

for event in websocket:
Expand Down
7 changes: 4 additions & 3 deletions lomond/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def build(cls, opcode, payload=b'',
mask=True, masking_key=None):
"""Build a WS frame."""
# https://tools.ietf.org/html/rfc6455#section-5.2

payload = bytearray(payload) if isinstance(payload, bytes) else payload
mask_bit = 1 << 7 if mask else 0
byte0 = fin << 7 | rsv1 << 6 | rsv2 << 5 | rsv3 << 4 | opcode
length = len(payload)
Expand All @@ -88,13 +88,14 @@ def build(cls, opcode, payload=b'',
if masking_key is None
else masking_key
)
mask_payload(masking_key, payload)
frame_bytes = b''.join((
header_bytes,
cls._pack_mask(masking_key),
mask_payload(masking_key, payload)
bytes(payload)
))
else:
frame_bytes = header_bytes + payload
frame_bytes = header_bytes + bytes(payload)
return frame_bytes

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion lomond/frame_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def parse(self):
yield header_data

while True:
byte1, byte2 = bytearray((yield self.read(2)))
byte1, byte2 = yield self.read(2)

fin = byte1 >> 7
rsv1 = (byte1 >> 6) & 1
Expand Down
83 changes: 15 additions & 68 deletions lomond/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,77 +10,24 @@
import six


try:
from wsaccel.xormask import XorMaskerSimple
except ImportError:
XorMaskerSimple = None
make_masking_key = partial(os.urandom, 4)


make_masking_key = partial(os.urandom, 4)
if six.PY2:
_XOR_TABLE = [b''.join(chr(a ^ b) for a in range(256)) for b in range(256)]
else:
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]


if XorMaskerSimple is not None:
# Fast C version (works with Py2 and Py3)
def mask_payload(masking_key, data):
return XorMaskerSimple(masking_key).process(data)
def mask_payload(masking_key, data):
"""XOR mask bytes.
else:
# This is about 60 (!) times faster than a simple loop
# -------------------------------------------------------------------------
# here's a brief explanation of how this works.
# we're creating an array of 256 translation tables. This will become of
# significance couple of lines later. In detail, this looks like this:
# [
# ''.join([ 0x00 ^ 0x00, 0x00 ^ 0x01, 0x00 ^ 0x02, ...],
# ''.join([ 0x01 ^ 0x00, 0x01 ^ 0x01, 0x01 ^ 0x02, ...],
# ''.join([ 0x02 ^ 0x00, 0x02 ^ 0x01, 0x02 ^ 0x02, ...]
# ...
# ''.join([ 0xff ^ 0x00, 0xff ^ 0x01, 0xff ^ 0x02, ...]
# ]
# there are a total of 256 rows in this table, because there are a total of
# 256 possible bytes.
if six.PY2:
_XOR_TABLE = [b''.join(chr(a ^ b) for a in range(256)) for b in range(256)]
else:
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
`masking_key` should be bytes.
`data` should be a bytearray, and is mutated.
def mask_payload(masking_key, data):
"""XOR mask bytes."""
a, b, c, d = (_XOR_TABLE[n] for n in bytearray(masking_key))
# there are 4 bytes in our masking key, that's why we are picking 4
# variables from the table. Now, here comes the fun part. We are
# converting `masking_key` to bytearray, which, when iterated over,
# will covert a byte to the corresponding uint8_t, so if, for instance
# the `masking_key` will be given as:
# 'x07\x03\x01\x00'
# then, n will have the value of 7, 3, 1, 0 with each pass of the
# for-loop.
# why is this so significant? Because, as you remember from above,
# _XOR_TABLE has 256 rows, and the first byte in the xor operation
# was changing from 0 to 256. Therefore, key from masking_key converted
# to uint8_t can point us to a translation table for n-th byte from
# `masking_key`.
data_bytes = bytearray(data)
data_bytes[::4] = data_bytes[::4].translate(a)
data_bytes[1::4] = data_bytes[1::4].translate(b)
data_bytes[2::4] = data_bytes[2::4].translate(c)
data_bytes[3::4] = data_bytes[3::4].translate(d)
# great. The rest is quite easy. array.translate expects a bytearray of
# length=256. How convenient! It's exactly what we have. The way it
# works is that it takes an input byte and looks for a byte in the
# replacement table. So in our case, the replacement table will contain
# XOR'ed value of this byte by the masking key. Now you may wonder -
# why are there 4 iterations of this? Well, because there are 4
# different translation tables for 4 bytes of our masking key - if we
# wouldn't do this, then we would mess up our input data. So, for the
# first byte of masking key, we do the following (O = leave original
# byte, R = replace byte):
# 0 1 2 3 4 5 6 7 8 ( byte no. )
# R O O O R O O O R
# then, for the second byte of the mask, we do:
# 0 1 2 3 4 5 6 7 8 ( byte no. )
# O R O O O R O O O
# please note, that even though byte 0 is marked as 'O', it has already
# been replaced in the previous step
return bytes(data_bytes)
# and voila!
"""
a, b, c, d = (_XOR_TABLE[n] for n in bytearray(masking_key))
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
data[2::4] = data[2::4].translate(c)
data[3::4] = data[3::4].translate(d)
2 changes: 1 addition & 1 deletion lomond/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def build(cls, frames, decompress=None):
if first_frame.rsv1 and decompress:
payload = cls.decompress_frames(frames, decompress)
else:
payload = b''.join(frame.payload for frame in frames)
payload = b''.join(bytes(frame.payload) for frame in frames)
if opcode == Opcode.BINARY:
return Binary(payload)
elif opcode == Opcode.TEXT:
Expand Down
31 changes: 14 additions & 17 deletions lomond/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ParseEOF(ParseError):
class _Awaitable(object):
"""An operation that effectively suspends the coroutine."""
# Analogous to Python3 asyncio concept
__slots__ = []

def validate(self, chunk):
"""Raise any ParseErrors"""
Expand All @@ -43,7 +44,7 @@ def __init__(self, count, utf8_validator):
self.utf8_validator = utf8_validator

def validate(self, data):
valid, _, _, _ = self.utf8_validator.validate(data)
valid, _, _, _ = self.utf8_validator.validate(bytes(data))
if not valid:
raise ParseError('invalid utf8')

Expand Down Expand Up @@ -86,8 +87,7 @@ class Parser(object):
def __init__(self):
self._gen = None
self._awaiting = None
self._buffer = [] # Buffer for reads
self._until = b'' # Buffer for read untils
self._buffer = bytearray() # Buffer for reads
self._eof = False
self.reset()

Expand Down Expand Up @@ -121,7 +121,6 @@ def feed(self, data):
:param bytes data: Data to parse.
"""

def _check_length(pos):
try:
self._awaiting.check_length(pos)
Expand All @@ -136,6 +135,7 @@ def _check_length(pos):
ParseError('unexpected eof of file')
)

_buffer = self._buffer
pos = 0
while pos < len(data):
# Awaiting a read of a fixed number of bytes
Expand All @@ -153,42 +153,39 @@ def _check_length(pos):
# Raises an exception in parse()
self._awaiting = self._gen.throw(error)
# Add to buffer
self._buffer.append(chunk)
_buffer.extend(chunk)
remaining -= chunk_size
if remaining:
# Await more bytes
self._awaiting.remaining = remaining
else:
# Got all the bytes we need in buffer
send_bytes = b''.join(self._buffer)
del self._buffer[:]
# Send to coroutine, get new 'awaitable'
self._awaiting = self._gen.send(send_bytes)
self._awaiting = self._gen.send(_buffer[:])
del _buffer[:]

# Awaiting a read until a terminator
elif isinstance(self._awaiting, _ReadUntil):
# Reading to separator
chunk = data[pos:]
self._until += chunk
_buffer.extend(chunk)
sep = self._awaiting.sep
sep_index = self._until.find(sep)
sep_index = _buffer.find(sep)

if sep_index == -1:
# Separator not found, advance position
pos += len(chunk)
_check_length(len(self._until))
_check_length(len(_buffer))
else:
# Found separator
# Get data prior to and including separator
sep_index += len(sep)
_check_length(sep_index)
send_bytes = self._until[:sep_index]
# Reset data, to continue parsing
data = self._until[sep_index:]
self._until = b''
data = _buffer[sep_index:]
pos = 0
# Send bytes to coroutine, get new 'awaitable'
self._awaiting = self._gen.send(send_bytes)
self._awaiting = self._gen.send(_buffer[:sep_index])
del _buffer[:]

# Yield any non-awaitables...
while not isinstance(self._awaiting, _Awaitable):
Expand Down Expand Up @@ -226,6 +223,6 @@ def parse(self):
data = yield self.read(2)
yield data
parser = TestParser()
for b in (b'head', b'ers: example', b'\r\n', b'\r\n', b'12', b'34', b'5', b'678', b'', b'90'):
for b in (b'head', b'ers: example', b'\r\n', b'\r\n', b'12', b'34', b'5', b'678', b'90'):
for frame in parser.feed(b):
print(repr(frame))
10 changes: 10 additions & 0 deletions lomond/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ def __init__(self, socket):
"""Construct with an open socket."""
self._socket = socket

def wait(self, max_bytes, timeout=0.0):
"""Block until socket is readable or a timeout occurs. Return
a tuple of <readable>, <max bytes>.
"""
if hasattr(self._socket, 'pending') and self._socket.pending():
return True, self._socket.pending()
readable = self.wait_readable(timeout=timeout)
return readable, max_bytes

def wait_readable(self, timeout=0.0):
"""Block until socket is readable or a timeout occurs, return
`True` if the socket is readable, or `False` if the timeout
Expand Down
26 changes: 10 additions & 16 deletions lomond/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class WebsocketSession(object):
"""Manages the mechanics of running the websocket."""
_selector_cls = selectors.PlatformSelector

BUFFER_SIZE = 64 * 1024

def __init__(self, websocket):
self.websocket = websocket
self._address = (websocket.host, websocket.port)
Expand All @@ -50,6 +52,7 @@ def __init__(self, websocket):
self._last_pong = None
self._start_time = None
self._ready = False
self._buffer = bytearray(self.BUFFER_SIZE)

def __repr__(self):
return "<ws-session '{}'>".format(self.websocket.url)
Expand Down Expand Up @@ -99,13 +102,13 @@ def write(self, data):

def send(self, opcode, data):
"""Send a WS Frame."""
frame = Frame(opcode, payload=data)
frame = Frame(opcode, payload=bytearray(data))
self.write(frame.to_bytes())
log.debug(' SRV <- CLI : %r', frame)

def send_compressed(self, opcode, data):
"""Send a compressed WS Frame."""
frame = Frame(opcode, payload=data, rsv1=1)
frame = Frame(opcode, payload=bytearray(data), rsv1=1)
self.write(frame.to_bytes())
log.debug(' SRV <- CLI : %r', frame)

Expand Down Expand Up @@ -290,19 +293,10 @@ def _check_close_timeout(self, close_timeout, session_time):
def _recv(self, count):
"""Receive and return pending data from the socket."""
if self._sock is None:
return b''
return bytearray(b'')
try:
if hasattr(self._sock, 'pending'):
# exhaust ssl buffer
recv_bytes = []
while count:
data = self._sock.recv(count)
recv_bytes.append(data)
count = self._sock.pending()
return b''.join(recv_bytes)
else:
# Plain socket recv
return self._sock.recv(count)
_recv_count = self._sock.recv_into(self._buffer, count)
return memoryview(self._buffer)[:_recv_count]
except socket.error as error:
log.debug('error in _recv', exc_info=True)
self._socket_fail('recv fail; {}', error)
Expand Down Expand Up @@ -398,11 +392,11 @@ def _regular():

try:
while not websocket.is_closed:
readable = selector.wait_readable(poll)
readable, max_bytes = selector.wait(self.BUFFER_SIZE, poll)
for event in _regular():
yield event
if readable:
data = self._recv(64 * 1024)
data = self._recv(max_bytes)
if data:
for event in self.websocket.feed(data):
self._on_event(event, auto_pong)
Expand Down
Loading

0 comments on commit 2e3235f

Please sign in to comment.