Permalink
Browse files

Big refactor from FAIR team

This commit includes a *ton* of refactoring goodness and various
stability fixes.
  • Loading branch information...
jma127 committed Jul 20, 2018
1 parent 1b6859f commit d5f320d8c687b9a406b861d387a2053a54e20ac3
Showing with 2,699 additions and 1,673 deletions.
  1. +10 −9 .gitmodules
  2. +0 −22 scripts/elfgames/go/analysis.sh
  3. +6 −6 src_cpp/elf/CMakeLists.txt
  4. +2 −2 src_cpp/elf/ai/ai.h
  5. +12 −11 src_cpp/elf/ai/tree_search/tree_search.h
  6. +7 −3 src_cpp/elf/ai/tree_search/tree_search_base.h
  7. +8 −0 src_cpp/elf/ai/tree_search/tree_search_options.h
  8. +11 −5 src_cpp/elf/base/context.h
  9. +113 −29 src_cpp/elf/base/ctrl.h
  10. +147 −0 src_cpp/elf/base/dispatcher.h
  11. +9 −2 src_cpp/elf/base/extractor.h
  12. +38 −0 src_cpp/elf/concurrency/ConcurrentQueue.h
  13. +72 −76 src_cpp/elf/distributed/shared_reader.h
  14. +26 −61 src_cpp/elf/distributed/shared_rw_buffer2.h
  15. +8 −0 src_cpp/elf/logging/IndexedLoggerFactory.h
  16. +22 −0 src_cpp/elf/utils/utils.h
  17. +40 −11 src_cpp/elfgames/go/CMakeLists.txt
  18. +12 −12 src_cpp/elfgames/go/base/board.cc
  19. +5 −1 src_cpp/elfgames/go/base/board_feature.cc
  20. 0 src_cpp/elfgames/go/base/{ → test}/board_feature_test.cc
  21. 0 src_cpp/elfgames/go/base/{ → test}/coord_test.cc
  22. +1 −1 src_cpp/elfgames/go/base/{ → test}/go_test.cc
  23. 0 src_cpp/elfgames/go/base/{ → test}/symmetry_test.cc
  24. 0 src_cpp/elfgames/go/base/{ → test}/test_utils.h
  25. +97 −0 src_cpp/elfgames/go/common/dispatcher_callback.h
  26. 0 src_cpp/elfgames/go/{ → common}/game_base.h
  27. +17 −4 src_cpp/elfgames/go/{ → common}/game_feature.h
  28. +90 −79 src_cpp/elfgames/go/{ → common}/game_selfplay.cc
  29. +11 −11 src_cpp/elfgames/go/{ → common}/game_selfplay.h
  30. +1 −1 src_cpp/elfgames/go/{ → common}/game_stats.h
  31. 0 src_cpp/elfgames/go/{ → common}/game_utils.h
  32. +8 −3 src_cpp/elfgames/go/{ → common}/go_game_specific.h
  33. +3 −1 src_cpp/elfgames/go/{ → common}/go_state_ext.cc
  34. +12 −3 src_cpp/elfgames/go/{ → common}/go_state_ext.h
  35. +75 −0 src_cpp/elfgames/go/common/model_pair.h
  36. +14 −0 src_cpp/elfgames/go/common/notifier.h
  37. +36 −85 src_cpp/elfgames/go/{ → common}/record.h
  38. +0 −78 src_cpp/elfgames/go/data_loader.h
  39. +0 −325 src_cpp/elfgames/go/game_context.h
  40. +0 −702 src_cpp/elfgames/go/game_ctrl.h
  41. +0 −58 src_cpp/elfgames/go/game_train.cc
  42. +47 −0 src_cpp/elfgames/go/inference/Pybind.cc
  43. 0 src_cpp/elfgames/go/{ → inference}/Pybind.h
  44. +123 −0 src_cpp/elfgames/go/inference/game_context.h
  45. +15 −0 src_cpp/elfgames/go/inference/pybind_module.cc
  46. +1 −1 src_cpp/elfgames/go/mcts/mcts.h
  47. +1 −1 src_cpp/elfgames/go/mcts/mcts_test.cc
  48. +1 −1 src_cpp/elfgames/go/sgf/sgf_test.cc
  49. +13 −7 src_cpp/elfgames/go/{ → train}/Pybind.cc
  50. +19 −0 src_cpp/elfgames/go/train/Pybind.h
  51. 0 src_cpp/elfgames/go/{ → train}/client_manager.cc
  52. +62 −36 src_cpp/elfgames/go/{ → train}/client_manager.h
  53. 0 src_cpp/elfgames/go/{ → train}/ctrl_eval.h
  54. +17 −5 src_cpp/elfgames/go/{ → train}/ctrl_selfplay.h
  55. +44 −6 src_cpp/elfgames/go/{ → train}/ctrl_utils.h
  56. +126 −0 src_cpp/elfgames/go/train/data_loader.h
  57. +26 −0 src_cpp/elfgames/go/train/distri_base.h
  58. +334 −0 src_cpp/elfgames/go/train/distri_client.h
  59. +123 −0 src_cpp/elfgames/go/train/distri_server.h
  60. +1 −1 src_cpp/elfgames/go/{ → train}/fair_pick.h
  61. +132 −0 src_cpp/elfgames/go/train/game_context.h
  62. +367 −0 src_cpp/elfgames/go/train/game_ctrl.h
  63. +70 −0 src_cpp/elfgames/go/train/game_train.cc
  64. +4 −5 src_cpp/elfgames/go/{ → train}/game_train.h
  65. 0 src_cpp/elfgames/go/{ → train}/pybind_module.cc
  66. +24 −7 src_py/elfgames/go/df_model3.py
  67. +223 −0 src_py/elfgames/go/game_inference.py
  68. +9 −1 src_py/rlpytorch/model_base.py
  69. +4 −1 src_py/rlpytorch/model_loader.py
