Permalink
Browse files

Update DRAGNN, fix some macOS issues

  • Loading branch information...
1 parent b7523ee commit ea3fa4a338b76ad2c2754862bd24bb8951060e65 @bogatyy bogatyy committed Mar 23, 2017
Showing with 3,527 additions and 192 deletions.
  1. +34 −0 syntaxnet/dragnn/components/stateless/BUILD
  2. +131 −0 syntaxnet/dragnn/components/stateless/stateless_component.cc
  3. +171 −0 syntaxnet/dragnn/components/stateless/stateless_component_test.cc
  4. +8 −9 syntaxnet/dragnn/components/syntaxnet/BUILD
  5. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
  6. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
  7. +104 −5 syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
  8. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc
  9. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
  10. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
  11. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
  12. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
  13. +15 −0 syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
  14. +5 −2 syntaxnet/dragnn/components/util/BUILD
  15. +15 −0 syntaxnet/dragnn/components/util/bulk_feature_extractor.h
  16. +23 −31 syntaxnet/dragnn/core/BUILD
  17. +18 −2 syntaxnet/dragnn/core/beam.h
  18. +15 −0 syntaxnet/dragnn/core/beam_test.cc
  19. +15 −0 syntaxnet/dragnn/core/component_registry.cc
  20. +15 −0 syntaxnet/dragnn/core/component_registry.h
  21. +15 −0 syntaxnet/dragnn/core/compute_session.h
  22. +15 −0 syntaxnet/dragnn/core/compute_session_impl.cc
  23. +15 −0 syntaxnet/dragnn/core/compute_session_impl.h
  24. +15 −0 syntaxnet/dragnn/core/compute_session_impl_test.cc
  25. +15 −0 syntaxnet/dragnn/core/compute_session_pool.cc
  26. +15 −0 syntaxnet/dragnn/core/compute_session_pool.h
  27. +15 −0 syntaxnet/dragnn/core/compute_session_pool_test.cc
  28. +15 −0 syntaxnet/dragnn/core/index_translator.cc
  29. +15 −0 syntaxnet/dragnn/core/index_translator.h
  30. +15 −0 syntaxnet/dragnn/core/index_translator_test.cc
  31. +20 −5 syntaxnet/dragnn/core/input_batch_cache.h
  32. +15 −0 syntaxnet/dragnn/core/input_batch_cache_test.cc
  33. +4 −1 syntaxnet/dragnn/core/interfaces/BUILD
  34. +15 −0 syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h
  35. +15 −0 syntaxnet/dragnn/core/interfaces/component.h
  36. +15 −0 syntaxnet/dragnn/core/interfaces/input_batch.h
  37. +15 −0 syntaxnet/dragnn/core/interfaces/transition_state.h
  38. +15 −0 syntaxnet/dragnn/core/interfaces/transition_state_starter_test.cc
  39. +15 −0 syntaxnet/dragnn/core/ops/compute_session_op.cc
  40. +15 −0 syntaxnet/dragnn/core/ops/compute_session_op.h
  41. +15 −0 syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
  42. +15 −0 syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
  43. +22 −3 syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
  44. +15 −0 syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
  45. +15 −0 syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
  46. +21 −0 syntaxnet/dragnn/core/ops/dragnn_ops.cc
  47. +15 −0 syntaxnet/dragnn/core/resource_container.h
  48. +15 −0 syntaxnet/dragnn/core/resource_container_test.cc
  49. +6 −7 syntaxnet/dragnn/core/test/BUILD
  50. +15 −0 syntaxnet/dragnn/core/test/generic.cc
  51. +15 −0 syntaxnet/dragnn/core/test/generic.h
  52. +15 −0 syntaxnet/dragnn/core/test/mock_component.h
  53. +15 −0 syntaxnet/dragnn/core/test/mock_compute_session.h
  54. +15 −0 syntaxnet/dragnn/core/test/mock_transition_state.h
  55. +15 −0 syntaxnet/dragnn/io/sentence_input_batch.cc
  56. +15 −0 syntaxnet/dragnn/io/sentence_input_batch.h
  57. +15 −0 syntaxnet/dragnn/io/sentence_input_batch_test.cc
  58. +15 −0 syntaxnet/dragnn/io/syntaxnet_sentence.h
  59. +1 −0 syntaxnet/dragnn/python/BUILD
  60. +22 −7 syntaxnet/dragnn/python/biaffine_units.py
  61. +19 −2 syntaxnet/dragnn/python/bulk_component.py
  62. +15 −0 syntaxnet/dragnn/python/bulk_component_test.py
  63. +43 −3 syntaxnet/dragnn/python/component.py
  64. +15 −0 syntaxnet/dragnn/python/composite_optimizer.py
  65. +16 −2 syntaxnet/dragnn/python/composite_optimizer_test.py
  66. +15 −0 syntaxnet/dragnn/python/digraph_ops.py
  67. +15 −0 syntaxnet/dragnn/python/digraph_ops_test.py
  68. +15 −0 syntaxnet/dragnn/python/dragnn_ops.py
  69. +49 −27 syntaxnet/dragnn/python/graph_builder.py
  70. +32 −0 syntaxnet/dragnn/python/graph_builder_test.py
  71. +41 −40 syntaxnet/dragnn/python/network_units.py
  72. +15 −0 syntaxnet/dragnn/python/network_units_test.py
  73. +15 −0 syntaxnet/dragnn/python/render_parse_tree_graphviz.py
  74. +15 −0 syntaxnet/dragnn/python/render_parse_tree_graphviz_test.py
  75. +15 −0 syntaxnet/dragnn/python/render_spec_with_graphviz.py
  76. +15 −0 syntaxnet/dragnn/python/render_spec_with_graphviz_test.py
  77. +15 −0 syntaxnet/dragnn/python/sentence_io.py
  78. +15 −0 syntaxnet/dragnn/python/sentence_io_test.py
  79. +6 −2 syntaxnet/dragnn/python/spec_builder.py
  80. +17 −2 syntaxnet/dragnn/python/trainer_lib.py
  81. +15 −0 syntaxnet/dragnn/python/visualization.py
  82. +15 −0 syntaxnet/dragnn/python/visualization_test.py
  83. +41 −17 syntaxnet/dragnn/python/wrapped_units.py
  84. +38 −1 syntaxnet/dragnn/tools/BUILD
  85. +15 −0 syntaxnet/dragnn/tools/build_pip_package.py
  86. +21 −0 syntaxnet/dragnn/tools/evaluator.py
  87. +197 −0 syntaxnet/dragnn/tools/model_trainer.py
  88. +54 −0 syntaxnet/dragnn/tools/model_trainer_test.sh
  89. +15 −0 syntaxnet/dragnn/tools/oss_notebook_launcher.py
  90. +15 −0 syntaxnet/dragnn/tools/parse-to-conll.py
  91. +0 −1 syntaxnet/dragnn/tools/parser_trainer.py
  92. +15 −0 syntaxnet/dragnn/tools/segmenter-evaluator.py
  93. +4 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/config.txt
  94. +18 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/hyperparameters.pbtxt
  95. +1,135 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/master.pbtxt
  96. +7 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/category-map
  97. +18 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-map
  98. +46 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-ngram-map
  99. +8 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/label-map
  100. +11 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/lcword-map
  101. BIN syntaxnet/dragnn/tools/testdata/biaffine.model/resources/prefix-table
  102. BIN syntaxnet/dragnn/tools/testdata/biaffine.model/resources/suffix-table
  103. +8 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/tag-map
  104. +7 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/tag-to-category
  105. +11 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/resources/word-map
  106. +17 −0 syntaxnet/dragnn/tools/testdata/biaffine.model/targets.pbtxt
  107. +29 −0 syntaxnet/dragnn/tools/testdata/small.conll
  108. +14 −0 syntaxnet/dragnn/viz/compile-minified.sh
  109. +14 −0 syntaxnet/dragnn/viz/develop.sh
  110. +4 −6 syntaxnet/syntaxnet/BUILD
  111. +5 −1 syntaxnet/syntaxnet/graph_builder.py
  112. +9 −0 syntaxnet/syntaxnet/ops/parser_ops.cc
  113. +14 −5 syntaxnet/syntaxnet/reader_ops.cc
  114. +38 −10 syntaxnet/syntaxnet/reader_ops_test.py
  115. +1 −1 syntaxnet/tensorflow
