Skip to content

Commit

Permalink
Add WebSocket::receiveFrame() that appends to a Poco::Buffer<char>
Browse files Browse the repository at this point in the history
  • Loading branch information
Tor Lillqvist committed Mar 7, 2015
1 parent 4336528 commit 08a748d
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 50 deletions.
16 changes: 16 additions & 0 deletions Net/include/Poco/Net/WebSocket.h
Expand Up @@ -23,6 +23,7 @@
#include "Poco/Net/Net.h"
#include "Poco/Net/StreamSocket.h"
#include "Poco/Net/HTTPCredentials.h"
#include "Poco/Buffer.h"


namespace Poco {
Expand Down Expand Up @@ -221,6 +222,21 @@ class Net_API WebSocket: public StreamSocket
/// The frame flags and opcode (FrameFlags and FrameOpcodes)
/// is stored in flags.

int receiveFrame(Poco::Buffer<char>& buffer, int& flags);
/// Receives a frame from the socket and stores it
/// after any previous content in buffer.
///
/// Returns the number of bytes received.
/// A return value of 0 means that the peer has
/// shut down or closed the connection.
///
/// Throws a TimeoutException if a receive timeout has
/// been set and nothing is received within that interval.
/// Throws a NetException (or a subclass) in case of other errors.
///
/// The frame flags and opcode (FrameFlags and FrameOpcodes)
/// is stored in flags.

Mode mode() const;
/// Returns WS_SERVER if the WebSocket is a server-side
/// WebSocket, or WS_CLIENT otherwise.
Expand Down
7 changes: 7 additions & 0 deletions Net/include/Poco/Net/WebSocketImpl.h
Expand Up @@ -21,6 +21,7 @@


#include "Poco/Net/StreamSocketImpl.h"
#include "Poco/Buffer.h"
#include "Poco/Random.h"


Expand All @@ -43,6 +44,9 @@ class Net_API WebSocketImpl: public StreamSocketImpl
virtual int receiveBytes(void* buffer, int length, int flags);
/// Receives a WebSocket protocol frame.

virtual int receiveBytes(Poco::Buffer<char>& buffer, int flags);
/// Receives a WebSocket protocol frame.

virtual SocketImpl* acceptConnection(SocketAddress& clientAddr);
virtual void connect(const SocketAddress& address);
virtual void connect(const SocketAddress& address, const Poco::Timespan& timeout);
Expand Down Expand Up @@ -77,6 +81,9 @@ class Net_API WebSocketImpl: public StreamSocketImpl
MAX_HEADER_LENGTH = 14
};

int receiveHeader(char mask[4], bool& useMask);
int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask);

int receiveNBytes(void* buffer, int bytes);
virtual ~WebSocketImpl();

Expand Down
9 changes: 9 additions & 0 deletions Net/src/WebSocket.cpp
Expand Up @@ -29,6 +29,7 @@
#include "Poco/String.h"
#include "Poco/Random.h"
#include "Poco/StreamCopier.h"
#include "Poco/Buffer.h"
#include <sstream>


Expand Down Expand Up @@ -114,6 +115,14 @@ int WebSocket::receiveFrame(void* buffer, int length, int& flags)
}


int WebSocket::receiveFrame(Poco::Buffer<char>& buffer, int& flags)
{
int n = static_cast<WebSocketImpl*>(impl())->receiveBytes(buffer, 0);
flags = static_cast<WebSocketImpl*>(impl())->frameFlags();
return n;
}


WebSocket::Mode WebSocket::mode() const
{
return static_cast<WebSocketImpl*>(impl())->mustMaskPayload() ? WS_CLIENT : WS_SERVER;
Expand Down
118 changes: 68 additions & 50 deletions Net/src/WebSocketImpl.cpp
Expand Up @@ -104,7 +104,7 @@ int WebSocketImpl::sendBytes(const void* buffer, int length, int flags)
}


