Permalink
Browse files

Slight logger refactor and memory pressure reduction

  • Loading branch information...
jma127 committed Oct 15, 2018
1 parent a39e7dc commit 7b50f00ea061c2c1593b13b9e58f30524153aa47
Showing with 112 additions and 84 deletions.
  1. +5 −5 src_cpp/elf/CMakeLists.txt
  2. +3 −2 src_cpp/elf/ai/tree_search/mcts.h
  3. +4 −3 src_cpp/elf/ai/tree_search/tree_search.h
  4. +1 −6 src_cpp/elf/ai/tree_search/tree_search_base.h
  5. +2 −1 src_cpp/elf/base/context.h
  6. +3 −2 src_cpp/elf/base/dispatcher.h
  7. +4 −2 src_cpp/elf/base/extractor.h
  8. +1 −1 src_cpp/elf/base/sharedmem.h
  9. +1 −1 src_cpp/elf/comm/comm.h
  10. +3 −2 src_cpp/elf/distributed/shared_reader.h
  11. +4 −2 src_cpp/elf/distributed/shared_rw_buffer2.h
  12. +10 −5 src_cpp/elf/distributed/zmq_util.h
  13. +3 −1 src_cpp/elf/legacy/python_options_utils_cpp.h
  14. +16 −0 src_cpp/elf/logging/IndexedLoggerFactory.cc
  15. +6 −6 src_cpp/elf/logging/IndexedLoggerFactory.h
  16. +3 −2 src_cpp/elfgames/go/base/board_feature.h
  17. +1 −1 src_cpp/elfgames/go/common/dispatcher_callback.h
  18. +3 −2 src_cpp/elfgames/go/common/game_base.h
  19. +1 −1 src_cpp/elfgames/go/common/game_selfplay.cc
  20. +3 −2 src_cpp/elfgames/go/common/game_stats.h
  21. +4 −3 src_cpp/elfgames/go/common/go_state_ext.h
  22. +1 −1 src_cpp/elfgames/go/inference/game_context.h
  23. +3 −1 src_cpp/elfgames/go/mcts/mcts.h
  24. +2 −1 src_cpp/elfgames/go/sgf/sgf.h
  25. +1 −1 src_cpp/elfgames/go/train/client_manager.h
  26. +6 −4 src_cpp/elfgames/go/train/ctrl_eval.h
  27. +2 −2 src_cpp/elfgames/go/train/ctrl_selfplay.h
  28. +1 −1 src_cpp/elfgames/go/train/data_loader.h
  29. +3 −3 src_cpp/elfgames/go/train/distri_client.h
  30. +1 −1 src_cpp/elfgames/go/train/distri_server.h
  31. +2 −1 src_cpp/elfgames/go/train/game_context.h
  32. +6 −4 src_cpp/elfgames/go/train/game_ctrl.h
  33. +1 −5 src_py/elfgames/go/mcts_prediction.py
  34. +2 −9 src_py/rlpytorch/model_loader.py
@@ -14,10 +14,10 @@ set(ELF_SOURCES
options/Pybind.cc
)

# set(ELF_TEST_SOURCES
# options/OptionMapTest.cc
# options/OptionSpecTest.cc
# )
set(ELF_TEST_SOURCES
options/OptionMapTest.cc
options/OptionSpecTest.cc
)

# Main ELF library

@@ -40,7 +40,7 @@ target_link_libraries(elf PUBLIC
# Tests

enable_testing()
# add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})
add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})

# Python bindings