@@ -0,0 +1,34 @@
+package(
+ default_visibility = ["//visibility:public"],
+ features = ["-layering_check"],
+)
+
+cc_library(
+ name = "stateless_component",
+ srcs = ["stateless_component.cc"],
+ deps = [
+ "//dragnn/core:component_registry",
+ "//dragnn/core/interfaces:component",
+ "//dragnn/core/interfaces:transition_state",
+ "//dragnn/io:sentence_input_batch",
+ "//dragnn/protos:data_proto",
+ "//syntaxnet:base",
+ ],
+ alwayslink = 1,
+)
+
+cc_test(
+ name = "stateless_component_test",
+ srcs = ["stateless_component_test.cc"],
+ deps = [
+ ":stateless_component",
+ "//dragnn/core:component_registry",
+ "//dragnn/core:input_batch_cache",
+ "//dragnn/core/test:generic",
+ "//dragnn/core/test:mock_transition_state",
+ "//dragnn/io:sentence_input_batch",
+ "//syntaxnet:base",
+ "//syntaxnet:sentence_proto",
+ "//syntaxnet:test_main",
+ ],
+)
@@ -0,0 +1,131 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "dragnn/core/component_registry.h"
+#include "dragnn/core/interfaces/component.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/io/sentence_input_batch.h"
+#include "dragnn/protos/data.pb.h"
+#include "syntaxnet/base.h"
+
+namespace syntaxnet {
+namespace dragnn {
+namespace {
+
+// A component that does not create its own transition states; instead, it
+// simply forwards the states of the previous component. Does not support all
+// methods. Intended for "compute-only" bulk components that only use linked
+// features, which use only a small subset of DRAGNN functionality.
+class StatelessComponent : public Component {
+ public:
+ void InitializeComponent(const ComponentSpec &spec) override {
+ name_ = spec.name();
+ }
+
+ // Stores the |parent_states| for forwarding to downstream components.
+ void InitializeData(
+ const std::vector<std::vector<const TransitionState *>> &parent_states,
+ int max_beam_size, InputBatchCache *input_data) override {
+ // Must use SentenceInputBatch to match SyntaxNetComponent.
+ batch_size_ = input_data->GetAs<SentenceInputBatch>()->data()->size();
+ beam_size_ = max_beam_size;
+ parent_states_ = parent_states;
+
+ // The beam should be wide enough for the previous component.
+ for (const auto &beam : parent_states) {
+ CHECK_LE(beam.size(), beam_size_);
+ }
+ }
+
+ // Forwards the states of the previous component.
+ std::vector<std::vector<const TransitionState *>> GetBeam() override {
+ return parent_states_;
+ }
+
+ // Forwards the |current_index| to the previous component.
+ int GetSourceBeamIndex(int current_index, int batch) const override {
+ return current_index;
+ }
+
+ string Name() const override { return name_; }
+ int BeamSize() const override { return beam_size_; }
+ int BatchSize() const override { return batch_size_; }
+ int StepsTaken(int batch_index) const override { return 0; }
+ bool IsReady() const override { return true; }
+ bool IsTerminal() const override { return true; }
+ void FinalizeData() override {}
+ void ResetComponent() override {}
+ void InitializeTracing() override {}
+ void DisableTracing() override {}
+ std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
+ return {};
+ }
+
+ // Unsupported methods.
+ int GetBeamIndexAtStep(int step, int current_index,
+ int batch) const override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ return 0;
+ }
+ std::function<int(int, int, int)> GetStepLookupFunction(
+ const string &method) override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ return nullptr;
+ }
+ void AdvanceFromPrediction(const float transition_matrix[],
+ int matrix_length) override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ }
+ void AdvanceFromOracle() override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ }
+ std::vector<std::vector<int>> GetOracleLabels() const override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ return {};
+ }
+ int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
+ std::function<int64 *(int)> allocate_ids,
+ std::function<float *(int)> allocate_weights,
+ int channel_id) const override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ return 0;
+ }
+ int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ return 0;
+ }
+ std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ return {};
+ }
+ void AddTranslatedLinkFeaturesToTrace(
+ const std::vector<LinkFeatures> &features, int channel_id) override {
+ LOG(FATAL) << "[" << name_ << "] Method not supported";
+ }
+
+ private:
+ string name_; // component name
+ int batch_size_ = 1; // number of sentences in current batch
+ int beam_size_ = 1; // maximum beam size
+
+ // Parent states passed to InitializeData(), and passed along in GetBeam().
+ std::vector<std::vector<const TransitionState *>> parent_states_;
+};
+
+REGISTER_DRAGNN_COMPONENT(StatelessComponent);
+
+} // namespace
+} // namespace dragnn
+} // namespace syntaxnet
@@ -0,0 +1,171 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "dragnn/core/component_registry.h"
+#include "dragnn/core/input_batch_cache.h"
+#include "dragnn/core/test/generic.h"
+#include "dragnn/core/test/mock_transition_state.h"
+#include "dragnn/io/sentence_input_batch.h"
+#include "syntaxnet/base.h"
+#include "syntaxnet/sentence.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+namespace {
+
+const char kSentence0[] = R"(
+token {
+ word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
+ break_level: NO_BREAK
+}
+token {
+ word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
+ break_level: SPACE_BREAK
+}
+token {
+ word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
+ break_level: NO_BREAK
+}
+)";
+
+const char kSentence1[] = R"(
+token {
+ word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
+ break_level: NO_BREAK
+}
+token {
+ word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
+ break_level: SPACE_BREAK
+}
+token {
+ word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
+ break_level: NO_BREAK
+}
+)";
+
+const char kLongSentence[] = R"(
+token {
+ word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
+ break_level: NO_BREAK
+}
+token {
+ word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
+ break_level: SPACE_BREAK
+}
+token {
+ word: "2" start: 10 end: 10 head: 0 tag: "CD" category: "NUM" label: "num"
+ break_level: SPACE_BREAK
+}
+token {
+ word: "3" start: 11 end: 11 head: 0 tag: "CD" category: "NUM" label: "num"
+ break_level: SPACE_BREAK
+}
+token {
+ word: "." start: 12 end: 12 head: 0 tag: "." category: "." label: "punct"
+ break_level: NO_BREAK
+}
+)";
+
+const char kMasterSpec[] = R"(
+component {
+ name: "test"
+ transition_system {
+ registered_name: "shift-only"
+ }
+ linked_feature {
+ name: "prev"
+ fml: "input.focus"
+ embedding_dim: 32
+ size: 1
+ source_component: "prev"
+ source_translator: "identity"
+ source_layer: "last_layer"
+ }
+ backend {
+ registered_name: "StatelessComponent"
+ }
+}
+)";
+
+} // namespace
+
+using testing::Return;
+
+class StatelessComponentTest : public ::testing::Test {
+ public:
+ std::unique_ptr<Component> CreateParser(
+ int beam_size,
+ const std::vector<std::vector<const TransitionState *>> &states,
+ const std::vector<string> &data) {
+ MasterSpec master_spec;
+ CHECK(TextFormat::ParseFromString(kMasterSpec, &master_spec));
+ data_.reset(new InputBatchCache(data));
+
+ // Create a parser component with the specified beam size.
+ std::unique_ptr<Component> parser_component(
+ Component::Create("StatelessComponent"));
+ parser_component->InitializeComponent(master_spec.component(0));
+ parser_component->InitializeData(states, beam_size, data_.get());
+ return parser_component;
+ }
+
+ std::unique_ptr<InputBatchCache> data_;
+};
+
+TEST_F(StatelessComponentTest, ForwardsTransitionStates) {
+ const MockTransitionState mock_state_1, mock_state_2, mock_state_3;
+ const std::vector<std::vector<const TransitionState *>> parent_states = {
+ {}, {&mock_state_1}, {&mock_state_2, &mock_state_3}};
+
+ std::vector<string> data;
+ for (const string &textproto : {kSentence0, kSentence1, kLongSentence}) {
+ Sentence sentence;
+ CHECK(TextFormat::ParseFromString(textproto, &sentence));
+ data.emplace_back();
+ CHECK(sentence.SerializeToString(&data.back()));
+ }
+ CHECK_EQ(parent_states.size(), data.size());
+
+ const int kBeamSize = 2;
+ auto test_parser = CreateParser(kBeamSize, parent_states, data);
+
+ EXPECT_TRUE(test_parser->IsReady());
+ EXPECT_TRUE(test_parser->IsTerminal());
+ EXPECT_EQ(kBeamSize, test_parser->BeamSize());
+ EXPECT_EQ(data.size(), test_parser->BatchSize());
+ EXPECT_TRUE(test_parser->GetTraceProtos().empty());
+
+ for (int batch_index = 0; batch_index < parent_states.size(); ++batch_index) {
+ EXPECT_EQ(0, test_parser->StepsTaken(batch_index));
+ const auto &beam = parent_states[batch_index];
+ for (int beam_index = 0; beam_index < beam.size(); ++beam_index) {
+ // Expect an identity mapping.
+ EXPECT_EQ(beam_index,
+ test_parser->GetSourceBeamIndex(beam_index, batch_index));
+ }
+ }
+
+ const auto forwarded_states = test_parser->GetBeam();
+ EXPECT_EQ(parent_states, forwarded_states);
+}
+
+} // namespace dragnn
+} // namespace syntaxnet
Oops, something went wrong.

0 comments on commit ea3fa4a

Please sign in to comment.