int WebSocketImpl::receiveBytes(void* buffer, int length, int)
int WebSocketImpl::receiveHeader(char mask[4], bool& useMask)
{
char header[MAX_HEADER_LENGTH];
int n = receiveNBytes(header, 2);
Expand All @@ -114,82 +114,100 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int)
return n;
}
poco_assert (n == 2);
Poco::UInt8 flags = static_cast<Poco::UInt8>(header[0]);
_frameFlags = flags;
Poco::UInt8 lengthByte = static_cast<Poco::UInt8>(header[1]);
int maskOffset = 0;
if (lengthByte & FRAME_FLAG_MASK) maskOffset += 4;
useMask = ((lengthByte & FRAME_FLAG_MASK) != 0);
int payloadLength;
lengthByte &= 0x7f;
if (lengthByte > 0 || maskOffset > 0)
if (lengthByte == 127)
{
if (lengthByte + 2 + maskOffset < MAX_HEADER_LENGTH)
n = receiveNBytes(header + 2, 8);
if (n <= 0)
{
n = receiveNBytes(header + 2, lengthByte + maskOffset);
_frameFlags = 0;
return n;
}
else
{
n = receiveNBytes(header + 2, MAX_HEADER_LENGTH - 2);
}
if (n <= 0) throw WebSocketException("Incomplete header received", WebSocket::WS_ERR_INCOMPLETE_FRAME);
n += 2;
}
Poco::MemoryInputStream istr(header, n);
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt8 flags;
char mask[4];
reader >> flags >> lengthByte;
_frameFlags = flags;
int payloadLength = 0;
int payloadOffset = 2;
if ((lengthByte & 0x7f) == 127)
{
Poco::MemoryInputStream istr(header + 2, 8);
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt64 l;
reader >> l;
if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %Lu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = static_cast<int>(l);
payloadOffset += 8;
}
else if ((lengthByte & 0x7f) == 126)
} else if (lengthByte == 126)
{
n = receiveNBytes(header + 2, 2);
if (n <= 0)
{
_frameFlags = 0;
return n;
}
Poco::MemoryInputStream istr(header + 2, 2);
Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER);
Poco::UInt16 l;
reader >> l;
if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %hu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = static_cast<int>(l);
payloadOffset += 2;
}
else
{
Poco::UInt8 l = lengthByte & 0x7f;
if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %u", unsigned(l)), WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
payloadLength = static_cast<int>(l);
}
if (lengthByte & FRAME_FLAG_MASK)
{
reader.readRaw(mask, 4);
payloadOffset += 4;
payloadLength = lengthByte;
}
int received = 0;
if (payloadOffset < n)
{
std::memcpy(buffer, header + payloadOffset, n - payloadOffset);
received = n - payloadOffset;
}
if (received < payloadLength)

if (useMask)
{
n = receiveNBytes(reinterpret_cast<char*>(buffer) + received, payloadLength - received);
if (n <= 0) throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME);
received += n;
n = receiveNBytes(mask, 4);
if (n <= 0)
{
_frameFlags = 0;
return n;
}
}
if (lengthByte & FRAME_FLAG_MASK)

return payloadLength;
}


int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask)
{
int received = receiveNBytes(reinterpret_cast<char*>(buffer), payloadLength);
if (received <= 0) throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME);

if (useMask)
{
char* p = reinterpret_cast<char*>(buffer);
for (int i = 0; i < received; i++)
{
p[i] ^= mask[i % 4];
buffer[i] ^= mask[i % 4];
}
}
return received;
}


int WebSocketImpl::receiveBytes(void* buffer, int length, int)
{
char mask[4];
bool useMask;
int payloadLength = receiveHeader(mask, useMask);
if (payloadLength <= 0)
return payloadLength;
if (payloadLength > length)
throw WebSocketException(Poco::format("Insufficient buffer for payload size %hu", payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG);
return receivePayload(reinterpret_cast<char*>(buffer), payloadLength, mask, useMask);
}


int WebSocketImpl::receiveBytes(Poco::Buffer<char>& buffer, int)
{
char mask[4];
bool useMask;
int payloadLength = receiveHeader(mask, useMask);
if (payloadLength <= 0)
return payloadLength;
int oldSize = buffer.size();
buffer.resize(oldSize + payloadLength);
return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask);
}