@@ -42,8 +42,9 @@ class MCTSAI_T : public AI_T<typename Actor::State, typename Actor::Action> {
const elf::ai::tree_search::TSOptions& options,
std::function<Actor*(int)> gen)
: options_(options),
logger_(
elf::logging::getLogger("elf::ai::tree_search::MCTSAI_T-", "")) {
logger_(elf::logging::getIndexedLogger(
"elf::ai::tree_search::MCTSAI_T-",
"")) {
ts_.reset(new TreeSearch(options_, gen));
}

@@ -70,7 +70,7 @@ class TreeSearchSingleThreadT {
TreeSearchSingleThreadT(int thread_id, const TSOptions& options)
: threadId_(thread_id),
options_(options),
logger_(elf::logging::getLogger(
logger_(elf::logging::getIndexedLogger(
"elf::ai::tree_search::TreeSearchSingleThreadT-",
"")) {
if (options_.verbose) {
@@ -334,8 +334,9 @@ class TreeSearchT {
TreeSearchT(const TSOptions& options, std::function<Actor*(int)> actor_gen)
: options_(options),
stopSearch_(false),
logger_(
elf::logging::getLogger("elf::ai::tree_search::TreeSearchT-", "")) {
logger_(elf::logging::getIndexedLogger(
"elf::ai::tree_search::TreeSearchT-",
"")) {
for (int i = 0; i < options.num_threads; ++i) {
treeSearches_.emplace_back(new TreeSearchSingleThread(i, options_));
actors_.emplace_back(actor_gen(i));
@@ -116,9 +116,7 @@ struct EdgeInfo {
child_node(InvalidNodeId),
reward(0),
num_visits(0),
virtual_loss(0),
logger_(
elf::logging::getLogger("elf::ai::tree_search::EdgeInfo-", "")) {}
virtual_loss(0) {}

float getQSA() const {
return reward / num_visits;
@@ -127,9 +125,6 @@ struct EdgeInfo {
// TODO: What is this function doing (ssengupta@fb.com)
void checkValid() const {
if (virtual_loss != 0) {
// TODO: This should be a Google log (ssengupta@fb)
logger_->info(
"Virtual loss is not zero[{}]\n{}", virtual_loss, info(true));
assert(virtual_loss == 0);
}
}
@@ -200,7 +200,8 @@ class Context {
public:
using GameCallback = std::function<void(int game_idx, GameClient*)>;

Context() : logger_(elf::logging::getLogger("elf::base::Context-", "")) {
Context()
: logger_(elf::logging::getIndexedLogger("elf::base::Context-", "")) {
// Wait for the derived class to add entries to extractor_.
server_ = comm_.getServer();
client_.reset(new GameClient(&comm_, this));
@@ -21,8 +21,9 @@ class ThreadedDispatcherT : public ThreadedCtrlBase {
ThreadedDispatcherT(Ctrl& ctrl, int num_games)
: ThreadedCtrlBase(ctrl, 500),
num_games_(num_games),
logger_(
elf::logging::getLogger("elf::base::ThreadedDispatcherT-", "")) {}
logger_(elf::logging::getIndexedLogger(
"elf::base::ThreadedDispatcherT-",
"")) {}

void Start(ServerReply replier, ServerFirstSend first_send = nullptr) {
server_replier_ = replier;
@@ -471,7 +471,8 @@ class ClassFieldT;
//
class Extractor {
public:
Extractor() : logger_(elf::logging::getLogger("elf::base::Extractor-", "")) {}
Extractor()
: logger_(elf::logging::getIndexedLogger("elf::base::Extractor-", "")) {}

template <typename T>
FuncMapT<T>& addField(const std::string& key) {
@@ -579,7 +580,8 @@ class ClassFieldT {

ClassFieldT(Extractor* ext)
: ext_(ext),
logger_(elf::logging::getLogger("elf::base::ClassFieldT-", "")) {}
logger_(elf::logging::getIndexedLogger("elf::base::ClassFieldT-", "")) {
}

template <typename T>
ClassField& addFunction(
@@ -133,7 +133,7 @@ class SharedMem {
const std::unordered_map<std::string, AnyP>& mem)
: opts_(smem_opts),
mem_(mem),
logger_(elf::logging::getLogger("elf::base::SharedMem-", "")) {
logger_(elf::logging::getIndexedLogger("elf::base::SharedMem-", "")) {
opts_.setIdx(idx);
}

@@ -327,7 +327,7 @@ class CommT : public CommInternalT<
: CommInternal::Client(pp),
pp_(pp),
rng_(time(NULL)),
logger_(elf::logging::getLogger("elf::comm::Client-", "")) {}
logger_(elf::logging::getIndexedLogger("elf::comm::Client-", "")) {}

ReplyStatus sendWait(Data data, const std::vector<std::string>& labels) {
return CommInternal::Client::sendWait(
@@ -164,8 +164,9 @@ class ReaderQueuesT {
ReaderQueuesT(const RQCtrl& reader_ctrl)
: min_size_satisfied_(false),
parity_sizes_(2, 0),
logger_(
elf::logging::getLogger("elf::distributed::ReaderQueuesT-", "")) {
logger_(elf::logging::getIndexedLogger(
"elf::distributed::ReaderQueuesT-",
"")) {
// Make sure this is an even number.
assert(reader_ctrl.num_reader % 2 == 0);
min_size_per_queue_ = reader_ctrl.ctrl.queue_min_size;
@@ -59,7 +59,8 @@ class Writer {
Writer(const Options& opt)
: rng_(time(NULL)),
options_(opt),
logger_(elf::logging::getLogger("elf::distributed::Writer-", "")) {
logger_(
elf::logging::getIndexedLogger("elf::distributed::Writer-", "")) {
identity_ = options_.identity + "-" + get_id(rng_);
sender_.reset(new elf::distri::ZMQSender(
identity_, options_.addr, options_.port, options_.use_ipv6));
@@ -151,7 +152,8 @@ class Reader {
db_name_(filename),
rng_(time(NULL)),
done_(false),
logger_(elf::logging::getLogger("elf::distributed::Reader-", "")) {}
logger_(
elf::logging::getIndexedLogger("elf::distributed::Reader-", "")) {}

void startReceiving(
ProcessFunc proc_func,
@@ -80,8 +80,9 @@ class SegmentedRecv {
public:
SegmentedRecv(zmq::socket_t& socket)
: socket_(socket),
logger_(
elf::logging::getLogger("elf::distributed::SegmentedRecv-", "")) {}
logger_(elf::logging::getIndexedLogger(
"elf::distributed::SegmentedRecv-",
"")) {}

/*
void recvBlocked(size_t n, std::vector<std::string>* p_msgs) {
@@ -179,7 +180,7 @@ class SegmentedRecv {
class SameThreadChecker {
public:
SameThreadChecker()
: logger_(elf::logging::getLogger(
: logger_(elf::logging::getIndexedLogger(
"elf::distributed::SameThreadChecker-",
"")) {
id_ = std::this_thread::get_id();
@@ -210,7 +211,9 @@ class ZMQReceiver : public SameThreadChecker {
public:
ZMQReceiver(int port, bool use_ipv6)
: context_(1),
logger_(elf::logging::getLogger("elf::distributed::ZMQReceiver-", "")) {
logger_(elf::logging::getIndexedLogger(
"elf::distributed::ZMQReceiver-",
"")) {
broker_.reset(new zmq::socket_t(context_, ZMQ_ROUTER));
if (use_ipv6) {
int ipv6 = 1;
@@ -286,7 +289,9 @@ class ZMQSender : public SameThreadChecker {
int port,
bool use_ipv6)
: context_(1),
logger_(elf::logging::getLogger("elf::distributed::ZMQSender-", "")) {
logger_(elf::logging::getIndexedLogger(
"elf::distributed::ZMQSender-",
"")) {
sender_.reset(new zmq::socket_t(context_, ZMQ_DEALER));
if (use_ipv6) {
int ipv6 = 1;
@@ -32,7 +32,9 @@ struct ContextOptions {
std::shared_ptr<spdlog::logger> _logger;

ContextOptions()
: _logger(elf::logging::getLogger("elf::legacy::ContextOptions-", "")) {}
: _logger(elf::logging::getIndexedLogger(
"elf::legacy::ContextOptions-",
"")) {}

void print() const {
_logger->info("JobId: {}", job_id);
@@ -12,6 +12,8 @@

#include "IndexedLoggerFactory.h"

#include <iostream>

namespace elf {
namespace logging {

@@ -22,15 +24,29 @@ void IndexedLoggerFactory::registerPy(pybind11::module& m) {
.def(py::init<CreatorT>())
.def(py::init<CreatorT, size_t>())
.def("makeLogger", &IndexedLoggerFactory::makeLogger);

m.def("getIndexedLogger", getIndexedLogger);
}

std::shared_ptr<spdlog::logger> IndexedLoggerFactory::makeLogger(
const std::string& prefix,
const std::string& suffix) {
size_t curCount = counter_++;
std::string loggerName = prefix + std::to_string(curCount) + suffix;

return creator_(loggerName);
}

std::shared_ptr<spdlog::logger> getIndexedLogger(
const std::string& prefix,
const std::string& suffix) {
static IndexedLoggerFactory factory([](const std::string& name) {
auto ptr = spdlog::stderr_color_mt(name);
spdlog::drop(name);
return ptr;
});
return factory.makeLogger(prefix, suffix);
}

} // namespace logging
} // namespace elf
@@ -34,6 +34,10 @@
* };
*
* }
*
* WARNING: Use this *only* when you are able to guarantee a bounded number of
* object instantiations. This class will automatically enforce an upper bound
* of a few thousand.
*/

#pragma once
@@ -69,13 +73,9 @@ class IndexedLoggerFactory {
std::atomic_size_t counter_;
};

inline std::shared_ptr<spdlog::logger> getLogger(
std::shared_ptr<spdlog::logger> getIndexedLogger(
const std::string& prefix,
const std::string& suffix) {
static IndexedLoggerFactory factory(
[](const std::string& name) { return spdlog::stderr_color_mt(name); });
return factory.makeLogger(prefix, suffix);
}
const std::string& suffix);

} // namespace logging
} // namespace elf
@@ -66,8 +66,9 @@ class BoardFeature {
: s_(s),
_rot(rot),
_flip(flip),
logger_(
elf::logging::getLogger("elfgames::go::base::BoardFeature-", "")) {}
logger_(elf::logging::getIndexedLogger(
"elfgames::go::base::BoardFeature-",
"")) {}
BoardFeature(const GoState& s) : s_(s), _rot(NONE), _flip(false) {}

static BoardFeature RandomShuffle(const GoState& s, std::mt19937* rng) {
@@ -13,7 +13,7 @@ class DispatcherCallback {
public:
DispatcherCallback(ThreadedDispatcher* dispatcher, elf::GameClient* client)
: client_(client),
logger_(elf::logging::getLogger(
logger_(elf::logging::getIndexedLogger(
"elfgames::go::common::DispatcherCallback-",
"")) {
using std::placeholders::_1;
@@ -26,8 +26,9 @@ class GoGameBase {
_game_idx(game_idx),
_options(options),
_context_options(context_options),
_logger(
elf::logging::getLogger("elfgames::go::common::GoGameBase-", "")) {
_logger(elf::logging::getIndexedLogger(
"elfgames::go::common::GoGameBase-",
"")) {
if (options.seed == 0) {
_seed = elf_utils::get_seed(
game_idx ^ std::hash<std::string>{}(context_options.job_id));
@@ -23,7 +23,7 @@ GoGameSelfPlay::GoGameSelfPlay(
dispatcher_(dispatcher),
notifier_(notifier),
_state_ext(game_idx, options),
logger_(elf::logging::getLogger(
logger_(elf::logging::getIndexedLogger(
"elfgames::go::GoGameSelfPlay-" + std::to_string(game_idx) + "-",
"")) {}

@@ -21,8 +21,9 @@
class GameStats {
public:
GameStats()
: _logger(
elf::logging::getLogger("elfgames::go::common::GameStats-", "")) {}
: _logger(elf::logging::getIndexedLogger(
"elfgames::go::common::GameStats-",
"")) {}

void feedMoveRanking(int ranking) {
std::lock_guard<std::mutex> lock(_mutex);
@@ -39,8 +39,9 @@ struct GoStateExt {
_last_value(0.0),
_resign_check(options.resign_thres, options.resign_prob_never),
_options(options),
_logger(
elf::logging::getLogger("elfgames::go::common::GoStateExt-", "")) {
_logger(elf::logging::getIndexedLogger(
"elfgames::go::common::GoStateExt-",
"")) {
restart();
}

@@ -263,7 +264,7 @@ class GoStateExtOffline {
: _game_idx(game_idx),
_bf(_state),
_options(options),
_logger(elf::logging::getLogger(
_logger(elf::logging::getIndexedLogger(
"elfgames::go::common::GoStateExtOffline-",
"")) {}

@@ -31,7 +31,7 @@ class GameContext {
GameContext(const ContextOptions& contextOptions, const GameOptions& options)
: contextOptions_(contextOptions),
goFeature_(options),
logger_(elf::logging::getLogger("GameContext-", "")) {
logger_(elf::logging::getIndexedLogger("GameContext-", "")) {
context_.reset(new elf::Context);

// Only works for online setting.
@@ -47,7 +47,9 @@ class MCTSActor {
MCTSActor(elf::GameClient* client, const MCTSActorParams& params)
: params_(params),
rng_(params.seed),
logger_(elf::logging::getLogger("elfgames::go::mcts::MCTSActor-", "")) {
logger_(elf::logging::getIndexedLogger(
"elfgames::go::mcts::MCTSActor-",
"")) {
ai_.reset(new AI(client, {params_.actor_name}));
}

@@ -260,7 +260,8 @@ class Sgf {

Sgf()
: _num_moves(0),
_logger(elf::logging::getLogger("elfgames::go::sgf::Sgf-", "")) {}
_logger(elf::logging::getIndexedLogger("elfgames::go::sgf::Sgf-", "")) {
}
bool load(const std::string& filename);
bool load(const std::string& filename, const std::string& game);

Oops, something went wrong.

0 comments on commit 7b50f00

Please sign in to comment.