Skip to content

Commit

Permalink
Add tests and fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
jcarreira committed Nov 2, 2018
1 parent 3deab41 commit e33b8a9
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .travis.yml
Expand Up @@ -27,6 +27,8 @@ script:
- for i in "./tests/test_travis_lr/test.sh & sleep 1" "./tests/test_travis_lr/error"; do eval ${i}; done
- ./tests/test_data/jester.sh
- for i in "./tests/test_travis_mf/test.sh & sleep 1" "./tests/test_travis_mf/error"; do eval ${i}; done
- ./tests/test_travis/test_register.sh
- ./tests/test_travis/test_keyvalue.sh

env:
global:
Expand Down
1 change: 1 addition & 0 deletions configure.ac
Expand Up @@ -39,6 +39,7 @@ AC_CONFIG_FILES([Makefile
tests/Makefile
tests/iterator/Makefile
tests/test_s3/Makefile
tests/test_travis/Makefile
tests/test_travis_lr/Makefile
tests/test_travis_mf/Makefile])

Expand Down
15 changes: 12 additions & 3 deletions src/PSSparseServerInterface.cpp
Expand Up @@ -329,6 +329,10 @@ void PSSparseServerInterface::set_value(const std::string& key,
char key_char[KEY_SIZE] = {0};
std::copy(key.data(), key.data() + key.size(), key_char);

uint32_t operation = SET_VALUE;
if (send_all(sock, &operation, sizeof(operation)) != sizeof(operation)) {
throw std::runtime_error("Error sending operation");
}
if (send_all(sock, key_char, KEY_SIZE) != KEY_SIZE) {
throw std::runtime_error("Error sending key name");
}
Expand All @@ -340,11 +344,16 @@ void PSSparseServerInterface::set_value(const std::string& key,
}
}

