Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func testMainFastFail(t *testing.T, interactive bool) {

done := make(chan error)
go func() { done <- cmd.Wait() }()
timeout := time.After(2 * time.Second) // 2s are enough for **fast** fail
timeout := time.After(4 * time.Second) // 4s are enough for **fast** fail

select {
case <-timeout:
Expand Down Expand Up @@ -137,7 +137,7 @@ func TestComplete(t *testing.T) {

p.InsertText(`RAIN `, false, true)
c = s.completer(*p.Document())
a.Equal(11, len(c))
a.Equal(18, len(c))

p.InsertText(`DNN`, false, true)
c = s.completer(*p.Document())
Expand Down
15 changes: 15 additions & 0 deletions pkg/sql/codegen/attribute/attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ package attribute
import (
"encoding/json"
"fmt"
"log"
"os/exec"
"reflect"
"sort"
"strings"
Expand Down Expand Up @@ -158,6 +160,18 @@ func NewDictionaryFromModelDefinition(estimator, prefix string) Dictionary {
// PremadeModelParamsDocs stores parameters and documents of all known models
var PremadeModelParamsDocs map[string]map[string]string

// ExtractDocString extracts parameter documents from python doc strings
func ExtractDocString(module ...string) {
cmd := exec.Command("python", "-uc", fmt.Sprintf("__import__('extract_docstring').print_param_doc('%s')", strings.Join(module, "', '")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make a local call here to print the docstring, other than start a process?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's somewhat hard to implement because we have to import the classes to get the docstrings. I have investigated several python bindings of golang but they are either too primitive (go-python) or not as reliable.

output, e := cmd.CombinedOutput()
if e != nil {
log.Println("ExtractDocString failed: ", e, string(output))
}
if e := json.Unmarshal(output, &PremadeModelParamsDocs); e != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment that unmarshal can append entries rather than overwrite it: https://golang.org/pkg/encoding/json/#Unmarshal

log.Println("ExtractDocString failed:", e, string(output))
}
}

func removeUnnecessaryParams() {
// The following parameters of canned estimators are already supported in the COLUMN clause.
for _, v := range PremadeModelParamsDocs {
Expand All @@ -171,5 +185,6 @@ func init() {
if err := json.Unmarshal([]byte(ModelParameterJSON), &PremadeModelParamsDocs); err != nil {
panic(err) // assertion
}
ExtractDocString("sqlflow_models")
removeUnnecessaryParams()
}
2 changes: 1 addition & 1 deletion pkg/sql/codegen/attribute/attribute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestDictionaryValidate(t *testing.T) {
func TestPremadeModelParamsDocs(t *testing.T) {
a := assert.New(t)

a.Equal(11, len(PremadeModelParamsDocs))
a.Equal(18, len(PremadeModelParamsDocs))
a.Equal(len(PremadeModelParamsDocs["DNNClassifier"]), 12)
a.NotContains(PremadeModelParamsDocs["DNNClassifier"], "feature_columns")
a.Contains(PremadeModelParamsDocs["DNNClassifier"], "optimizer")
Expand Down
10 changes: 10 additions & 0 deletions python/extract_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ def parse_ctor_args(f, prefix=''):
[' '.join(doc.split()).replace("`", "'") for doc in total[2::2]]))


def print_param_doc(*modules):
param_doc = {} # { "class_names": {"parameters": "splitted docstrings"} }
for module in modules:
models = filter(lambda m: inspect.isclass(m[1]),
inspect.getmembers(__import__(module)))
for name, cls in models:
param_doc[f'{module}.{name}'] = parse_ctor_args(cls, ':param')
print(json.dumps(param_doc))


if __name__ == "__main__":
param_doc = {} # { "class_names": {"parameters": "splitted docstrings"} }

Expand Down
1 change: 1 addition & 0 deletions scripts/test/ipython.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ DATASOURCE="mysql://root:root@tcp(127.0.0.1:3306)/?maxAllowedPacket=0"
export PYTHONPATH=$GOPATH/src/sqlflow.org/sqlflow/python

sqlflowserver &
sleep 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to change this sleep to a while loop that pings the SQLFlow server?

# e2e test for standard SQL
SQLFLOW_DATASOURCE=${DATASOURCE} SQLFLOW_SERVER=localhost:50051 ipython python/test_magic.py
# TODO(yi): Re-enable the end-to-end test of Ant XGBoost after accelerating Travis CI.
Expand Down