diff --git a/Net/include/Poco/Net/WebSocket.h b/Net/include/Poco/Net/WebSocket.h index 173eae6839..dc372931a4 100644 --- a/Net/include/Poco/Net/WebSocket.h +++ b/Net/include/Poco/Net/WebSocket.h @@ -23,6 +23,7 @@ #include "Poco/Net/Net.h" #include "Poco/Net/StreamSocket.h" #include "Poco/Net/HTTPCredentials.h" +#include "Poco/Buffer.h" namespace Poco { @@ -221,6 +222,21 @@ class Net_API WebSocket: public StreamSocket /// The frame flags and opcode (FrameFlags and FrameOpcodes) /// is stored in flags. + int receiveFrame(Poco::Buffer& buffer, int& flags); + /// Receives a frame from the socket and stores it + /// after any previous content in buffer. + /// + /// Returns the number of bytes received. + /// A return value of 0 means that the peer has + /// shut down or closed the connection. + /// + /// Throws a TimeoutException if a receive timeout has + /// been set and nothing is received within that interval. + /// Throws a NetException (or a subclass) in case of other errors. + /// + /// The frame flags and opcode (FrameFlags and FrameOpcodes) + /// is stored in flags. + Mode mode() const; /// Returns WS_SERVER if the WebSocket is a server-side /// WebSocket, or WS_CLIENT otherwise. diff --git a/Net/include/Poco/Net/WebSocketImpl.h b/Net/include/Poco/Net/WebSocketImpl.h index b631a5d2e4..648dac9eec 100644 --- a/Net/include/Poco/Net/WebSocketImpl.h +++ b/Net/include/Poco/Net/WebSocketImpl.h @@ -21,6 +21,7 @@ #include "Poco/Net/StreamSocketImpl.h" +#include "Poco/Buffer.h" #include "Poco/Random.h" @@ -43,6 +44,9 @@ class Net_API WebSocketImpl: public StreamSocketImpl virtual int receiveBytes(void* buffer, int length, int flags); /// Receives a WebSocket protocol frame. + virtual int receiveBytes(Poco::Buffer& buffer, int flags); + /// Receives a WebSocket protocol frame. + virtual SocketImpl* acceptConnection(SocketAddress& clientAddr); virtual void connect(const SocketAddress& address); virtual void connect(const SocketAddress& address, const Poco::Timespan& timeout); @@ -77,6 +81,9 @@ class Net_API WebSocketImpl: public StreamSocketImpl MAX_HEADER_LENGTH = 14 }; + int receiveHeader(char mask[4], bool& useMask); + int receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask); + int receiveNBytes(void* buffer, int bytes); virtual ~WebSocketImpl(); diff --git a/Net/src/WebSocket.cpp b/Net/src/WebSocket.cpp index 5950a6a9d9..ca09c59991 100644 --- a/Net/src/WebSocket.cpp +++ b/Net/src/WebSocket.cpp @@ -29,6 +29,7 @@ #include "Poco/String.h" #include "Poco/Random.h" #include "Poco/StreamCopier.h" +#include "Poco/Buffer.h" #include @@ -114,6 +115,14 @@ int WebSocket::receiveFrame(void* buffer, int length, int& flags) } +int WebSocket::receiveFrame(Poco::Buffer& buffer, int& flags) +{ + int n = static_cast(impl())->receiveBytes(buffer, 0); + flags = static_cast(impl())->frameFlags(); + return n; +} + + WebSocket::Mode WebSocket::mode() const { return static_cast(impl())->mustMaskPayload() ? WS_CLIENT : WS_SERVER; diff --git a/Net/src/WebSocketImpl.cpp b/Net/src/WebSocketImpl.cpp index 4ed6354d62..92a56aebf5 100644 --- a/Net/src/WebSocketImpl.cpp +++ b/Net/src/WebSocketImpl.cpp @@ -104,7 +104,7 @@ int WebSocketImpl::sendBytes(const void* buffer, int length, int flags) } -int WebSocketImpl::receiveBytes(void* buffer, int length, int) +int WebSocketImpl::receiveHeader(char mask[4], bool& useMask) { char header[MAX_HEADER_LENGTH]; int n = receiveNBytes(header, 2); @@ -114,82 +114,100 @@ int WebSocketImpl::receiveBytes(void* buffer, int length, int) return n; } poco_assert (n == 2); + Poco::UInt8 flags = static_cast(header[0]); + _frameFlags = flags; Poco::UInt8 lengthByte = static_cast(header[1]); - int maskOffset = 0; - if (lengthByte & FRAME_FLAG_MASK) maskOffset += 4; + useMask = ((lengthByte & FRAME_FLAG_MASK) != 0); + int payloadLength; lengthByte &= 0x7f; - if (lengthByte > 0 || maskOffset > 0) + if (lengthByte == 127) { - if (lengthByte + 2 + maskOffset < MAX_HEADER_LENGTH) + n = receiveNBytes(header + 2, 8); + if (n <= 0) { - n = receiveNBytes(header + 2, lengthByte + maskOffset); + _frameFlags = 0; + return n; } - else - { - n = receiveNBytes(header + 2, MAX_HEADER_LENGTH - 2); - } - if (n <= 0) throw WebSocketException("Incomplete header received", WebSocket::WS_ERR_INCOMPLETE_FRAME); - n += 2; - } - Poco::MemoryInputStream istr(header, n); - Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); - Poco::UInt8 flags; - char mask[4]; - reader >> flags >> lengthByte; - _frameFlags = flags; - int payloadLength = 0; - int payloadOffset = 2; - if ((lengthByte & 0x7f) == 127) - { + Poco::MemoryInputStream istr(header + 2, 8); + Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::UInt64 l; reader >> l; - if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %Lu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); payloadLength = static_cast(l); - payloadOffset += 8; - } - else if ((lengthByte & 0x7f) == 126) + } else if (lengthByte == 126) { + n = receiveNBytes(header + 2, 2); + if (n <= 0) + { + _frameFlags = 0; + return n; + } + Poco::MemoryInputStream istr(header + 2, 2); + Poco::BinaryReader reader(istr, Poco::BinaryReader::NETWORK_BYTE_ORDER); Poco::UInt16 l; reader >> l; - if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %hu", l), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); payloadLength = static_cast(l); - payloadOffset += 2; } else { - Poco::UInt8 l = lengthByte & 0x7f; - if (l > length) throw WebSocketException(Poco::format("Insufficient buffer for payload size %u", unsigned(l)), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); - payloadLength = static_cast(l); - } - if (lengthByte & FRAME_FLAG_MASK) - { - reader.readRaw(mask, 4); - payloadOffset += 4; + payloadLength = lengthByte; } - int received = 0; - if (payloadOffset < n) - { - std::memcpy(buffer, header + payloadOffset, n - payloadOffset); - received = n - payloadOffset; - } - if (received < payloadLength) + + if (useMask) { - n = receiveNBytes(reinterpret_cast(buffer) + received, payloadLength - received); - if (n <= 0) throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); - received += n; + n = receiveNBytes(mask, 4); + if (n <= 0) + { + _frameFlags = 0; + return n; + } } - if (lengthByte & FRAME_FLAG_MASK) + + return payloadLength; +} + + +int WebSocketImpl::receivePayload(char *buffer, int payloadLength, char mask[4], bool useMask) +{ + int received = receiveNBytes(reinterpret_cast(buffer), payloadLength); + if (received <= 0) throw WebSocketException("Incomplete frame received", WebSocket::WS_ERR_INCOMPLETE_FRAME); + + if (useMask) { - char* p = reinterpret_cast(buffer); for (int i = 0; i < received; i++) { - p[i] ^= mask[i % 4]; + buffer[i] ^= mask[i % 4]; } } return received; } +int WebSocketImpl::receiveBytes(void* buffer, int length, int) +{ + char mask[4]; + bool useMask; + int payloadLength = receiveHeader(mask, useMask); + if (payloadLength <= 0) + return payloadLength; + if (payloadLength > length) + throw WebSocketException(Poco::format("Insufficient buffer for payload size %hu", payloadLength), WebSocket::WS_ERR_PAYLOAD_TOO_BIG); + return receivePayload(reinterpret_cast(buffer), payloadLength, mask, useMask); +} + + +int WebSocketImpl::receiveBytes(Poco::Buffer& buffer, int) +{ + char mask[4]; + bool useMask; + int payloadLength = receiveHeader(mask, useMask); + if (payloadLength <= 0) + return payloadLength; + int oldSize = buffer.size(); + buffer.resize(oldSize + payloadLength); + return receivePayload(buffer.begin() + oldSize, payloadLength, mask, useMask); +} + + int WebSocketImpl::receiveNBytes(void* buffer, int bytes) { int received = _pStreamSocketImpl->receiveBytes(reinterpret_cast(buffer), bytes); diff --git a/Net/testsuite/src/WebSocketTest.cpp b/Net/testsuite/src/WebSocketTest.cpp index 3cb7016672..e1a1a7e898 100644 --- a/Net/testsuite/src/WebSocketTest.cpp +++ b/Net/testsuite/src/WebSocketTest.cpp @@ -141,6 +141,13 @@ void WebSocketTest::testWebSocket() assert (n == payload.size()); assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); assert (flags == WebSocket::FRAME_TEXT); + + ws.sendFrame(payload.data(), (int) payload.size()); + Poco::Buffer pocobuffer(0); + n = ws.receiveFrame(pocobuffer, flags); + assert (n == payload.size()); + assert (payload.compare(0, payload.size(), pocobuffer.begin(), 0, n) == 0); + assert (flags == WebSocket::FRAME_TEXT); } for (int i = 125; i < 129; i++) @@ -151,6 +158,13 @@ void WebSocketTest::testWebSocket() assert (n == payload.size()); assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); assert (flags == WebSocket::FRAME_TEXT); + + ws.sendFrame(payload.data(), (int) payload.size()); + Poco::Buffer pocobuffer(0); + n = ws.receiveFrame(pocobuffer, flags); + assert (n == payload.size()); + assert (payload.compare(0, payload.size(), pocobuffer.begin(), 0, n) == 0); + assert (flags == WebSocket::FRAME_TEXT); } payload = "Hello, world!"; @@ -210,6 +224,49 @@ void WebSocketTest::testWebSocketLarge() } +void WebSocketTest::testOneLargeFrame(int msgSize) +{ + Poco::Net::ServerSocket ss(0); + Poco::Net::HTTPServer server(new WebSocketRequestHandlerFactory(msgSize), ss, new Poco::Net::HTTPServerParams); + server.start(); + + Poco::Thread::sleep(200); + + HTTPClientSession cs("localhost", ss.address().port()); + HTTPRequest request(HTTPRequest::HTTP_GET, "/ws"); + HTTPResponse response; + WebSocket ws(cs, request, response); + ws.setSendBufferSize(msgSize); + ws.setReceiveBufferSize(msgSize); + std::string payload(msgSize, 'x'); + + ws.sendFrame(payload.data(), msgSize); + + char buffer[msgSize]; + int flags; + int n; + + n = ws.receiveFrame(buffer, sizeof(buffer), flags); + assert (n == payload.size()); + assert (payload.compare(0, payload.size(), buffer, 0, n) == 0); + + ws.sendFrame(payload.data(), msgSize); + + Poco::Buffer pocobuffer(0); + + n = ws.receiveFrame(pocobuffer, flags); + assert (n == payload.size()); + assert (payload.compare(0, payload.size(), pocobuffer.begin(), 0, n) == 0); +} + + +void WebSocketTest::testWebSocketLargeInOneFrame() +{ + testOneLargeFrame(64000); + testOneLargeFrame(70000); +} + + void WebSocketTest::setUp() { } @@ -226,6 +283,7 @@ CppUnit::Test* WebSocketTest::suite() CppUnit_addTest(pSuite, WebSocketTest, testWebSocket); CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLarge); + CppUnit_addTest(pSuite, WebSocketTest, testWebSocketLargeInOneFrame); return pSuite; } diff --git a/Net/testsuite/src/WebSocketTest.h b/Net/testsuite/src/WebSocketTest.h index 939acf630c..07f18a857d 100644 --- a/Net/testsuite/src/WebSocketTest.h +++ b/Net/testsuite/src/WebSocketTest.h @@ -28,6 +28,7 @@ class WebSocketTest: public CppUnit::TestCase void testWebSocket(); void testWebSocketLarge(); + void testWebSocketLargeInOneFrame(); void setUp(); void tearDown(); @@ -35,6 +36,7 @@ class WebSocketTest: public CppUnit::TestCase static CppUnit::Test* suite(); private: + void testOneLargeFrame(int msgSize); };