@@ -1,21 +1,22 @@
[submodule "third_party/concurrentqueue"]
path = third_party/concurrentqueue
url = https://github.com/cameron314/concurrentqueue.git
[submodule "third_party/cppzmq"]
path = third_party/cppzmq
url = https://github.com/zeromq/cppzmq.git
[submodule "third_party/googletest"]
path = third_party/googletest
url = https://github.com/google/googletest.git
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json.git
[submodule "third_party/pybind11"]
path = third_party/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "third_party/spdlog"]
path = third_party/spdlog
url = https://github.com/gabime/spdlog.git
[submodule "third_party/json"]
path = third_party/json
url = https://github.com/nlohmann/json.git
[submodule "third_party/googletest"]
path = third_party/googletest
url = https://github.com/google/googletest.git
[submodule "third_party/tbb"]
path = third_party/tbb
url = https://github.com/01org/tbb.git
ignore = untracked
[submodule "third_party/cppzmq"]
path = third_party/cppzmq
url = https://github.com/zeromq/cppzmq.git

This file was deleted.

Oops, something went wrong.
@@ -14,10 +14,10 @@ set(ELF_SOURCES
options/Pybind.cc
)

set(ELF_TEST_SOURCES
options/OptionMapTest.cc
options/OptionSpecTest.cc
)
# set(ELF_TEST_SOURCES
# options/OptionMapTest.cc
# options/OptionSpecTest.cc
# )

# Main ELF library

@@ -26,10 +26,10 @@ target_compile_definitions(elf PUBLIC
GIT_COMMIT_HASH=${GIT_COMMIT_HASH}
GIT_STAGED=${GIT_STAGED_STRING}
)

target_link_libraries(elf PUBLIC
#${Boost_LIBRARIES}
concurrentqueue
cppzmq
nlohmann_json
pybind11
$<BUILD_INTERFACE:${PYTHON_LIBRARIES}>
@@ -40,7 +40,7 @@ target_link_libraries(elf PUBLIC
# Tests

enable_testing()
add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})
# add_cpp_tests(test_cpp_elf_ elf ${ELF_TEST_SOURCES})

# Python bindings

@@ -88,9 +88,9 @@ class AIClientT : public AI_T<S, A> {
status == comm::ReplyStatus::UNKNOWN;
}

