Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket unmask speedup #579

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.py
Expand Up @@ -33,6 +33,10 @@
extensions.append(distutils.core.Extension(
"tornado.epoll", ["tornado/epoll.c"]))

# Build the WebSocket unmask optimization
extensions.append(distutils.core.Extension(
"tornado._websocket_unmask", ["tornado/_websocket_unmask.c"]))

version = "2.3.post1"

if major >= 3:
Expand Down
100 changes: 100 additions & 0 deletions tornado/_websocket_unmask.c
@@ -0,0 +1,100 @@
#include <Python.h>

#if PY_MAJOR_VERSION >= 3
#define IS_PY3K
#endif

/**
* Mask/unmask WebSocket data frames.
*
* http://tools.ietf.org/html/rfc6455#section-5.3
*/
static PyObject *
Module_unmask_frame(PyObject *self, PyObject *args)
{
unsigned char *input;
unsigned int *input_multi;
int input_length;
unsigned int mask;
unsigned char *mask_str;
int mask_length;
int i;
PyObject *pyoutput;
unsigned char *output;
unsigned int *output_multi;

if (!PyArg_ParseTuple(args, "s#s#", &input_multi, &input_length, &mask_str, &mask_length)) {
return NULL;
}

if (mask_length != 4) {
PyErr_Format(PyExc_TypeError, "the mask must be a string of length 4, not %d", mask_length);
return NULL;
} else if (input_length == 0) {
#ifdef IS_PY3K
return PyByteArray_FromStringAndSize("", 0);
#else
return PyString_FromString("");
#endif
}

#ifdef IS_PY3K
pyoutput = PyByteArray_FromStringAndSize(NULL, input_length);
#else
pyoutput = PyString_FromStringAndSize(NULL, input_length);
#endif
if (pyoutput == NULL) {
return NULL;
}

mask = mask_str[3] << 24 | mask_str[2] << 16 | mask_str[1] << 8 | mask_str[0];
#ifdef IS_PY3K
output_multi = (unsigned int *) PyByteArray_AS_STRING(pyoutput);
#else
output_multi = (unsigned int *) PyString_AS_STRING(pyoutput);
#endif
if (input_length >= 4) {
// process 4 bytes at once
for (i=0; i<input_length/4; i++) {
*output_multi++ = *input_multi++ ^ mask;
}
}

// process remaining bytes
i = input_length & 3;
if (i) {
input = (unsigned char *) input_multi;
output = (unsigned char *) output_multi;
while (i--) {
*output++ = *input++ ^ *mask_str++;
}
}
return pyoutput;
}

PyMethodDef
methods[] = {
{"unmask_frame", (PyCFunction)Module_unmask_frame, METH_VARARGS, "Unmask WebSocket frame data."},
{NULL, NULL, 0, NULL},
};

#ifdef IS_PY3K
static struct PyModuleDef
_websocket_unmask_module = {
PyModuleDef_HEAD_INIT,
"_websocket_unmask",
NULL,
-1,
methods
};
#endif

PyMODINIT_FUNC
init_websocket_unmask(void)
{
#ifdef IS_PY3K
return PyModule_Create(&_websocket_unmask_module);
#else
Py_InitModule("_websocket_unmask", methods);
#endif
}
1 change: 1 addition & 0 deletions tornado/test/runtests.py
Expand Up @@ -27,6 +27,7 @@
'tornado.test.twisted_test',
'tornado.test.util_test',
'tornado.test.web_test',
'tornado.test.websocket_test',
'tornado.test.wsgi_test',
]

Expand Down
28 changes: 28 additions & 0 deletions tornado/test/websocket_test.py
@@ -0,0 +1,28 @@
from __future__ import absolute_import, division, with_statement
import sys
import unittest

from tornado.websocket import unmask_frame_python
try:
from tornado._websocket_unmask import unmask_frame as unmask_frame_c
except ImportError:
unmask_frame_c = None

class MaskingTests(unittest.TestCase):
def test_masking(self):
if unmask_frame_c is None:
raise TypeError('The optimized mask/unmask method is not available, skipping test')

# make sure the C and Python versions produce the same result
TEST_DATA = (
('1234567890', '1234'),
)
for (data, mask) in TEST_DATA:
encoded1 = unmask_frame_python(data, mask).tostring()
encoded2 = unmask_frame_c(data, mask)
self.assertEquals(encoded1, encoded2)

decoded1 = unmask_frame_python(encoded1, mask).tostring()
decoded2 = unmask_frame_c(encoded2, mask)
self.assertEquals(decoded1, data)
self.assertEquals(decoded1, decoded2)
19 changes: 14 additions & 5 deletions tornado/websocket.py
Expand Up @@ -32,6 +32,18 @@

from tornado.util import bytes_type, b

def unmask_frame_python(data, mask):
mask = array.array("B", mask)
unmasked = array.array("B", data)
for i in xrange(len(data)):
unmasked[i] = unmasked[i] ^ mask[i % 4]
return unmasked

try:
from tornado._websocket_unmask import unmask_frame
except ImportError:
# Optimized version is not available, use (slower) Python version
unmask_frame = unmask_frame_python

class WebSocketHandler(tornado.web.RequestHandler):
"""Subclass this class to create a basic WebSocket handler.
Expand Down Expand Up @@ -564,14 +576,11 @@ def _on_frame_length_64(self, data):
self.stream.read_bytes(4, self._on_masking_key)

def _on_masking_key(self, data):
self._frame_mask = array.array("B", data)
self._frame_mask = data
self.stream.read_bytes(self._frame_length, self._on_frame_data)

def _on_frame_data(self, data):
unmasked = array.array("B", data)
for i in xrange(len(data)):
unmasked[i] = unmasked[i] ^ self._frame_mask[i % 4]

unmasked = unmask_frame(data, self._frame_mask)
if self._frame_opcode_is_control:
# control frames may be interleaved with a series of fragmented
# data frames, so control frames must not interact with
Expand Down