int WebSocketImpl::receiveNBytes(void* buffer, int bytes)
{
int received = _pStreamSocketImpl->receiveBytes(reinterpret_cast<char*>(buffer), bytes);
Expand Down
58 changes: 58 additions & 0 deletions Net/testsuite/src/WebSocketTest.cpp
Expand Up @@ -141,6 +141,13 @@ void WebSocketTest::testWebSocket()
assert (n == payload.size());
assert (payload.compare(0, payload.size(), buffer, 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT);

ws.sendFrame(payload.data(), (int) payload.size());
Poco::Buffer<char> pocobuffer(0);
n = ws.receiveFrame(pocobuffer, flags);
assert (n == payload.size());
assert (payload.compare(0, payload.size(), pocobuffer.begin(), 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT);
}

for (int i = 125; i < 129; i++)
Expand All @@ -151,6 +158,13 @@ void WebSocketTest::testWebSocket()
assert (n == payload.size());
assert (payload.compare(0, payload.size(), buffer, 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT);

ws.sendFrame(payload.data(), (int) payload.size());
Poco::Buffer<char> pocobuffer(0);
n = ws.receiveFrame(pocobuffer, flags);
assert (n == payload.size());
assert (payload.compare(0, payload.size(), pocobuffer.begin(), 0, n) == 0);
assert (flags == WebSocket::FRAME_TEXT);
}

payload = "Hello, world!";
Expand Down Expand Up @@ -210,6 +224,49 @@ void WebSocketTest::testWebSocketLarge()
}


void WebSocketTest::testOneLargeFrame(int msgSize)
{
Poco::Net::ServerSocket ss(0);
Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory(msgSize), ss, new Poco::Net::HTTPServerParams);
server.start();

Poco::Thread::sleep(200);

HTTPClientSession cs("localhost", ss.address().port());
HTTPRequest request(HTTPRequest::HTTP_GET, "/ws");
HTTPResponse response;
WebSocket ws(cs, request, response);
ws.setSendBufferSize(msgSize);
ws.setReceiveBufferSize(msgSize);
std::string payload(msgSize, 'x');

ws.sendFrame(payload.data(), msgSize);

char buffer[msgSize];
int flags;
int n;

n = ws.receiveFrame(buffer, sizeof(buffer), flags);
assert (n == payload.size());
assert (payload.compare(0, payload.size(), buffer, 0, n) == 0);

ws.sendFrame(payload.data(), msgSize);

Poco::Buffer<char> pocobuffer(0);

n = ws.receiveFrame(pocobuffer, flags);
assert (n == payload.size());
assert (payload.compare(0, payload.size(), pocobuffer.begin(), 0, n) == 0);
}


void WebSocketTest::testWebSocketLargeInOneFrame()
{
testOneLargeFrame(64000);
testOneLargeFrame(70000);
}


void WebSocketTest::setUp()
{
}
Expand All @@ -226,6 +283,7 @@ CppUnit::Test* WebSocketTest::suite()

CppUnit_addTest(pSuite, WebSocketTest, testWebSocket);
CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLarge);
CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLargeInOneFrame);

return pSuite;
}
2 changes: 2 additions & 0 deletions Net/testsuite/src/WebSocketTest.h
Expand Up @@ -28,13 +28,15 @@ class WebSocketTest: public CppUnit::TestCase

void testWebSocket();
void testWebSocketLarge();
void testWebSocketLargeInOneFrame();

void setUp();
void tearDown();

static CppUnit::Test* suite();

private:
void testOneLargeFrame(int msgSize);
};


Expand Down

0 comments on commit 08a748d

Please sign in to comment.