Skip to content
Browse files

Extract MessageChannel functionality into free functions in Utils/Mes…

…sageIO.h.

These functions also have consistent support for timeouts.
Also simplify MessageChannel by making use of those functions.
  • Loading branch information...
1 parent e3885d5 commit e094601498450a1269cf42ef8687d969f3285f37 @FooBarWidget FooBarWidget committed
View
4 build/cxx_tests.rb
@@ -187,6 +187,10 @@
test/cxx/BufferedIOTest.cpp
ext/common/Utils/BufferedIO.h
ext/common/Utils/Timer.h),
+ 'test/cxx/MessageIOTest.o' => %w(
+ test/cxx/MessageIOTest.cpp
+ ext/common/Utils/MessageIO.h
+ ext/common/Utils/IOUtils.h),
'test/cxx/VariantMapTest.o' => %w(
test/cxx/VariantMapTest.cpp
ext/common/MessageChannel.h
View
153 ext/common/MessageChannel.h
@@ -53,6 +53,7 @@
#include "Utils/Timer.h"
#include "Utils/MemZeroGuard.h"
#include "Utils/IOUtils.h"
+#include "Utils/MessageIO.h"
namespace Passenger {
@@ -194,52 +195,9 @@ class MessageChannel {
* @pre None of the message elements may contain a NUL character (<tt>'\\0'</tt>).
* @see read(), write(const char *, ...)
*/
- template<typename StringArrayType, typename StringArrayConstIteratorType>
+ template<typename StringArrayType>
void write(const StringArrayType &args) {
- StringArrayConstIteratorType it;
- string data;
- uint16_t dataSize = 0;
-
- for (it = args.begin(); it != args.end(); it++) {
- dataSize += it->size() + 1;
- }
- data.reserve(dataSize + sizeof(dataSize));
- dataSize = htons(dataSize);
- data.append((const char *) &dataSize, sizeof(dataSize));
- for (it = args.begin(); it != args.end(); it++) {
- data.append(*it);
- data.append(1, DELIMITER);
- }
-
- writeExact(fd, data);
- }
-
- /**
- * Send an array message, which consists of the given elements, over the underlying
- * file descriptor.
- *
- * @param args The message elements.
- * @throws SystemException An error occured while writing the data to the file descriptor.
- * @throws boost::thread_interrupted
- * @pre None of the message elements may contain a NUL character (<tt>'\\0'</tt>).
- * @see read(), write(const char *, ...)
- */
- void write(const list<string> &args) {
- write<list<string>, list<string>::const_iterator>(args);
- }
-
- /**
- * Send an array message, which consists of the given elements, over the underlying
- * file descriptor.
- *
- * @param args The message elements.
- * @throws SystemException An error occured while writing the data to the file descriptor.
- * @throws boost::thread_interrupted
- * @pre None of the message elements may contain a NUL character (<tt>'\\0'</tt>).
- * @see read(), write(const char *, ...)
- */
- void write(const vector<string> &args) {
- write<vector<string>, vector<string>::const_iterator>(args);
+ writeArrayMessage(fd, args);
}
/**
@@ -252,18 +210,7 @@ class MessageChannel {
* @pre None of the message elements may contain a NUL character (<tt>'\\0'</tt>).
*/
void write(const char *name, va_list &ap) {
- list<string> args;
- args.push_back(name);
-
- while (true) {
- const char *arg = va_arg(ap, const char *);
- if (arg == NULL) {
- break;
- } else {
- args.push_back(arg);
- }
- }
- write(args);
+ writeArrayMessage(fd, name, ap);
}
/**
@@ -297,8 +244,7 @@ class MessageChannel {
* @throws boost::thread_interrupted
*/
void writeUint32(unsigned int value) {
- uint32_t l = htonl(value);
- writeExact(fd, &l, sizeof(uint32_t));
+ Passenger::writeUint32(fd, value);
}
/**
@@ -314,7 +260,7 @@ class MessageChannel {
* @see readScalar(), writeScalar(const char *, unsigned int)
*/
void writeScalar(const string &str) {
- writeScalar(str.c_str(), str.size());
+ writeScalarMessage(fd, str);
}
/**
@@ -332,8 +278,7 @@ class MessageChannel {
* @see readScalar(), writeScalar(const string &)
*/
void writeScalar(const char *data, unsigned int size) {
- writeUint32(size);
- writeExact(fd, data, size);
+ writeScalarMessage(fd, data, size);
}
/**
@@ -348,27 +293,10 @@ class MessageChannel {
* @see readFileDescriptor()
*/
void writeFileDescriptor(int fileDescriptor, bool negotiate = true) {
- // See message_channel.rb for more info about negotiation.
if (negotiate) {
- vector<string> args;
-
- if (!read(args)) {
- throw IOException("Unexpected end of stream encountered while pre-negotiating a file descriptor");
- } else if (args.size() != 1 || args[0] != "pass IO") {
- throw IOException("FD passing pre-negotiation message expected.");
- }
- }
-
- Passenger::writeFileDescriptor(fd, fileDescriptor);
-
- if (negotiate) {
- vector<string> args;
-
- if (!read(args)) {
- throw IOException("Unexpected end of stream encountered while post-negotiating a file descriptor");
- } else if (args.size() != 1 || args[0] != "got IO") {
- throw IOException("FD passing post-negotiation message expected.");
- }
+ Passenger::writeFileDescriptorWithNegotiation(fd, fileDescriptor);
+ } else {
+ Passenger::writeFileDescriptor(fd, fileDescriptor);
}
}
@@ -383,44 +311,12 @@ class MessageChannel {
* @see write()
*/
bool read(vector<string> &args) {
- uint16_t size;
- int ret;
- unsigned int alreadyRead = 0;
-
- do {
- ret = syscalls::read(fd, (char *) &size + alreadyRead, sizeof(size) - alreadyRead);
- if (ret == -1) {
- throw SystemException("read() failed", errno);
- } else if (ret == 0) {
- return false;
- }
- alreadyRead += ret;
- } while (alreadyRead < sizeof(size));
- size = ntohs(size);
-
- string buffer;
- args.clear();
- buffer.reserve(size);
- while (buffer.size() < size) {
- char tmp[1024 * 8];
- ret = syscalls::read(fd, tmp, min(size - buffer.size(), sizeof(tmp)));
- if (ret == -1) {
- throw SystemException("read() failed", errno);
- } else if (ret == 0) {
- return false;
- }
- buffer.append(tmp, ret);
- }
-
- if (!buffer.empty()) {
- string::size_type start = 0, pos;
- const string &const_buffer(buffer);
- while ((pos = const_buffer.find('\0', start)) != string::npos) {
- args.push_back(const_buffer.substr(start, pos - start));
- start = pos + 1;
- }
+ try {
+ args = readArrayMessage(fd);
+ return true;
+ } catch (const EOFException &) {
+ return false;
}
- return true;
}
/**
@@ -568,24 +464,11 @@ class MessageChannel {
* @throws boost::thread_interrupted
*/
int readFileDescriptor(bool negotiate = true) {
- // See message_channel.rb for more info about negotiation.
- if (negotiate) {
- write("pass IO", NULL);
- }
-
- int fd = Passenger::readFileDescriptor(this->fd);
-
if (negotiate) {
- try {
- write("got IO", NULL);
- } catch (...) {
- this_thread::disable_syscall_interruption dsi;
- syscalls::close(fd);
- throw;
- }
+ Passenger::readFileDescriptorWithNegotiation(fd);
+ } else {
+ Passenger::readFileDescriptor(fd);
}
-
- return fd;
}
/**
View
10 ext/common/Utils/IOUtils.h
@@ -367,6 +367,11 @@ void setWritevFunction(WritevFunction func);
/**
* Receive a file descriptor over the given Unix domain socket.
+ * This is a low-level function that directly wraps the Unix file
+ * descriptor passing system calls. You should not use this directly;
+ * instead you should use readFileDescriptorWithNegotiation() from MessageIO.h
+ * which is safer. See MessageIO.h for more information about the
+ * negotiation protocol for file descriptor passing.
*
* @param timeout A pointer to an integer, which specifies the maximum number of
* microseconds that may be spent on receiving the file descriptor.
@@ -387,6 +392,11 @@ int readFileDescriptor(int fd, unsigned long long *timeout = NULL);
/**
* Pass the file descriptor 'fdToSend' over the Unix socket 'fd'.
+ * This is a low-level function that directly wraps the Unix file
+ * descriptor passing system calls. You should not use this directly;
+ * instead you should use writeFileDescriptorWithNegotiation() from MessageIO.h
+ * which is safer. See MessageIO.h for more information about the
+ * negotiation protocol for file descriptor passing.
*
* @param timeout A pointer to an integer, which specifies the maximum number of
* microseconds that may be spent on trying to pass the file descriptor.
View
576 ext/common/Utils/MessageIO.h
@@ -0,0 +1,576 @@
+/*
+ * Phusion Passenger - http://www.modrails.com/
+ * Copyright (c) 2011 Phusion
+ *
+ * "Phusion Passenger" is a trademark of Hongli Lai & Ninh Bui.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+#ifndef _PASSENGER_MESSAGE_IO_H_
+#define _PASSENGER_MESSAGE_IO_H_
+
+/**
+ * This file contains functions for reading and writing structured messages over
+ * I/O channels. Supported message types are as follows.
+ *
+ * == 16-bit and 32-bit integers
+ * Their raw formats are binary, in big endian.
+ *
+ * == Array of strings (array messages)
+ * Each string may contain arbitrary data except for the NUL byte.
+ * Its raw format consists of a 16-bit big endian size header
+ * and a body containing all the strings in the array, each terminated
+ * by a NUL byte. The size header specifies the raw size of the body.
+ *
+ * == Arbitary binary strings (scalar messages)
+ * Its raw format consists of a 32-bit big endian size header
+ * followed by the raw string data.
+ *
+ * == File descriptor passing and negotiation
+ * Unix socket file descriptor passing is not safe without some kind
+ * of negotiation protocol. If one side passes a file descriptor, and
+ * the other side accidentally read()s past the normal data then it
+ * will read away the passed file descriptor too without actually
+ * receiving it.
+ *
+ * For example suppose that side A looks like this:
+ *
+ * read(fd, buf, 1024)
+ * read_io(fd)
+ *
+ * and side B:
+ *
+ * write(fd, buf, 100)
+ * send_io(fd_to_pass)
+ *
+ * If B completes both write() and send_io(), then A's read() call
+ * reads past the 100 bytes that B sent. On some platforms, like
+ * Linux, this will cause read_io() to fail. And it just so happens
+ * that Ruby's IO#read method slurps more than just the given amount
+ * of bytes.
+ *
+ * In order to solve this problem, we wrap the actual file descriptor
+ * passing/reading code into a negotiation protocol to ensure that
+ * this situation can never happen.
+ */
+
+// For ntohl/htonl/ntohs/htons.
+#include <sys/types.h>
+#include <arpa/inet.h>
+#include <netinet/in.h>
+
+#include <algorithm>
+#include <string>
+#include <vector>
+#include <cstring>
+#include <cstdarg>
+
+#include <boost/cstdint.hpp>
+#include <boost/bind.hpp>
+#include <boost/scoped_array.hpp>
+
+#include <oxt/macros.hpp>
+
+#include <StaticString.h>
+#include <Exceptions.h>
+#include <Utils/MemZeroGuard.h>
+#include <Utils/ScopeGuard.h>
+#include <Utils/IOUtils.h>
+#include <Utils/StrIntUtils.h>
+
+
+namespace Passenger {
+
+using namespace std;
+using namespace boost;
+
+/**
+ * Reads a 16-bit unsigned integer from the given file descriptor. The result
+ * is put into 'output'.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on reading the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on reading will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @return True if reading was successful, false if end-of-file was prematurely reached.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to read the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline bool
+readUint16(int fd, uint16_t &output, unsigned long long *timeout = NULL) {
+ uint16_t temp;
+
+ if (readExact(fd, &temp, sizeof(uint16_t), timeout) == sizeof(uint16_t)) {
+ output = ntohs(temp);
+ return true;
+ } else {
+ return false;
+ }
+}
+
+/**
+ * Reads a 16-bit unsigned integer from the given file descriptor.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on reading the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on reading will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws EOFException End-of-file was reached before a full integer could be read.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to read the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline uint16_t
+readUint16(int fd, unsigned long long *timeout = NULL) {
+ uint16_t temp;
+
+ if (readUint16(fd, temp, timeout)) {
+ return temp;
+ } else {
+ throw EOFException("EOF encountered before a full 16-bit integer could be read");
+ }
+}
+
+/**
+ * Reads a 32-bit unsigned integer from the given file descriptor. The result
+ * is put into 'output'.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on reading the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on reading will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @return True if reading was successful, false if end-of-file was prematurely reached.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to read the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline bool
+readUint32(int fd, uint32_t &output, unsigned long long *timeout = NULL) {
+ uint32_t temp;
+
+ if (readExact(fd, &temp, sizeof(uint32_t), timeout) == sizeof(uint32_t)) {
+ output = ntohl(temp);
+ return true;
+ } else {
+ return false;
+ }
+}
+
+/**
+ * Reads a 32-bit unsigned integer from the given file descriptor.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on reading the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on reading will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws EOFException End-of-file was reached before a full integer could be read.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to read the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline uint32_t
+readUint32(int fd, unsigned long long *timeout = NULL) {
+ uint32_t temp;
+
+ if (readUint32(fd, temp, timeout)) {
+ return temp;
+ } else {
+ throw EOFException("EOF encountered before a full 32-bit integer could be read");
+ }
+}
+
+
+/**
+ * Reads an array message from the given file descriptor. This version
+ * puts the result into the given collection instead of returning a
+ * new collection.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on reading the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on reading will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @return True if an array message was read, false if end-of-file was reached
+ * before a full array message could be read.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to read the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+template<typename Collection>
+inline bool
+readArrayMessage(int fd, Collection &output, unsigned long long *timeout = NULL) {
+ uint16_t size;
+ if (!readUint16(fd, size, timeout)) {
+ return false;
+ }
+
+ scoped_array<char> buffer(new char[size]);
+ MemZeroGuard g(buffer.get(), size);
+ if (readExact(fd, buffer.get(), size, timeout) != size) {
+ return false;
+ }
+
+ output.clear();
+ if (size != 0) {
+ string::size_type start = 0, pos;
+ StaticString buffer_str(buffer.get(), size);
+ while ((pos = buffer_str.find('\0', start)) != string::npos) {
+ output.push_back(buffer_str.substr(start, pos - start));
+ start = pos + 1;
+ }
+ }
+ return true;
+}
+
+/**
+ * Reads an array message from the given file descriptor. This version returns
+ * the result immediately as a string vector.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on reading the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on reading will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws EOFException End-of-file was reached before a full integer could be read.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to read the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline vector<string>
+readArrayMessage(int fd, unsigned long long *timeout = NULL) {
+ vector<string> output;
+
+ if (readArrayMessage(fd, output, timeout)) {
+ return output;
+ } else {
+ throw EOFException("EOF encountered before the full array message could be read");
+ }
+}
+
+
+/**
+ * Reads a scalar message from the given file descriptor.
+ *
+ * @param maxSize The maximum number of bytes that may be read. If the
+ * scalar to read is larger than this, then a SecurityException
+ * will be thrown. Set to 0 for no size limit.
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on reading the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on reading will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws EOFException End-of-file was reached before a full integer could be read.
+ * @throws SystemException Something went wrong.
+ * @throws SecurityException The message body is larger than allowed by maxSize.
+ * @throws TimeoutException Unable to read the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline string
+readScalarMessage(int fd, unsigned int maxSize = 0, unsigned long long *timeout = NULL) {
+ uint32_t size;
+ if (!readUint32(fd, size, timeout)) {
+ throw EOFException("EOF encountered before a 32-bit scalar message header could be read");
+ }
+
+ if (maxSize != 0 && size > (uint32_t) maxSize) {
+ throw SecurityException("The scalar message body is larger than the size limit");
+ }
+
+ string output;
+ unsigned int remaining = size;
+ output.reserve(size);
+ if (OXT_LIKELY(remaining > 0)) {
+ char buf[1024 * 32];
+ MemZeroGuard g(buf, sizeof(buf));
+
+ while (remaining > 0) {
+ unsigned int blockSize = min((unsigned int) sizeof(buf), remaining);
+
+ if (readExact(fd, buf, blockSize, timeout) != blockSize) {
+ throw EOFException("EOF encountered before the full scalar message body could be read");
+ }
+ output.append(buf, blockSize);
+ remaining -= blockSize;
+ }
+ }
+ return output;
+}
+
+
+/**
+ * Writes a 16-bit unsigned integer to the given file descriptor.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on writing the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on writing will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to write the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline void
+writeUint16(int fd, uint16_t value, unsigned long long *timeout = NULL) {
+ uint16_t l = htons(value);
+ writeExact(fd, &l, sizeof(uint16_t), timeout);
+}
+
+/**
+ * Writes a 32-bit unsigned integer to the given file descriptor.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on writing the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on writing will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to write the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline void
+writeUint32(int fd, uint32_t value, unsigned long long *timeout = NULL) {
+ uint32_t l = htonl(value);
+ writeExact(fd, &l, sizeof(uint32_t), timeout);
+}
+
+
+/**
+ * Writes an array message to the given file descriptor.
+ *
+ * @param args A collection of strings containing the array message's elements.
+ * The collection must have an STL container-like interface and
+ * the strings must have an STL string-like interface.
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on writing the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on writing will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to write the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+template<typename Collection>
+inline void
+writeArrayMessage(int fd, const Collection &args, unsigned long long *timeout = NULL) {
+ typename Collection::const_iterator it, end = args.end();
+ uint16_t bodySize = 0;
+
+ for (it = args.begin(); it != end; it++) {
+ bodySize += it->size() + 1;
+ }
+
+ scoped_array<char> data(new char[sizeof(uint16_t) + bodySize]);
+ uint16_t header = htons(bodySize);
+ memcpy(data.get(), &header, sizeof(uint16_t));
+
+ char *dataEnd = data.get() + sizeof(uint16_t);
+ for (it = args.begin(); it != end; it++) {
+ memcpy(dataEnd, it->data(), it->size());
+ dataEnd += it->size();
+ *dataEnd = '\0';
+ dataEnd++;
+ }
+
+ writeExact(fd, data.get(), sizeof(uint16_t) + bodySize, timeout);
+}
+
+inline void
+writeArrayMessage(int fd, const StaticString &name, va_list &ap, unsigned long long *timeout = NULL) {
+ vector<StaticString> args;
+
+ args.push_back(name);
+ while (true) {
+ const char *arg = va_arg(ap, const char *);
+ if (arg == NULL) {
+ break;
+ } else {
+ args.push_back(arg);
+ }
+ }
+ writeArrayMessage(fd, args, timeout);
+}
+
+struct _VaGuard {
+ va_list &ap;
+
+ _VaGuard(va_list &_ap)
+ : ap(_ap)
+ { }
+
+ ~_VaGuard() {
+ va_end(ap);
+ }
+};
+
+/** Version of writeArrayMessage() that accepts a variadic list of 'const char *'
+ * arguments as message elements. The list must be terminated with a NULL.
+ */
+inline void
+writeArrayMessage(int fd, const StaticString &name, ...) {
+ va_list ap;
+ va_start(ap, name);
+ _VaGuard guard(ap);
+ writeArrayMessage(fd, name, ap);
+}
+
+/** Version of writeArrayMessage() that accepts a variadic list of 'const char *'
+ * arguments as message elements, with timeout support. The list must be terminated
+ * with a NULL.
+ */
+inline void
+writeArrayMessage(int fd, unsigned long long *timeout, const StaticString &name, ...) {
+ va_list ap;
+ va_start(ap, name);
+ _VaGuard guard(ap);
+ writeArrayMessage(fd, name, ap, timeout);
+}
+
+/**
+ * Writes a scalar message to the given file descriptor.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on writing the necessary data.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on writing will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to write the necessary data within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline void
+writeScalarMessage(int fd, const StaticString &data, unsigned long long *timeout = NULL) {
+ uint32_t header = htonl(data.size());
+ StaticString buffers[2] = {
+ StaticString((const char *) &header, sizeof(uint32_t)),
+ data
+ };
+ gatheredWrite(fd, buffers, 2, timeout);
+}
+
+inline void
+writeScalarMessage(int fd, const char *data, size_t size, unsigned long long *timeout = NULL) {
+ writeScalarMessage(fd, StaticString(data, size), timeout);
+}
+
+
+/**
+ * Receive a file descriptor over the given Unix domain socket,
+ * involving a negotiation protocol.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on receiving the file descriptor.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on receiving will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @return The received file descriptor.
+ * @throws SystemException Something went wrong.
+ * @throws IOException Whatever was received doesn't seem to be a
+ * file descriptor.
+ * @throws TimeoutException Unable to receive a file descriptor within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline int
+readFileDescriptorWithNegotiation(int fd, unsigned long long *timeout = NULL) {
+ writeArrayMessage(fd, timeout, "pass IO", NULL);
+ int result = readFileDescriptor(fd, timeout);
+ ScopeGuard guard(boost::bind(safelyClose, result));
+ writeArrayMessage(fd, timeout, "got IO", NULL);
+ guard.clear();
+ return result;
+}
+
+
+/**
+ * Pass the file descriptor 'fdToSend' over the Unix socket 'fd',
+ * involving a negotiation protocol.
+ *
+ * @param timeout A pointer to an integer, which specifies the maximum number of
+ * microseconds that may be spent on trying to pass the file descriptor.
+ * If the timeout expired then TimeoutException will be thrown.
+ * If this function returns without throwing an exception, then the
+ * total number of microseconds spent on writing will be deducted
+ * from <tt>timeout</tt>.
+ * Pass NULL if you do not want to enforce a timeout.
+ * @throws SystemException Something went wrong.
+ * @throws TimeoutException Unable to pass the file descriptor within
+ * <tt>timeout</tt> microseconds.
+ * @throws boost::thread_interrupted
+ */
+inline void
+writeFileDescriptorWithNegotiation(int fd, int fdToPass, unsigned long long *timeout = NULL) {
+ vector<string> args;
+
+ args = readArrayMessage(fd, timeout);
+ if (args.size() != 1 || args[0] != "pass IO") {
+ throw IOException("FD passing pre-negotiation message expected");
+ }
+
+ writeFileDescriptor(fd, fdToPass, timeout);
+
+ args = readArrayMessage(fd, timeout);
+ if (args.size() != 1 || args[0] != "got IO") {
+ throw IOException("FD passing post-negotiation message expected.");
+ }
+}
+
+
+} // namespace Passenger
+
+#endif /* _PASSENGER_MESSAGE_IO_H_ */
View
2 test/cxx/MessageChannelTest.cpp
@@ -458,7 +458,7 @@ namespace tut {
fail("TimeoutException expected");
} catch (const TimeoutException &) {
unsigned long long elapsed = timer.elapsed();
- ensure("Spent at least 35 msec waiting", elapsed >= 35);
+ ensure("Spent at least 35 msec waiting", elapsed >= 34);
ensure("Spent at most 60 msec waiting", elapsed <= 60);
ensure("The passed time is deducted from timeout", timeout < 5);
}
View
310 test/cxx/MessageIOTest.cpp
@@ -0,0 +1,310 @@
+#include "TestSupport.h"
+#include <Utils/IOUtils.h>
+#include <Utils/MessageIO.h>
+#include <Utils/SystemTime.h>
+
+using namespace Passenger;
+using namespace std;
+using namespace boost;
+
+namespace tut {
+ struct MessageIOTest {
+ Pipe pipes;
+
+ MessageIOTest() {
+ pipes = createPipe();
+ }
+ };
+
+ DEFINE_TEST_GROUP(MessageIOTest);
+
+ /***** Test readUint16() and writeUint16() *****/
+
+ TEST_METHOD(1) {
+ // They work.
+ writeUint16(pipes[1], 0x3F56);
+ writeUint16(pipes[1], 0x3F57);
+ writeUint16(pipes[1], 0x3F58);
+
+ unsigned char buf[2];
+ ensure_equals(readExact(pipes[0], buf, 2), 2u);
+ ensure_equals(buf[0], 0x3F);
+ ensure_equals(buf[1], 0x56);
+
+ ensure_equals(readUint16(pipes[0]), 0x3F57u);
+
+ uint16_t out;
+ ensure(readUint16(pipes[0], out));
+ ensure_equals(out, 0x3F58);
+ }
+
+ TEST_METHOD(2) {
+ // readUint16() throws EOFException on premature EOF.
+ writeExact(pipes[1], "x", 1);
+ pipes[1].close();
+ try {
+ readUint16(pipes[0]);
+ fail("EOFException expected");
+ } catch (const EOFException &) {
+ }
+ }
+
+ TEST_METHOD(3) {
+ // readUint16(uint32_t &) returns false EOFException on premature EOF.
+ writeExact(pipes[1], "x", 1);
+ pipes[1].close();
+ uint16_t out;
+ ensure(!readUint16(pipes[0], out));
+ }
+
+ TEST_METHOD(4) {
+ // Test timeout.
+ unsigned long long timeout = 30000;
+ unsigned long long startTime = SystemTime::getUsec();
+ try {
+ readUint16(pipes[0], &timeout);
+ fail("TimeoutException expected");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure("About 30 ms elapsed (1)", elapsed >= 29000 && elapsed <= 50000);
+ ensure("Time is correctly deducted from 'timeout' (1)", timeout <= 2000);
+ }
+
+ writeUntilFull(pipes[1]);
+
+ timeout = 30000;
+ startTime = SystemTime::getUsec();
+ try {
+ writeUint16(pipes[1], 0x12, &timeout);
+ fail("TimeoutException expected");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure("About 30 ms elapsed (3)", elapsed >= 29000 && elapsed <= 50000);
+ ensure("Time is correctly deducted from 'timeout' (4)", timeout <= 2000);
+ }
+ }
+
+ /***** Test readUint32() and writeUint32() *****/
+
+ TEST_METHOD(10) {
+ // They work.
+ writeUint32(pipes[1], 0x12343F56);
+ writeUint32(pipes[1], 0x12343F57);
+ writeUint32(pipes[1], 0x12343F58);
+
+ unsigned char buf[4];
+ ensure_equals(readExact(pipes[0], buf, 4), 4u);
+ ensure_equals(buf[0], 0x12);
+ ensure_equals(buf[1], 0x34);
+ ensure_equals(buf[2], 0x3F);
+ ensure_equals(buf[3], 0x56);
+
+ ensure_equals(readUint32(pipes[0]), 0x12343F57u);
+
+ uint32_t out;
+ ensure(readUint32(pipes[0], out));
+ ensure_equals(out, 0x12343F58u);
+ }
+
+ TEST_METHOD(11) {
+ // readUint32() throws EOFException on premature EOF.
+ writeExact(pipes[1], "xyz", 3);
+ pipes[1].close();
+ try {
+ readUint32(pipes[0]);
+ fail("EOFException expected");
+ } catch (const EOFException &) {
+ }
+ }
+
+ TEST_METHOD(12) {
+ // readUint16(uint32_t &) returns false EOFException on premature EOF.
+ writeExact(pipes[1], "xyz", 3);
+ pipes[1].close();
+ uint32_t out;
+ ensure(!readUint32(pipes[0], out));
+ }
+
+ TEST_METHOD(13) {
+ // Test timeout.
+ unsigned long long timeout = 30000;
+ unsigned long long startTime = SystemTime::getUsec();
+ try {
+ readUint32(pipes[0], &timeout);
+ fail("TimeoutException expected");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure(elapsed >= 29000 && elapsed <= 50000);
+ ensure(timeout <= 2000);
+ }
+
+ writeUntilFull(pipes[1]);
+
+ timeout = 30000;
+ startTime = SystemTime::getUsec();
+ try {
+ writeUint32(pipes[1], 0x1234, &timeout);
+ fail("TimeoutException expected");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure(elapsed >= 29000 && elapsed <= 50000);
+ ensure(timeout <= 2000);
+ }
+ }
+
+ /***** Test readArrayMessage() and writeArrayMessage() *****/
+
+ TEST_METHOD(20) {
+ // They work.
+ writeArrayMessage(pipes[1], "ab", "cd", "efg", NULL);
+ writeArrayMessage(pipes[1], "ab", "cd", "efh", NULL);
+
+ unsigned char buf[12];
+ readExact(pipes[0], buf, 12);
+ ensure_equals(buf[0], 0u);
+ ensure_equals(buf[1], 10u);
+ ensure_equals(buf[2], 'a');
+ ensure_equals(buf[3], 'b');
+ ensure_equals(buf[4], '\0');
+ ensure_equals(buf[5], 'c');
+ ensure_equals(buf[6], 'd');
+ ensure_equals(buf[7], '\0');
+ ensure_equals(buf[8], 'e');
+ ensure_equals(buf[9], 'f');
+ ensure_equals(buf[10], 'g');
+ ensure_equals(buf[11], '\0');
+
+ vector<string> args = readArrayMessage(pipes[0]);
+ ensure_equals(args.size(), 3u);
+ ensure_equals(args[0], "ab");
+ ensure_equals(args[1], "cd");
+ ensure_equals(args[2], "efh");
+ }
+
+ TEST_METHOD(21) {
+ // readArrayMessage() throws EOFException on premature EOF.
+ writeExact(pipes[1], "\x00");
+ pipes[1].close();
+ try {
+ readArrayMessage(pipes[0]);
+ fail("EOFException expected (1)");
+ } catch (const EOFException &) {
+ }
+
+ pipes = createPipe();
+ writeExact(pipes[1], "\x00\x04a\x00b");
+ pipes[1].close();
+ try {
+ readArrayMessage(pipes[0]);
+ fail("EOFException expected (2)");
+ } catch (const EOFException &) {
+ }
+ }
+
+ TEST_METHOD(22) {
+ // Test timeout.
+ unsigned long long timeout = 30000;
+ unsigned long long startTime = SystemTime::getUsec();
+ try {
+ readArrayMessage(pipes[0], &timeout);
+ fail("TimeoutException expected (1)");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure(elapsed >= 29000 && elapsed <= 50000);
+ ensure(timeout <= 2000);
+ }
+
+ writeUntilFull(pipes[1]);
+
+ timeout = 30000;
+ startTime = SystemTime::getUsec();
+ try {
+ writeArrayMessage(pipes[1], &timeout, "hi", "ho", NULL);
+ fail("TimeoutException expected (2)");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure(elapsed >= 29000 && elapsed <= 50000);
+ ensure(timeout <= 2000);
+ }
+ }
+
+ /***** Test readArrayMessage() and writeArrayMessage() *****/
+
+ TEST_METHOD(30) {
+ // They work.
+ writeScalarMessage(pipes[1], "hello");
+ writeScalarMessage(pipes[1], "world");
+
+ unsigned char buf[4 + 5];
+ readExact(pipes[0], buf, 4 + 5);
+ ensure_equals(buf[0], 0u);
+ ensure_equals(buf[1], 0u);
+ ensure_equals(buf[2], 0u);
+ ensure_equals(buf[3], 5u);
+ ensure_equals(buf[4], 'h');
+ ensure_equals(buf[5], 'e');
+ ensure_equals(buf[6], 'l');
+ ensure_equals(buf[7], 'l');
+ ensure_equals(buf[8], 'o');
+
+ ensure_equals(readScalarMessage(pipes[0]), "world");
+ }
+
+ TEST_METHOD(31) {
+ // readScalarMessage() throws EOFException on premature EOF.
+ writeExact(pipes[1], StaticString("\x00", 1));
+ pipes[1].close();
+ try {
+ readScalarMessage(pipes[0]);
+ fail("EOFException expected (1)");
+ } catch (const EOFException &) {
+ }
+
+ pipes = createPipe();
+ writeExact(pipes[1], StaticString("\x00\x00\x00\x04" "abc", 4 + 3));
+ pipes[1].close();
+ try {
+ readScalarMessage(pipes[0]);
+ fail("EOFException expected (2)");
+ } catch (const EOFException &) {
+ }
+ }
+
+ TEST_METHOD(32) {
+ // readScalarMessage() throws SecurityException if the
+ // body larger than the limit
+ writeExact(pipes[1], StaticString("\x00\x00\x00\x05", 4));
+ try {
+ readScalarMessage(pipes[0], 4);
+ fail("SecurityException expected (1)");
+ } catch (const SecurityException &) {
+ }
+ }
+
+ TEST_METHOD(33) {
+ // Test timeout.
+ unsigned long long timeout = 30000;
+ unsigned long long startTime = SystemTime::getUsec();
+ try {
+ readScalarMessage(pipes[0], 0, &timeout);
+ fail("TimeoutException expected (1)");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure(elapsed >= 29000 && elapsed <= 50000);
+ ensure(timeout <= 2000);
+ }
+
+ writeUntilFull(pipes[1]);
+
+ timeout = 30000;
+ startTime = SystemTime::getUsec();
+ try {
+ writeScalarMessage(pipes[1], "hello", &timeout);
+ fail("TimeoutException expected (2)");
+ } catch (const TimeoutException &) {
+ unsigned long long elapsed = SystemTime::getUsec() - startTime;
+ ensure(elapsed >= 29000 && elapsed <= 50000);
+ ensure(timeout <= 2000);
+ }
+ }
+}

0 comments on commit e094601

Please sign in to comment.
Something went wrong with that request. Please try again.