Skip to content

Commit

Permalink
Improve dispatcher tests (#2358)
Browse files Browse the repository at this point in the history
This improves dispatcher tests by allowing units to act like component
tests and use embedded std::thread-based osquery APIs. A unit may force
a 'service' to run by joining the Dispatcher before deconstructing.
  • Loading branch information
Teddy Reed committed Aug 14, 2016
1 parent 89e1854 commit 58fd284
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 61 deletions.
24 changes: 23 additions & 1 deletion include/osquery/dispatcher.h
Expand Up @@ -79,6 +79,10 @@ class InterruptableRunnable {
/// Put the runnable into an interruptible sleep.
virtual void pauseMilli(std::chrono::milliseconds milli);

private:
/// Testing only, the interruptible will bypass initial interruption check.
void mustRun() { bypass_check_ = true; }

private:
/**
* @brief Protect interruption checking and resource tear down.
Expand All @@ -94,6 +98,19 @@ class InterruptableRunnable {

/// Use an interruption point to exit a pause if the thread was interrupted.
RunnerInterruptPoint point_;

private:
/// Testing only, track the interruptible check for interruption.
bool checked_{false};

/// Testing only, require that the interruptible bypass the first check.
std::atomic<bool> bypass_check_{false};

private:
FRIEND_TEST(DispatcherTests, test_run);
FRIEND_TEST(DispatcherTests, test_independent_run);
FRIEND_TEST(DispatcherTests, test_interruption);
FRIEND_TEST(BufferedLogForwarderTests, test_async);
};

class InternalRunnable : private boost::noncopyable,
Expand Down Expand Up @@ -181,6 +198,10 @@ class Dispatcher : private boost::noncopyable {
/// When a service ends, it will remove itself from the dispatcher.
static void removeService(const InternalRunnable* service);

private:
/// For testing only, reset the stopping status for unittests.
void resetStopping() { stopping_ = false; }

private:
/// The set of shared osquery service threads.
std::vector<InternalThreadRef> service_threads_;
Expand Down Expand Up @@ -209,6 +230,7 @@ class Dispatcher : private boost::noncopyable {

private:
friend class InternalRunnable;
friend class ExtensionsTest;
friend class ExtensionsTests;
friend class DispatcherTests;
};
}
32 changes: 20 additions & 12 deletions osquery/dispatcher/dispatcher.cpp
Expand Up @@ -51,6 +51,11 @@ void InterruptableRunnable::interrupt() {

bool InterruptableRunnable::interrupted() {
WriteLock lock(stopping_);
// A small conditional to force-skip an interruption check, used in testing.
if (bypass_check_ && !checked_) {
checked_ = true;
return false;
}
return interrupted_;
}

Expand Down Expand Up @@ -106,14 +111,26 @@ void Dispatcher::removeService(const InternalRunnable* service) {
self.services_.end());
}

inline static void assureRun(const InternalRunnableRef& service) {
while (true) {
// Wait for each thread's entry point (start) meaning the thread context
// was allocated and (run) was called by std::thread started.
if (service->hasRun()) {
break;
}
// We only need to check if std::terminate is called very quickly after
// the std::thread is created.
sleepFor(20);
}
}

void Dispatcher::joinServices() {
auto& self = instance();
DLOG(INFO) << "Thread: " << std::this_thread::get_id()
<< " requesting a join";
WriteLock join_lock(self.join_mutex_);

for (auto& thread : self.service_threads_) {
// Boost threads would have been interrupted, and joined using the
// provided thread instance.
thread->join();
DLOG(INFO) << "Service thread: " << thread.get() << " has joined";
}
Expand All @@ -133,16 +150,7 @@ void Dispatcher::stopServices() {
DLOG(INFO) << "Thread: " << std::this_thread::get_id()
<< " requesting a stop";
for (const auto& service : self.services_) {
while (true) {
// Wait for each thread's entry point (start) meaning the thread context
// was allocated and (run) was called by std::thread started.
if (service->hasRun()) {
break;
}
// We only need to check if std::terminate is called very quickly after
// the std::thread is created.
sleepFor(20);
}
assureRun(service);
service->interrupt();
DLOG(INFO) << "Service: " << service.get() << " has been interrupted";
}
Expand Down
108 changes: 104 additions & 4 deletions osquery/dispatcher/tests/dispatcher_tests.cpp
Expand Up @@ -15,7 +15,7 @@
namespace osquery {

class DispatcherTests : public testing::Test {
void TearDown() override {}
void TearDown() override { Dispatcher::instance().resetStopping(); }
};

TEST_F(DispatcherTests, test_singleton) {
Expand All @@ -26,8 +26,108 @@ TEST_F(DispatcherTests, test_singleton) {

class TestRunnable : public InternalRunnable {
public:
int* i;
explicit TestRunnable(int* i) : i(i) {}
virtual void start() { ++*i; }
explicit TestRunnable() {}

virtual void start() override {
WriteLock lock(mutex_);
++i;
}

void reset() {
WriteLock lock(mutex_);
i = 0;
}

size_t count() {
WriteLock lock(mutex_);
return i;
}

private:
static size_t i;

private:
Mutex mutex_;
};

size_t TestRunnable::i{0};

TEST_F(DispatcherTests, test_service_count) {
auto runnable = std::make_shared<TestRunnable>();

auto service_count = Dispatcher::instance().serviceCount();
// The service exits after incrementing.
auto s = Dispatcher::addService(runnable);
EXPECT_TRUE(s);

// Wait for the service to stop.
Dispatcher::joinServices();

// Make sure the service is removed.
EXPECT_EQ(service_count, Dispatcher::instance().serviceCount());
}

TEST_F(DispatcherTests, test_run) {
auto runnable = std::make_shared<TestRunnable>();
runnable->mustRun();
runnable->reset();

// The service exits after incrementing.
Dispatcher::addService(runnable);
Dispatcher::joinServices();
EXPECT_EQ(1U, runnable->count());
EXPECT_TRUE(runnable->hasRun());

// This runnable cannot be executed again.
auto s = Dispatcher::addService(runnable);
EXPECT_FALSE(s);

Dispatcher::joinServices();
EXPECT_EQ(1U, runnable->count());
}

TEST_F(DispatcherTests, test_independent_run) {
// Nothing stops two instances of the same service from running.
auto r1 = std::make_shared<TestRunnable>();
auto r2 = std::make_shared<TestRunnable>();
r1->mustRun();
r2->mustRun();
r1->reset();

Dispatcher::addService(r1);
Dispatcher::addService(r2);
Dispatcher::joinServices();

EXPECT_EQ(2U, r1->count());
}

class BlockingTestRunnable : public InternalRunnable {
public:
explicit BlockingTestRunnable() {}

virtual void start() override {
// Wow that's a long sleep!
pauseMilli(100 * 1000);
}
};

TEST_F(DispatcherTests, test_interruption) {
auto r1 = std::make_shared<BlockingTestRunnable>();
r1->mustRun();
Dispatcher::addService(r1);

// This service would normally wait for 100 seconds.
r1->interrupt();

Dispatcher::joinServices();
EXPECT_TRUE(r1->hasRun());
}

TEST_F(DispatcherTests, test_stop_dispatcher) {
Dispatcher::stopServices();

auto r1 = std::make_shared<TestRunnable>();
auto s = Dispatcher::addService(r1);
EXPECT_FALSE(s);
}
}
33 changes: 16 additions & 17 deletions osquery/logger/plugins/buffered.cpp
Expand Up @@ -57,12 +57,12 @@ void BufferedLogForwarder::check() {
// For each index, accumulate the log line into the result or status set.
std::vector<std::string> results, statuses;
iterate(indexes, ([&results, &statuses, this](std::string& index) {
std::string value;
auto& target = isResultIndex(index) ? results : statuses;
if (getDatabaseValue(kLogs, index, value)) {
target.push_back(std::move(value));
}
}));
std::string value;
auto& target = isResultIndex(index) ? results : statuses;
if (getDatabaseValue(kLogs, index, value)) {
target.push_back(std::move(value));
}
}));

// If any results/statuses were found in the flushed buffer, send.
if (results.size() > 0) {
Expand All @@ -72,11 +72,11 @@ void BufferedLogForwarder::check() {
} else {
// Clear the results logs once they were sent.
iterate(indexes, ([this](std::string& index) {
if (!isResultIndex(index)) {
return;
}
deleteValueWithCount(kLogs, index);
}));
if (!isResultIndex(index)) {
return;
}
deleteValueWithCount(kLogs, index);
}));
}
}

Expand All @@ -87,11 +87,11 @@ void BufferedLogForwarder::check() {
} else {
// Clear the status logs once they were sent.
iterate(indexes, ([this](std::string& index) {
if (!isStatusIndex(index)) {
return;
}
deleteValueWithCount(kLogs, index);
}));
if (!isStatusIndex(index)) {
return;
}
deleteValueWithCount(kLogs, index);
}));
}
}

Expand Down Expand Up @@ -156,7 +156,6 @@ void BufferedLogForwarder::purge() {
LOG(ERROR) << "Error deleting value during buffered log purge";
}
});

}

void BufferedLogForwarder::start() {
Expand Down
4 changes: 4 additions & 0 deletions osquery/logger/plugins/buffered.h
Expand Up @@ -146,6 +146,7 @@ class BufferedLogForwarder : public InternalRunnable {
protected:
/// Return whether the string is a result index
bool isResultIndex(const std::string& index);

/// Return whether the string is a status index
bool isStatusIndex(const std::string& index);

Expand All @@ -156,11 +157,13 @@ class BufferedLogForwarder : public InternalRunnable {
protected:
/// Generate a result index string to use with the backing store
std::string genResultIndex(size_t time = 0);

/// Generate a status index string to use with the backing store
std::string genStatusIndex(size_t time = 0);

private:
std::string genIndexPrefix(bool results);

std::string genIndex(bool results, size_t time = 0);

/**
Expand All @@ -170,6 +173,7 @@ class BufferedLogForwarder : public InternalRunnable {
Status addValueWithCount(const std::string& domain,
const std::string& key,
const std::string& value);

/**
* @brief Delete a database value while maintaining count
*
Expand Down
42 changes: 15 additions & 27 deletions osquery/logger/plugins/tests/buffered_tests.cpp
Expand Up @@ -114,10 +114,9 @@ TEST_F(BufferedLogForwarderTests, test_basic) {
runner.logString("baz");
EXPECT_CALL(runner, send(ElementsAre("bar", "baz"), "result"))
.WillOnce(Return(Status(0)));
EXPECT_CALL(
runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status"))
.WillOnce(Return(Status(0)));
EXPECT_CALL(runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)),
"status")).WillOnce(Return(Status(0)));
runner.check();
// This call should not result in sending again
runner.check();
Expand All @@ -143,16 +142,14 @@ TEST_F(BufferedLogForwarderTests, test_retry) {
runner.logString("bar");
EXPECT_CALL(runner, send(ElementsAre("foo", "bar"), "result"))
.WillOnce(Return(Status(0)));
EXPECT_CALL(
runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status"))
.WillOnce(Return(Status(1, "fail")));
EXPECT_CALL(runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)),
"status")).WillOnce(Return(Status(1, "fail")));
runner.check();

EXPECT_CALL(
runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status"))
.WillOnce(Return(Status(0)));
EXPECT_CALL(runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)),
"status")).WillOnce(Return(Status(0)));
runner.check();

// This call should not send again because the previous was successful
Expand Down Expand Up @@ -215,22 +212,14 @@ TEST_F(BufferedLogForwarderTests, test_multiple) {
TEST_F(BufferedLogForwarderTests, test_async) {
auto runner = std::make_shared<StrictMock<MockBufferedLogForwarder>>(
"mock", kLogPeriod);
Dispatcher::addService(runner);
runner->mustRun();

EXPECT_CALL(*runner, send(ElementsAre("foo"), "result"))
.WillOnce(Return(Status(0)));
runner->logString("foo");
std::this_thread::sleep_for(5 * kLogPeriod);

EXPECT_CALL(*runner, send(ElementsAre("bar"), "result"))
.Times(3)
.WillOnce(Return(Status(1, "fail")))
.WillOnce(Return(Status(1, "fail again")))
.WillOnce(Return(Status(0)));
runner->logString("bar");
std::this_thread::sleep_for(15 * kLogPeriod);

Dispatcher::stopServices();
Dispatcher::addService(runner);
runner->interrupt();
Dispatcher::joinServices();
}

Expand Down Expand Up @@ -325,10 +314,9 @@ TEST_F(BufferedLogForwarderTests, test_purge_max) {

EXPECT_CALL(runner, send(ElementsAre("foo", "bar", "baz"), "result"))
.WillOnce(Return(Status(1, "fail")));
EXPECT_CALL(
runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)), "status"))
.WillOnce(Return(Status(1, "fail")));
EXPECT_CALL(runner,
send(ElementsAre(MatchesStatus(log1), MatchesStatus(log2)),
"status")).WillOnce(Return(Status(1, "fail")));
runner.check();

EXPECT_CALL(runner, send(ElementsAre("baz"), "result"))
Expand Down

0 comments on commit 58fd284

Please sign in to comment.