From 590852ae5fb2416b5b93f62eb5d4bd8a8d6ec4f9 Mon Sep 17 00:00:00 2001 From: shendiaomo Date: Wed, 15 Jan 2020 18:09:57 +0800 Subject: [PATCH 1/4] Add doc string for models in sqlflow_models --- pkg/sql/codegen/attribute/attribute.go | 15 +++++++++++++++ python/extract_docstring.py | 10 ++++++++++ 2 files changed, 25 insertions(+) diff --git a/pkg/sql/codegen/attribute/attribute.go b/pkg/sql/codegen/attribute/attribute.go index 93a90a9b6f..353f043100 100644 --- a/pkg/sql/codegen/attribute/attribute.go +++ b/pkg/sql/codegen/attribute/attribute.go @@ -16,6 +16,8 @@ package attribute import ( "encoding/json" "fmt" + "log" + "os/exec" "reflect" "sort" "strings" @@ -158,6 +160,18 @@ func NewDictionaryFromModelDefinition(estimator, prefix string) Dictionary { // PremadeModelParamsDocs stores parameters and documents of all known models var PremadeModelParamsDocs map[string]map[string]string +// ExtractDocString extracts parameter documents from python doc strings +func ExtractDocString(module ...string) { + cmd := exec.Command("python", "-uc", fmt.Sprintf("__import__('extract_docstring').print_param_doc('%s')", strings.Join(module, "', '"))) + output, e := cmd.CombinedOutput() + if e != nil { + log.Println("ExtractDocString failed: ", e, string(output)) + } + if e := json.Unmarshal(output, &PremadeModelParamsDocs); e != nil { + log.Println("ExtractDocString failed:", e, string(output)) + } +} + func removeUnnecessaryParams() { // The following parameters of canned estimators are already supported in the COLUMN clause. for _, v := range PremadeModelParamsDocs { @@ -171,5 +185,6 @@ func init() { if err := json.Unmarshal([]byte(ModelParameterJSON), &PremadeModelParamsDocs); err != nil { panic(err) // assertion } + ExtractDocString("sqlflow_models") removeUnnecessaryParams() } diff --git a/python/extract_docstring.py b/python/extract_docstring.py index bbe50077fa..305acd3d43 100644 --- a/python/extract_docstring.py +++ b/python/extract_docstring.py @@ -74,6 +74,16 @@ def parse_ctor_args(f, prefix=''): [' '.join(doc.split()).replace("`", "'") for doc in total[2::2]])) +def print_param_doc(*modules): + param_doc = {} # { "class_names": {"parameters": "splitted docstrings"} } + for module in modules: + models = filter(lambda m: inspect.isclass(m[1]), + inspect.getmembers(__import__(module))) + for name, cls in models: + param_doc[f'{module}.{name}'] = parse_ctor_args(cls, ':param') + print(json.dumps(param_doc)) + + if __name__ == "__main__": param_doc = {} # { "class_names": {"parameters": "splitted docstrings"} } From 31b5c93280628b55d9eb9262612c38c1217cc115 Mon Sep 17 00:00:00 2001 From: shendiaomo Date: Wed, 15 Jan 2020 19:19:43 +0800 Subject: [PATCH 2/4] Fix unit tests --- cmd/repl/repl_test.go | 4 ++-- pkg/sql/codegen/attribute/attribute_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cmd/repl/repl_test.go b/cmd/repl/repl_test.go index 7b56ec3cd9..4fa81fb170 100644 --- a/cmd/repl/repl_test.go +++ b/cmd/repl/repl_test.go @@ -49,7 +49,7 @@ func testMainFastFail(t *testing.T, interactive bool) { done := make(chan error) go func() { done <- cmd.Wait() }() - timeout := time.After(2 * time.Second) // 2s are enough for **fast** fail + timeout := time.After(4 * time.Second) // 4s are enough for **fast** fail select { case <-timeout: @@ -137,7 +137,7 @@ func TestComplete(t *testing.T) { p.InsertText(`RAIN `, false, true) c = s.completer(*p.Document()) - a.Equal(11, len(c)) + a.Equal(18, len(c)) p.InsertText(`DNN`, false, true) c = s.completer(*p.Document()) diff --git a/pkg/sql/codegen/attribute/attribute_test.go b/pkg/sql/codegen/attribute/attribute_test.go index 4a2f12e58e..37d84abc54 100644 --- a/pkg/sql/codegen/attribute/attribute_test.go +++ b/pkg/sql/codegen/attribute/attribute_test.go @@ -43,7 +43,7 @@ func TestDictionaryValidate(t *testing.T) { func TestPremadeModelParamsDocs(t *testing.T) { a := assert.New(t) - a.Equal(11, len(PremadeModelParamsDocs)) + a.Equal(18, len(PremadeModelParamsDocs)) a.Equal(len(PremadeModelParamsDocs["DNNClassifier"]), 12) a.NotContains(PremadeModelParamsDocs["DNNClassifier"], "feature_columns") a.Contains(PremadeModelParamsDocs["DNNClassifier"], "optimizer") From 8ec5ea52fb89eb33b0853cb68256c82c25fb1075 Mon Sep 17 00:00:00 2001 From: shendiaomo Date: Wed, 15 Jan 2020 21:30:41 +0800 Subject: [PATCH 3/4] Wait until tensorflow imported to fix unit test --- scripts/test/ipython.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/test/ipython.sh b/scripts/test/ipython.sh index 617cf0c603..e4794d496f 100644 --- a/scripts/test/ipython.sh +++ b/scripts/test/ipython.sh @@ -45,6 +45,7 @@ DATASOURCE="mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0" export PYTHONPATH=$GOPATH/src/sqlflow.org/sqlflow/python sqlflowserver & +sleep 4 # e2e test for standard SQL SQLFLOW_DATASOURCE=${DATASOURCE} SQLFLOW_SERVER=localhost:50051 ipython python/test_magic.py # TODO(yi): Re-enable the end-to-end test of Ant XGBoost after accelerating Travis CI. From d7edb92ac018412c3b8f7c4f3b939fa8acc17ee5 Mon Sep 17 00:00:00 2001 From: shendiaomo Date: Wed, 15 Jan 2020 22:46:54 +0800 Subject: [PATCH 4/4] Wait longer --- scripts/test/ipython.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/test/ipython.sh b/scripts/test/ipython.sh index e4794d496f..af7164b2c4 100644 --- a/scripts/test/ipython.sh +++ b/scripts/test/ipython.sh @@ -45,7 +45,7 @@ DATASOURCE="mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0" export PYTHONPATH=$GOPATH/src/sqlflow.org/sqlflow/python sqlflowserver & -sleep 4 +sleep 10 # e2e test for standard SQL SQLFLOW_DATASOURCE=${DATASOURCE} SQLFLOW_SERVER=localhost:50051 ipython python/test_magic.py # TODO(yi): Re-enable the end-to-end test of Ant XGBoost after accelerating Travis CI.