virtual bool act_batch(
bool act_batch(
const std::vector<const S*>& batch_s,
const std::vector<A*>& batch_a) {
const std::vector<A*>& batch_a) override {
std::vector<elf::FuncsWithState> funcs_s =
client_->BindStateToFunctions(targets_, batch_s);
std::vector<elf::FuncsWithState> funcs_a =
@@ -70,7 +70,7 @@ class TreeSearchSingleThreadT {
: threadId_(thread_id), options_(options) {
if (options_.verbose) {
std::string log_file =
"tree_search_" + std::to_string(thread_id) + ".txt";
options_.log_prefix + std::to_string(thread_id) + ".txt";
output_.reset(new std::ofstream(log_file));
}
}
@@ -374,10 +374,7 @@ class TreeSearchT {
return searchTree_.printTree();
}

MCTSResult runPolicyOnly(const State& /*root_state*/) {
// TODO Policy only doesn't work.
assert(false);
/*
MCTSResult runPolicyOnly(const State& root_state) {
if (actors_.empty() || treeSearches_.empty()) {
throw std::range_error(
"TreeSearch::runPolicyOnly works when there is at least one thread");
@@ -386,15 +383,17 @@ class TreeSearchT {

// Some hack here.
Node* root = searchTree_.getRootNode();
treeSearches_[0]->visit(*actors_[0], root);

// return StrongestPrior(root->getStateActions());
*/
if (!root->isVisited()) {
NodeResponseT<Action> resp;
actors_[0]->evaluate(*root->getStatePtr(), &resp);
root->setEvaluation(resp);
}

MCTSResult result;
// result.action_rank_method = MCTSResult::PRIOR;
// result.addActions(root->getStateActions());

result.action_rank_method = MCTSResult::PRIOR;
result.addActions(root->getStateActions());
result.root_value = root->getValue();
return result;
}

@@ -490,6 +489,8 @@ class TreeSearchT {

// Pick the best solution.
MCTSResult result;
result.root_value = root->getValue();

// MCTSResult result2;
if (options_.pick_method == "strongest_prior") {
result.action_rank_method = MCTSResult::PRIOR;
@@ -61,8 +61,10 @@ struct StateTrait {
return s1 == s2;
}

static bool
moves_since(const S& s, size_t* next_move_number, std::vector<A>* moves) {
static bool moves_since(
const S& /*s*/,
size_t* /*next_move_number*/,
std::vector<A>* /*moves*/) {
// By default it is not provided.
return false;
}
@@ -84,7 +86,7 @@ struct ActionTrait {
template <typename Actor>
struct ActorTrait {
public:
static std::string to_string(const Actor& a) {
static std::string to_string(const Actor&) {
return "";
}
};
@@ -213,6 +215,7 @@ struct MCTSResultT {
enum RankCriterion { MOST_VISITED = 0, PRIOR = 1, UNIFORM_RANDOM };

Action best_action;
float root_value;
float max_score;
EdgeInfo best_edge_info;
MCTSPolicy<Action> mcts_policy;
@@ -224,6 +227,7 @@ struct MCTSResultT {
// action_edges ssengupta@fb.com
MCTSResultT()
: best_action(ActionTrait<Action>::default_value()),
root_value(0.0),
max_score(std::numeric_limits<float>::lowest()),
best_edge_info(0),
total_visits(0),
@@ -85,6 +85,7 @@ struct TSOptions {
bool persistent_tree = false;
float root_epsilon = 0.0;
float root_alpha = 0.0;
std::string log_prefix = "";

// [TODO] Not a good design.
// string pick_method = "strongest_prior";
@@ -102,6 +103,7 @@ struct TSOptions {
ss << "Maximal #moves (0 = no constraint): " << max_num_moves
<< std::endl;
ss << "Seed: " << seed << std::endl;
ss << "Log Prefix: " << log_prefix << std::endl;
ss << "#Threads: " << num_threads << std::endl;
ss << "#Rollout per thread: " << num_rollouts_per_thread
<< ", #rollouts per batch: " << num_rollouts_per_batch << std::endl;
@@ -156,6 +158,9 @@ struct TSOptions {
if (t1.pick_method != t2.pick_method) {
return false;
}
if (t1.log_prefix != t2.log_prefix) {
return false;
}
if (t1.root_epsilon != t2.root_epsilon) {
return false;
}
@@ -181,6 +186,7 @@ struct TSOptions {
JSON_SAVE(j, seed);
JSON_SAVE(j, persistent_tree);
JSON_SAVE(j, pick_method);
JSON_SAVE(j, log_prefix);
JSON_SAVE(j, root_epsilon);
JSON_SAVE(j, root_alpha);
JSON_SAVE(j, virtual_loss);
@@ -198,6 +204,7 @@ struct TSOptions {
JSON_LOAD(opt, j, seed);
JSON_LOAD(opt, j, persistent_tree);
JSON_LOAD(opt, j, pick_method);
JSON_LOAD(opt, j, log_prefix);
JSON_LOAD(opt, j, root_epsilon);
JSON_LOAD(opt, j, root_alpha);
JSON_LOAD(opt, j, virtual_loss);
@@ -213,6 +220,7 @@ struct TSOptions {
verbose,
persistent_tree,
pick_method,
log_prefix,
virtual_loss,
verbose_time,
alg_opt,
@@ -124,7 +124,7 @@ class Context {

void start() {
th_.reset(new std::thread([&]() {
assert(nice(10) == 10);
// assert(nice(10) == 10);
collectAndSendBatch();
}));
}
@@ -183,11 +183,17 @@ class Context {
}
}
smem_->waitBatchFillMem(server_);
// LOG(INFO) << "Receiver: Batch received. #batch = "
// << batch.size() << std::endl;
// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
// received. #batch = "
// << smem_->getEffectiveBatchSize() << std::endl;

comm::ReplyStatus batch_status =
batchClient_->sendWait(smem_.get(), {""});

// std::cout << "Receiver[" << smem_opts.getLabel() << "] Batch
// releasing. #batch = "
// << smem_->getEffectiveBatchSize() << std::endl;

// LOG(INFO) << "Receiver: Release batch" << std::endl;
smem_->waitReplyReleaseBatch(server_, batch_status);
}
@@ -280,7 +286,7 @@ class Context {
auto* client = getClient();
for (int i = 0; i < num_games_; ++i) {
game_threads_.emplace_back([i, client, this]() {
assert(nice(19) == 19);
// assert(nice(19) == 19);
client->start();
game_cb_(i, client);
client->End();
@@ -329,7 +335,7 @@ class Context {
std::atomic<bool> tmp_thread_done(false);

std::thread tmp_thread([&]() {
assert(nice(10) == 10);
// assert(nice(10) == 10);

std::cout << "Prepare to stop ..." << std::endl;
client_->prepareToStop();
Oops, something went wrong.

0 comments on commit d5f320d

Please sign in to comment.