Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…nk_for_doc
  • Loading branch information
lhw362950217 committed Oct 12, 2020
2 parents fa80433 + 67622e4 commit 0625398
Show file tree
Hide file tree
Showing 24 changed files with 430 additions and 89 deletions.
9 changes: 9 additions & 0 deletions docker/dev/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ echo "Generate Python protobuf ..."
protoc --python_out=$SQLFLOWPATH/python/runtime/dbapi/table_writer \
-I $SQLFLOWPATH/go/proto/ sqlflow.proto

python -m grpc_tools.protoc \
--python_out=$SQLFLOWPATH/python/runtime/model/ \
--grpc_python_out=$SQLFLOWPATH/python/runtime/model/ \
-I $SQLFLOWPATH/go/proto modelzooserver.proto

# A workaround for the issue: https://github.com/protocolbuffers/protobuf/issues/1491
sed -i 's/import modelzooserver_pb2/from . import modelzooserver_pb2/g' \
$SQLFLOWPATH/python/runtime/model/modelzooserver_pb2_grpc.py

echo "Build model zoo ..."
cd $SQLFLOW_BIN
if [[ ! -d models ]]; then
Expand Down
4 changes: 3 additions & 1 deletion docker/dev/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ python -m pip install --quiet --upgrade pip setuptools six
echo "Install pip packages ..."
PRE_COMMIT="pre-commit==1.18.3"
PY_TEST="pytest==5.3.0 pytest-cov"
GRPC_PACKAGES="grpcio==1.28.1 grpcio-tools==1.28.1"
JS_LINTER=jsbeautifier
PYTHON_LINTER="yapf isort<5,>=4.2.5 pylint>=2.5.3 flake8"
WHEEL="wheel"
Expand All @@ -69,7 +70,8 @@ python -m pip install --quiet \
$PRE_COMMIT \
$PY_TEST \
$JS_LINTER \
$PYTHON_LINTER
$PYTHON_LINTER \
$GRPC_PACKAGES
rm -rf "$HOME"/.cache/pip/*


Expand Down
3 changes: 2 additions & 1 deletion docker/step/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ RUN bash -c 'pip install --no-cache-dir --prefix=/install \
sklearn2pmml==0.56.0 \
shap==0.30.1 \
PyUtilib==5.8.0 \
pyomo==5.6.9'
pyomo==5.6.9 \
grpcio==1.28.1'

RUN py3clean /install /usr/lib/python3.6

Expand Down
60 changes: 54 additions & 6 deletions go/codegen/experimental/codegen_step.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
package experimental

import (
"context"
"fmt"
"google.golang.org/grpc"
"net/url"
"os"
"strconv"
Expand Down Expand Up @@ -142,21 +144,67 @@ func isXGBoostEstimator(estimator string) bool {
return strings.HasPrefix(strings.ToUpper(estimator), "XGBOOST.")
}

type metadata simplejson.Json
// Metadata represents the metadata of a trained model
type Metadata simplejson.Json

func (m *metadata) imageName() string {
func (m *Metadata) imageName() string {
return (*simplejson.Json)(m).Get("model_repo_image").MustString()
}

func getModelMetadata(session *pb.Session, table string) (*metadata, error) {
func getModelMetadata(session *pb.Session, table string) (*Metadata, error) {
submitter := getSubmitter(session)
if submitter == "local" {
return getModelMetadataFromDB(session.DbConnStr, table)
modelZooAddr, table, tag := decomposeModelName(table)
if modelZooAddr != "" {
return getModelMetadataFromModelZoo(modelZooAddr, table, tag)
}
return GetModelMetadataFromDB(session.DbConnStr, table)
}
return nil, fmt.Errorf("not supported submitter %s", submitter)
}

func getModelMetadataFromDB(dbConnStr, table string) (*metadata, error) {
func decomposeModelName(modelName string) (string, string, string) {
idx := strings.LastIndex(modelName, "/")
if idx < 0 {
return "", modelName, ""
}

address := modelName[0:idx]
modelName = modelName[idx+1:]
idx = strings.LastIndex(modelName, ":")
tag := ""
if idx >= 0 {
tag = modelName[idx+1:]
modelName = modelName[0:idx]
}
return address, modelName, tag
}

func getModelMetadataFromModelZoo(addr, table, tag string) (*Metadata, error) {
conn, err := grpc.Dial(addr, grpc.WithInsecure())
if err != nil {
return nil, err
}
defer conn.Close()

client := pb.NewModelZooServerClient(conn)
req := &pb.ReleaseModelRequest{
Name: table,
Tag: tag,
}
resp, err := client.GetModelMeta(context.Background(), req)
if err != nil {
return nil, fmt.Errorf("error is from: %v %s", err, req.Name)
}
json, err := simplejson.NewJson([]byte(resp.Meta))
if err != nil {
return nil, err
}
return (*Metadata)(json), nil
}

// GetModelMetadataFromDB gets model Metadata from DBMS
func GetModelMetadataFromDB(dbConnStr, table string) (*Metadata, error) {
db, err := database.OpenAndConnectDB(dbConnStr)
if err != nil {
return nil, err
Expand Down Expand Up @@ -191,7 +239,7 @@ func getModelMetadataFromDB(dbConnStr, table string) (*metadata, error) {
if err != nil {
return nil, err
}
return (*metadata)(json), nil
return (*Metadata)(json), nil
}

func initializeAndCheckAttributes(stmt ir.SQLFlowStmt) error {
Expand Down
2 changes: 1 addition & 1 deletion go/codegen/experimental/codegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestExperimentalXGBCodegen(t *testing.T) {
if err != nil {
t.Errorf("error %s", err)
}
expected := `feature_column_map = {"feature_columns":[runtime.feature.column.NumericColumn(runtime.feature.field_desc.FieldDesc(name="petal_length", dtype=runtime.feature.field_desc.DataType.FLOAT32, delimiter="", format="", shape=[1], is_sparse=False, vocabulary=[]))]}`
expected := `feature_column_map = {"feature_columns":[runtime.feature.column.NumericColumn(runtime.feature.field_desc.FieldDesc(name="petal_length", dtype=runtime.feature.field_desc.DataType.FLOAT32, dtype_weight=runtime.feature.field_desc.DataType.INT64, delimiter="", delimiter_kv="", format="", shape=[1], is_sparse=False, vocabulary=[]))]}`
a.True(strings.Contains(coulerCode, expected))
}

Expand Down
49 changes: 28 additions & 21 deletions go/codegen/experimental/parse_to_ir.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,36 @@ import (
pb "sqlflow.org/sqlflow/go/proto"
)

// GenerateIRStatement generates IR statement from parser.SQLFlowStmt
func GenerateIRStatement(sql *parser.SQLFlowStmt, session *pb.Session) (ir.SQLFlowStmt, error) {
var r ir.SQLFlowStmt
var err error
if sql.IsExtendedSyntax() {
if sql.Train {
r, err = ir.GenerateTrainStmt(sql.SQLFlowSelectStmt)
} else if sql.ShowTrain {
r, err = ir.GenerateShowTrainStmt(sql.SQLFlowSelectStmt)
} else if sql.Explain {
r, err = ir.GenerateExplainStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
} else if sql.Predict {
r, err = ir.GeneratePredictStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
} else if sql.Evaluate {
r, err = ir.GenerateEvaluateStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
} else if sql.Optimize {
r, err = ir.GenerateOptimizeStmt(sql.SQLFlowSelectStmt)
} else if sql.Run {
r, err = ir.GenerateRunStmt(sql.SQLFlowSelectStmt)
}
} else {
standardSQL := ir.NormalStmt(sql.Original)
r = &standardSQL
}
return r, err
}

// parseToIR parse the sql program to generate a list of IR.
func parseToIR(sqlProgram string, session *pb.Session) ([]ir.SQLFlowStmt, error) {
var dbDriver string
var r ir.SQLFlowStmt
var result []ir.SQLFlowStmt

sqlProgram, err := parser.RemoveCommentInSQLStatement(sqlProgram)
Expand All @@ -45,26 +71,7 @@ func parseToIR(sqlProgram string, session *pb.Session) ([]ir.SQLFlowStmt, error)
}
sqls := rewriteStatementsWithHints(stmts, dbDriver)
for _, sql := range sqls {
if sql.IsExtendedSyntax() {
if sql.Train {
r, err = ir.GenerateTrainStmt(sql.SQLFlowSelectStmt)
} else if sql.ShowTrain {
r, err = ir.GenerateShowTrainStmt(sql.SQLFlowSelectStmt)
} else if sql.Explain {
r, err = ir.GenerateExplainStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
} else if sql.Predict {
r, err = ir.GeneratePredictStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
} else if sql.Evaluate {
r, err = ir.GenerateEvaluateStmt(sql.SQLFlowSelectStmt, session.DbConnStr, "", "", false)
} else if sql.Optimize {
r, err = ir.GenerateOptimizeStmt(sql.SQLFlowSelectStmt)
} else if sql.Run {
r, err = ir.GenerateRunStmt(sql.SQLFlowSelectStmt)
}
} else {
standardSQL := ir.NormalStmt(sql.Original)
r = &standardSQL
}
r, err := GenerateIRStatement(sql, session)
if err != nil {
return nil, err
}
Expand Down
28 changes: 26 additions & 2 deletions go/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ type pythonExecutor struct {
Session *pb.Session
}

func useExperimentalExecutor(exec Executor) (bool, error) {
// UseExperimentalExecutor returns whether to use the experimental codegen
func UseExperimentalExecutor(exec Executor) (bool, error) {
if os.Getenv("SQLFLOW_USE_EXPERIMENTAL_CODEGEN") != "true" {
return false, nil
}
Expand All @@ -188,7 +189,7 @@ func useExperimentalExecutor(exec Executor) (bool, error) {
}

func (s *pythonExecutor) tryExperimentalExecute(sqlStmt ir.SQLFlowStmt, logStderr bool) (bool, error) {
ok, err := useExperimentalExecutor(s)
ok, err := UseExperimentalExecutor(s)
if err != nil {
return true, err
}
Expand All @@ -198,6 +199,10 @@ func (s *pythonExecutor) tryExperimentalExecute(sqlStmt ir.SQLFlowStmt, logStder

// NOTE(sneaxiy): should use the image here
stepCode, _, err := experimental.GenerateStepCodeAndImage(sqlStmt, 0, s.Session, nil)
if err != nil {
return true, err
}

stepFuncCode, err := experimental.GetPyFuncBody(stepCode, "step_entry_0")
if err != nil {
return true, err
Expand Down Expand Up @@ -283,6 +288,9 @@ func (s *pythonExecutor) ExecuteQuery(stmt *ir.NormalStmt) error {
}

func (s *pythonExecutor) ExecuteTrain(cl *ir.TrainStmt) (e error) {
if ok, err := s.tryExperimentalExecute(cl, false); ok {
return err
}
var code string
if cl.GetModelKind() == ir.XGBoost {
if code, e = xgboost.Train(cl, s.Session); e != nil {
Expand All @@ -300,6 +308,9 @@ func (s *pythonExecutor) ExecuteTrain(cl *ir.TrainStmt) (e error) {
}

func (s *pythonExecutor) ExecutePredict(cl *ir.PredictStmt) (e error) {
if ok, err := s.tryExperimentalExecute(cl, false); ok {
return err
}
// NOTE(typhoonzero): model is already loaded under s.Cwd
if e = createPredictionResultTable(cl, s.Db, s.Session); e != nil {
return e
Expand All @@ -319,6 +330,9 @@ func (s *pythonExecutor) ExecutePredict(cl *ir.PredictStmt) (e error) {
}

func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
if ok, err := s.tryExperimentalExecute(cl, false); ok {
return err
}
// NOTE(typhoonzero): model is already loaded under s.Cwd
var code string
var err error
Expand Down Expand Up @@ -365,6 +379,9 @@ func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
}

func (s *pythonExecutor) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
if ok, err := s.tryExperimentalExecute(cl, false); ok {
return err
}
// NOTE(typhoonzero): model is already loaded under s.Cwd
var code string
var err error
Expand Down Expand Up @@ -406,6 +423,9 @@ func (s *pythonExecutor) ExecuteEvaluate(cl *ir.EvaluateStmt) error {
}

func (s *pythonExecutor) ExecuteOptimize(stmt *ir.OptimizeStmt) error {
if ok, err := s.tryExperimentalExecute(stmt, false); ok {
return err
}
db, err := database.OpenAndConnectDB(s.Session.DbConnStr)
if err != nil {
return err
Expand Down Expand Up @@ -571,6 +591,10 @@ func readExplainResult(target string) (string, error) {
func (s *pythonExecutor) GetTrainStmtFromModel() bool { return true }

func (s *pythonExecutor) ExecuteShowTrain(showTrain *ir.ShowTrainStmt) error {
if ok, err := s.tryExperimentalExecute(showTrain, false); ok {
return err
}

model, err := model.Load(showTrain.ModelName, "", s.Db)
if err != nil {
s.Writer.Write("Load model meta " + showTrain.ModelName + " failed.")
Expand Down
4 changes: 3 additions & 1 deletion go/ir/feature_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ func (fd *FieldDesc) GenPythonCode() string {
}

// pass format = "" to let runtime feature derivation to fill it in.
return fmt.Sprintf(`runtime.feature.field_desc.FieldDesc(name="%s", dtype=runtime.feature.field_desc.DataType.%s, delimiter="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s)`,
return fmt.Sprintf(`runtime.feature.field_desc.FieldDesc(name="%s", dtype=runtime.feature.field_desc.DataType.%s, dtype_weight=runtime.feature.field_desc.DataType.%s, delimiter="%s", delimiter_kv="%s", format="", shape=%s, is_sparse=%s, vocabulary=%s)`,
fd.Name,
strings.ToUpper(DTypeToString(fd.DType)),
strings.ToUpper(DTypeToString(fd.DTypeWeight)),
fd.Delimiter,
fd.DelimiterKV,
shapeStr,
isSparseStr,
AttrToPythonValue(vocabList),
Expand Down
2 changes: 1 addition & 1 deletion go/ir/feature_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestFeatureColumnGenPythonCode(t *testing.T) {
DType: 0,
},
}
a.Equal("runtime.feature.column.NumericColumn(runtime.feature.field_desc.FieldDesc(name=\"testcol\", dtype=runtime.feature.field_desc.DataType.INT64, delimiter=\"\", format=\"\", shape=[10], is_sparse=False, vocabulary=[]))",
a.Equal(`runtime.feature.column.NumericColumn(runtime.feature.field_desc.FieldDesc(name="testcol", dtype=runtime.feature.field_desc.DataType.INT64, dtype_weight=runtime.feature.field_desc.DataType.INT64, delimiter="", delimiter_kv="", format="", shape=[10], is_sparse=False, vocabulary=[]))`,
nc.GenPythonCode())

emd := &EmbeddingColumn{
Expand Down
39 changes: 39 additions & 0 deletions go/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,45 @@ func (m *Model) saveDB(connStr, table string, session *pb.Session) (e error) {
return nil
}

// SaveDBExperimental save the model to database with metadata using the refactored format.
func (m *Model) SaveDBExperimental(connStr, table string, session *pb.Session) (e error) {
db, err := database.OpenAndConnectDB(connStr)
if err != nil {
return err
}
defer db.Close()

sqlf, e := sqlfs.Create(db, table, session)
if e != nil {
return fmt.Errorf("cannot create sqlfs file %s: %v", table, e)
}
defer sqlf.Close()

metaJSONStr, err := m.Meta.Encode()
if err != nil {
return err
}
metaLen := len(metaJSONStr)
metaLenHex := fmt.Sprintf("0x%08x", metaLen)
sqlf.Write([]byte(metaLenHex))
sqlf.Write([]byte(metaJSONStr))

// model and its metadata are both zipped into a tarball
cmd := exec.Command("tar", "czf", "-", "-C", m.workDir, ".")
cmd.Stdout = sqlf
var errBuf bytes.Buffer
cmd.Stderr = &errBuf

if e := cmd.Run(); e != nil {
return fmt.Errorf("tar stderr: %v\ntar cmd %v", errBuf.String(), e)
}

if e := sqlf.Close(); e != nil {
return fmt.Errorf("close sqlfs error: %v", e)
}
return nil
}

func (m *Model) saveTar(modelDir, save string) (string, error) {
save = strings.TrimSuffix(save, ".tar.gz")
modelFile := filepath.Join(modelDir, save+".tar.gz")
Expand Down
Loading

0 comments on commit 0625398

Please sign in to comment.