std::shared_ptr<char> PSSparseServerInterface::get_value(
std::pair<std::shared_ptr<char>, uint32_t> PSSparseServerInterface::get_value(
const std::string& key) {
char key_char[KEY_SIZE] = {0};
std::copy(key.data(), key.data() + key.size(), key_char);

uint32_t operation = GET_VALUE;
if (send_all(sock, &operation, sizeof(operation)) != sizeof(operation)) {
throw std::runtime_error("Error sending operation");
}

if (send_all(sock, key_char, KEY_SIZE) != KEY_SIZE) {
throw std::runtime_error("Error sending key name");
}
Expand All @@ -356,7 +365,7 @@ std::shared_ptr<char> PSSparseServerInterface::get_value(

if (size == 0) {
// object not found
return std::shared_ptr<char>(nullptr);
return std::make_pair(std::shared_ptr<char>(nullptr), 0);
}

std::shared_ptr<char> value_data =
Expand All @@ -366,7 +375,7 @@ std::shared_ptr<char> PSSparseServerInterface::get_value(
throw std::runtime_error("Error receiving value data");
}

return value_data;
return std::make_pair(value_data, size);
}

} // namespace cirrus
Expand Down
2 changes: 1 addition & 1 deletion src/PSSparseServerInterface.h
Expand Up @@ -51,7 +51,7 @@ class PSSparseServerInterface {
* @param key Key name
* @return Returns pointer to raw value
*/
std::shared_ptr<char> get_value(const std::string& key);
std::pair<std::shared_ptr<char>, uint32_t> get_value(const std::string& key);

/*
* Marks task as running on the parameter server
Expand Down
13 changes: 10 additions & 3 deletions src/PSSparseServerTask.cpp
Expand Up @@ -44,6 +44,7 @@ PSSparseServerTask::PSSparseServerTask(uint64_t model_size,
std::atomic_init(&gradientUpdatesCount, 0UL);
std::atomic_init(&thread_count, 0);

set_operation_maps();

for (int i = 0; i < NUM_PS_WORK_THREADS; i++) {
thread_msg_buffer[i].reset(new char[THREAD_MSG_BUFFER_SIZE]);
Expand Down Expand Up @@ -523,12 +524,18 @@ bool PSSparseServerTask::process_set_value(int sock,
std::vector<char>& thread_buffer,
int) {
struct {
char key[KEY_SIZE + 1];
char key[KEY_SIZE];
uint32_t value_size;
} msg;

memset(&msg, 0, sizeof(msg));

// read the key (KEY_SIZE bytes)
if (read_all(sock, &msg, sizeof(msg)) == 0) {
if (read_all(sock, msg.key, KEY_SIZE) == 0) {
handle_failed_read(&req.poll_fd);
return false;
}
if (read_all(sock, &msg.value_size, sizeof(uint32_t)) == 0) {
handle_failed_read(&req.poll_fd);
return false;
}
Expand All @@ -537,7 +544,7 @@ bool PSSparseServerTask::process_set_value(int sock,
new char[msg.value_size], std::default_delete<char[]>());

// read the key value
if (read_all(sock, value_data.get(), sizeof(msg.value_size)) == 0) {
if (read_all(sock, value_data.get(), msg.value_size) == 0) {
handle_failed_read(&req.poll_fd);
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/Makefile.am
@@ -1,3 +1,3 @@
AUTOMAKE_OPTIONS = foreign
SUBDIRS = test_travis_lr test_travis_mf iterator
SUBDIRS = test_travis_lr test_travis_mf test_travis iterator

2 changes: 1 addition & 1 deletion tests/iterator/Makefile.am
@@ -1,7 +1,7 @@
AUTOMAKE_OPTIONS = foreign

CXX=g++
CXXFLAGS=-Wall -ansi -O3 -std=c++17 -ggdb
CXXFLAGS=-Wall -ansi -O3 -std=c++17 -ggdb -Werror

TOP_DIR=../../../..
THIRD_PARTY_DIR=../../third_party
Expand Down
2 changes: 0 additions & 2 deletions tests/iterator/test_libsvm_iterator.cpp
Expand Up @@ -30,12 +30,10 @@ int main() {
double total_loss = ret.first;
double total_accuracy = ret.second;
double total_num_samples = test_data->num_samples();
double total_num_features = test_data->num_features();

std::cout << "[ERROR_TASK] Loss (Total/Avg): " << total_loss << "/"
<< (total_loss / total_num_samples)
<< " Accuracy: " << (total_accuracy) << std::endl;
double avg_loss = (total_loss / total_num_samples);
}

return 0;
Expand Down
65 changes: 65 additions & 0 deletions tests/test_travis/Makefile.am
@@ -0,0 +1,65 @@
AUTOMAKE_OPTIONS = foreign

CXX=g++
CXXFLAGS=-Wall -ansi -O3 -std=c++17 -ggdb

bin_PROGRAMS = test_register_worker test_keyvalue

TOP_DIR=../../../..
THIRD_PARTY_DIR=../../third_party
CIRRUS_SRC_DIR=../../src
CIRRUS_SRC_FILES=$(CIRRUS_SRC_DIR)/Configuration.cpp \
$(CIRRUS_SRC_DIR)/Dataset.cpp \
$(CIRRUS_SRC_DIR)/Matrix.cpp \
$(CIRRUS_SRC_DIR)/LRModel.cpp \
$(CIRRUS_SRC_DIR)/SparseLRModel.cpp \
$(CIRRUS_SRC_DIR)/S3SparseIterator.cpp \
$(CIRRUS_SRC_DIR)/Utils.cpp \
$(CIRRUS_SRC_DIR)/MlUtils.cpp \
$(CIRRUS_SRC_DIR)/S3.cpp \
$(CIRRUS_SRC_DIR)/OptimizationMethod.cpp \
$(CIRRUS_SRC_DIR)/SGD.cpp \
$(CIRRUS_SRC_DIR)/Checksum.cpp \
$(CIRRUS_SRC_DIR)/MurmurHash3.cpp \
$(CIRRUS_SRC_DIR)/ModelGradient.cpp \
$(CIRRUS_SRC_DIR)/InputReader.cpp \
$(CIRRUS_SRC_DIR)/PSSparseServerInterface.cpp \
$(CIRRUS_SRC_DIR)/PSSparseServerTask.cpp \
$(CIRRUS_SRC_DIR)/SparseMFModel.cpp \
$(CIRRUS_SRC_DIR)/MFModel.cpp \
$(CIRRUS_SRC_DIR)/Nesterov.cpp \
$(CIRRUS_SRC_DIR)/AdaGrad.cpp \
$(CIRRUS_SRC_DIR)/Momentum.cpp \
$(CIRRUS_SRC_DIR)/S3Iterator.cpp \
$(CIRRUS_SRC_DIR)/S3IteratorLibsvm.cpp \
$(CIRRUS_SRC_DIR)/S3Client.cpp \
$(CIRRUS_SRC_DIR)/SparseDataset.cpp

LINCLUDES = -L$(THIRD_PARTY_DIR)/kerberos/src/lib \
-L$(THIRD_PARTY_DIR)/keyutils/ \
-L$(THIRD_PARTY_DIR)/gflags/lib/ \
-L$(THIRD_PARTY_DIR)/curl/curl/lib/.libs/ \
-L$(THIRD_PARTY_DIR)/aws-sdk-cpp/build/aws-cpp-sdk-core/ \
-L$(THIRD_PARTY_DIR)/aws-sdk-cpp/build/aws-cpp-sdk-s3 \
-L/home/ec2-user/kerberos/krb5-1.15.2/src/lib \
-L/home/ec2-user/keyutils/keyutils-1.5.10

LIBS= -laws-cpp-sdk-s3 -laws-cpp-sdk-core \
-lcurl -lssl -lcrypto -lz -ldl -lkrb5 -lk5crypto \
-lall -lkeyutils -lresolv -lgflags

LDFLAGS =-static-libgcc -static \
-Wl,--whole-archive -lpthread -Wl,--no-whole-archive

LDADD = $(LINCLUDES) $(LIBS)

AM_CPPFLAGS=-I$(CIRRUS_SRC_DIR) \
-I$(THIRD_PARTY_DIR)/aws-sdk-cpp/aws-cpp-sdk-s3/include/ \
-I$(THIRD_PARTY_DIR)/aws-sdk-cpp/aws-cpp-sdk-core/include/ \
-I$(THIRD_PARTY_DIR)/eigen_source/

test_register_worker_SOURCES = test_register_worker.cpp $(CIRRUS_SRC_FILES)
test_keyvalue_SOURCES = test_keyvalue.cpp $(CIRRUS_SRC_FILES)

clean:
rm -rf a.out test_iterator
29 changes: 29 additions & 0 deletions tests/test_travis/test_keyvalue.cpp
@@ -0,0 +1,29 @@
#include <PSSparseServerInterface.h>

using namespace cirrus;

#define VALUE_SIZE (1000)

int main() {
std::unique_ptr<PSSparseServerInterface> psi =
std::make_unique<PSSparseServerInterface>("127.0.0.1", 1337);
psi->connect();

char value[VALUE_SIZE];
for (int i = 0; i < VALUE_SIZE; ++i) {
value[i] = i;
}

psi->set_value("key", value, sizeof(value));

std::pair<std::shared_ptr<char>, uint32_t> ret = psi->get_value("key");

if (ret.second != sizeof(value)) {
throw std::runtime_error("Wrong size");
}
if (memcmp(&value, ret.first.get(), sizeof(value))) {
throw std::runtime_error("Wrong value");
}

return 0;
}
7 changes: 7 additions & 0 deletions tests/test_travis/test_keyvalue.sh
@@ -0,0 +1,7 @@
#!/bin/bash

timeout 60 ./tests/test_travis_lr/test_ps&
sleep 1

timeout 50 ./tests/test_travis/test_keyvalue&
sleep 1
7 changes: 7 additions & 0 deletions tests/test_travis/test_register.sh
@@ -0,0 +1,7 @@
#!/bin/bash

timeout 60 ./tests/test_travis_lr/test_ps&
sleep 1

timeout 50 ./tests/test_travis/test_register_worker&
sleep 1
File renamed without changes.

0 comments on commit e33b8a9

Please sign in to comment.