From 258e7a05754f6b32ec97685848ea36c64ae339c5 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 20 Aug 2020 08:13:40 +0800 Subject: [PATCH 1/6] fix explain into when shap values is not list --- python/runtime/xgboost/explain.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/python/runtime/xgboost/explain.py b/python/runtime/xgboost/explain.py index 609e0d63c9..9d2dcc234c 100644 --- a/python/runtime/xgboost/explain.py +++ b/python/runtime/xgboost/explain.py @@ -180,15 +180,18 @@ def explain(datasource, if is_pai: from runtime.dbapi.paiio import PaiIOConnection conn = PaiIOConnection.from_table(result_table) - # TODO(typhoonzero): the shape of shap_values is - # (3, num_samples, num_features), use the first - # dimension here, should find out how to use - # the other two. else: conn = db.connect_with_data_source(datasource) - - write_shap_values(shap_values[0], conn, result_table, - feature_column_names) + # TODO(typhoonzero): the shap_values is may be a + # list of shape [3, num_samples, num_features], + # use the first dimension here, should find out + # when to use the other two. When shap_values is + # not a list it can be directly used. + if isinstance(shap_values, list): + to_write = shap_values[0] + else: + to_write = shap_values + write_shap_values(to_write, conn, result_table, feature_column_names) return if summary_params.get("plot_type") == "decision": From f3a3c9dbf2162aeaeeb9de7f98b73f4e00be67d2 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 20 Aug 2020 11:23:30 +0800 Subject: [PATCH 2/6] update --- doc/datasets/popularize_titanic.sql | 8 +++--- go/cmd/sqlflowserver/e2e_common_cases.go | 24 ++++++++++++++++ go/cmd/sqlflowserver/e2e_mysql_test.go | 1 + .../sqlflowserver/e2e_pai_maxcompute_test.go | 28 +++++++------------ go/codegen/xgboost/codegen_explain.go | 1 + go/codegen/xgboost/template_explain.go | 6 ++-- go/executor/executor.go | 16 +++++++---- go/executor/pai.go | 9 ++++-- python/runtime/xgboost/explain.py | 3 -- 9 files changed, 61 insertions(+), 35 deletions(-) diff --git a/doc/datasets/popularize_titanic.sql b/doc/datasets/popularize_titanic.sql index fd90a4a1af..b9c8f29ee8 100644 --- a/doc/datasets/popularize_titanic.sql +++ b/doc/datasets/popularize_titanic.sql @@ -34,8 +34,8 @@ cabinalpha int, family int, isalone int, ismother int, -age float, -realfare float, +age VARCHAR(255), +realfare VARCHAR(255), survived int ); @@ -954,8 +954,8 @@ cabinalpha int, family int, isalone int, ismother int, -age float, -realfare float, +age VARCHAR(255), +realfare VARCHAR(255), survived int ); diff --git a/go/cmd/sqlflowserver/e2e_common_cases.go b/go/cmd/sqlflowserver/e2e_common_cases.go index bf22edf78f..177bfdab94 100644 --- a/go/cmd/sqlflowserver/e2e_common_cases.go +++ b/go/cmd/sqlflowserver/e2e_common_cases.go @@ -286,6 +286,30 @@ INTO sqlflow_models.my_xgb_regression_model_eval_result; } } +func casePredictXGBoostExplain(t *testing.T) { + a := assert.New(t) + trainSQL := fmt.Sprintf(`SELECT * FROM titanic.train +TO TRAIN xgboost.gbtree +WITH + objective="binary:logistic", + train.num_boost_round = 30, + eta = 0.4 +LABEL survived +INTO sqlflow_models.titanic_explain;`) + _, _, _, err := connectAndRunSQL(trainSQL) + if err != nil { + a.Fail("run predSQL error: %v", err) + } + explainSQL := fmt.Sprintf(`SELECT * FROM titanic.train +TO EXPLAIN sqlflow_models.titanic_explain +WITH label_col=survived, summary.sort=True +INTO titanic.titanic_explain_result;`) + _, _, _, err = connectAndRunSQL(explainSQL) + if err != nil { + a.Fail("run predSQL error: %v", err) + } +} + func casePredictXGBoostRegression(t *testing.T) { a := assert.New(t) predSQL := fmt.Sprintf(`SELECT * diff --git a/go/cmd/sqlflowserver/e2e_mysql_test.go b/go/cmd/sqlflowserver/e2e_mysql_test.go index e497d0c9e3..9b3f55d567 100644 --- a/go/cmd/sqlflowserver/e2e_mysql_test.go +++ b/go/cmd/sqlflowserver/e2e_mysql_test.go @@ -90,6 +90,7 @@ func TestEnd2EndMySQL(t *testing.T) { t.Run("CaseFeatureDerivation", CaseFeatureDerivation) // xgboost cases + t.Run("casePredictXGBoostExplain", casePredictXGBoostExplain) t.Run("caseTrainXGBoostRegressionConvergence", caseTrainXGBoostRegressionConvergence) t.Run("CasePredictXGBoostRegression", casePredictXGBoostRegression) diff --git a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go index 5b35b280a5..50693427b4 100644 --- a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go +++ b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go @@ -378,16 +378,16 @@ func CasePAIMaxComputeTrainXGBoost(t *testing.T) { a := assert.New(t) trainSQL := fmt.Sprintf(`SELECT * FROM %s - TO TRAIN xgboost.gbtree - WITH - objective="multi:softprob", - train.num_boost_round = 30, - eta = 0.4, - num_class = 3, - train.batch_size=10, - validation.select="select * from %s" - LABEL class - INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable) +TO TRAIN xgboost.gbtree +WITH + objective="multi:softprob", + train.num_boost_round = 30, + eta = 0.4, + num_class = 3, + train.batch_size=10, + validation.select="select * from %s" +LABEL class +INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable) _, _, _, err := connectAndRunSQL(trainSQL) a.NoError(err, "Run trainSQL error.") @@ -404,14 +404,6 @@ LABEL class INTO %s.e2etest_xgb_evaluate_result;`, caseTestTable, caseDB) _, _, _, err = connectAndRunSQL(evalSQL) a.NoError(err, "Run evalSQL error.") - - explainSQL := fmt.Sprintf(`SELECT * FROM %s -TO EXPLAIN e2etest_xgb_classi_model -WITH label_col=class -USING TreeExplainer -INTO %s.e2etest_xgb_explain_result;`, caseTrainTable, caseDB) - _, _, _, err = connectAndRunSQL(explainSQL) - a.NoError(err, "Run explainSQL error.") } func CasePAIMaxComputeTrainCustomModel(t *testing.T) { diff --git a/go/codegen/xgboost/codegen_explain.go b/go/codegen/xgboost/codegen_explain.go index e94bbaf6c3..cb35a1a723 100644 --- a/go/codegen/xgboost/codegen_explain.go +++ b/go/codegen/xgboost/codegen_explain.go @@ -57,6 +57,7 @@ func Explain(explainStmt *ir.ExplainStmt, session *pb.Session) (string, error) { FeatureColumnNames: fs, FeatureColumnCode: featureColumnCode, LabelJSON: string(l), + ResultTable: explainStmt.Into, IsPAI: tf.IsPAI(), PAIExplainTable: explainStmt.TmpExplainTable, } diff --git a/go/codegen/xgboost/template_explain.go b/go/codegen/xgboost/template_explain.go index c63b58be7d..c5f7b62631 100644 --- a/go/codegen/xgboost/template_explain.go +++ b/go/codegen/xgboost/template_explain.go @@ -25,6 +25,7 @@ type explainFiller struct { FeatureColumnNames []string FeatureColumnCode string LabelJSON string + ResultTable string IsPAI bool PAIExplainTable string } @@ -48,10 +49,11 @@ transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(feature explain( datasource='''{{.DataSource}}''', select='''{{.DatasetSQL}}''', - feature_field_meta=feature_field_meta, - feature_column_names=feature_column_names, + feature_field_meta=feature_field_meta, + feature_column_names=feature_column_names, label_meta=label_meta, summary_params=summary_params, + result_table="{{.ResultTable}}", is_pai="{{.IsPAI}}" == "true", pai_explain_table="{{.PAIExplainTable}}", transform_fn=transform_fn, diff --git a/go/executor/executor.go b/go/executor/executor.go index a155e4e259..beb3ed9dda 100644 --- a/go/executor/executor.go +++ b/go/executor/executor.go @@ -272,16 +272,20 @@ func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error { return err } defer db.Close() + + var modelType int if cl.TrainStmt.GetModelKind() == ir.XGBoost { code, err = xgboost.Explain(cl, s.Session) - // TODO(typhoonzero): deal with XGBoost model explain result table creation. + modelType = pai.ModelTypeXGBoost } else { code, err = tensorflow.Explain(cl, s.Session) - if cl.Into != "" { - err := createExplainResultTable(db, cl, cl.Into, pai.ModelTypeTF, cl.TrainStmt.Estimator) - if err != nil { - return err - } + modelType = pai.ModelTypeTF + } + + if cl.Into != "" { + err := createExplainResultTable(db, cl, cl.Into, modelType, cl.TrainStmt.Estimator) + if err != nil { + return err } } diff --git a/go/executor/pai.go b/go/executor/pai.go index 99b64030c8..fda8bc300a 100644 --- a/go/executor/pai.go +++ b/go/executor/pai.go @@ -23,9 +23,10 @@ import ( "path" "path/filepath" "regexp" - "sqlflow.org/sqlflow/go/verifier" "strings" + "sqlflow.org/sqlflow/go/verifier" + "sqlflow.org/sqlflow/go/codegen/optimize" "github.com/aliyun/aliyun-oss-go-sdk/oss" @@ -636,9 +637,13 @@ func getCreateShapResultSQL(db *database.DB, tableName string, selectStmt string return "", err } columnDefList := []string{} + columnType := "STRING" + if db.DriverName == "mysql" { + columnType = "VARCHAR(255)" + } for _, fieldName := range flds { if fieldName != labelCol { - columnDefList = append(columnDefList, fmt.Sprintf("%s STRING", fieldName)) + columnDefList = append(columnDefList, fmt.Sprintf("%s %s", fieldName, columnType)) } } createStmt := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (%s);`, tableName, strings.Join(columnDefList, ",")) diff --git a/python/runtime/xgboost/explain.py b/python/runtime/xgboost/explain.py index 9d2dcc234c..f6a4ba5332 100644 --- a/python/runtime/xgboost/explain.py +++ b/python/runtime/xgboost/explain.py @@ -173,9 +173,7 @@ def explain(datasource, pai_explain_table, transform_fn=transform_fn, feature_column_code=feature_column_code) - shap_values, shap_interaction_values, expected_value = xgb_shap_values(x) - if result_table != "": if is_pai: from runtime.dbapi.paiio import PaiIOConnection @@ -192,7 +190,6 @@ def explain(datasource, else: to_write = shap_values write_shap_values(to_write, conn, result_table, feature_column_names) - return if summary_params.get("plot_type") == "decision": explainer.plot_and_save( From 7433d4c3ce43b006742812ddec0fbea4416537ac Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 20 Aug 2020 11:40:16 +0800 Subject: [PATCH 3/6] add e2e test --- go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go index 50693427b4..02e0a50908 100644 --- a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go +++ b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go @@ -404,6 +404,21 @@ LABEL class INTO %s.e2etest_xgb_evaluate_result;`, caseTestTable, caseDB) _, _, _, err = connectAndRunSQL(evalSQL) a.NoError(err, "Run evalSQL error.") + + titanicTrain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train +TO TRAIN xgboost.gbtree +WITH objective="binary:logistic" +LABEL survived +INTO e2etest_xgb_titanic;`, caseDB) + _, _, _, err = connectAndRunSQL(titanicTrain) + a.NoError(err, "Run titanicTrain error.") + + titanicExplain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train +TO EXPLAIN e2etest_xgb_titanic +WITH label_col=survived +INTO %s.e2etest_xgb_explain_result_wuyi;`, caseDB, caseDB) + _, _, _, err = connectAndRunSQL(titanicExplain) + a.NoError(err, "Run titanicExplain error.") } func CasePAIMaxComputeTrainCustomModel(t *testing.T) { From 7d37f6ceeb562d5b31298663eb2ce3bf8919590a Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 20 Aug 2020 13:29:23 +0800 Subject: [PATCH 4/6] update --- go/cmd/sqlflowserver/e2e_common_cases.go | 24 ------------------- go/cmd/sqlflowserver/e2e_mysql_test.go | 1 - .../sqlflowserver/e2e_pai_maxcompute_test.go | 4 ++-- python/runtime/xgboost/dataset.py | 8 ++----- 4 files changed, 4 insertions(+), 33 deletions(-) diff --git a/go/cmd/sqlflowserver/e2e_common_cases.go b/go/cmd/sqlflowserver/e2e_common_cases.go index 177bfdab94..bf22edf78f 100644 --- a/go/cmd/sqlflowserver/e2e_common_cases.go +++ b/go/cmd/sqlflowserver/e2e_common_cases.go @@ -286,30 +286,6 @@ INTO sqlflow_models.my_xgb_regression_model_eval_result; } } -func casePredictXGBoostExplain(t *testing.T) { - a := assert.New(t) - trainSQL := fmt.Sprintf(`SELECT * FROM titanic.train -TO TRAIN xgboost.gbtree -WITH - objective="binary:logistic", - train.num_boost_round = 30, - eta = 0.4 -LABEL survived -INTO sqlflow_models.titanic_explain;`) - _, _, _, err := connectAndRunSQL(trainSQL) - if err != nil { - a.Fail("run predSQL error: %v", err) - } - explainSQL := fmt.Sprintf(`SELECT * FROM titanic.train -TO EXPLAIN sqlflow_models.titanic_explain -WITH label_col=survived, summary.sort=True -INTO titanic.titanic_explain_result;`) - _, _, _, err = connectAndRunSQL(explainSQL) - if err != nil { - a.Fail("run predSQL error: %v", err) - } -} - func casePredictXGBoostRegression(t *testing.T) { a := assert.New(t) predSQL := fmt.Sprintf(`SELECT * diff --git a/go/cmd/sqlflowserver/e2e_mysql_test.go b/go/cmd/sqlflowserver/e2e_mysql_test.go index 9b3f55d567..e497d0c9e3 100644 --- a/go/cmd/sqlflowserver/e2e_mysql_test.go +++ b/go/cmd/sqlflowserver/e2e_mysql_test.go @@ -90,7 +90,6 @@ func TestEnd2EndMySQL(t *testing.T) { t.Run("CaseFeatureDerivation", CaseFeatureDerivation) // xgboost cases - t.Run("casePredictXGBoostExplain", casePredictXGBoostExplain) t.Run("caseTrainXGBoostRegressionConvergence", caseTrainXGBoostRegressionConvergence) t.Run("CasePredictXGBoostRegression", casePredictXGBoostRegression) diff --git a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go index 02e0a50908..4e560a958c 100644 --- a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go +++ b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go @@ -410,13 +410,13 @@ TO TRAIN xgboost.gbtree WITH objective="binary:logistic" LABEL survived INTO e2etest_xgb_titanic;`, caseDB) - _, _, _, err = connectAndRunSQL(titanicTrain) + _, _, _, err := connectAndRunSQL(titanicTrain) a.NoError(err, "Run titanicTrain error.") titanicExplain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train TO EXPLAIN e2etest_xgb_titanic WITH label_col=survived -INTO %s.e2etest_xgb_explain_result_wuyi;`, caseDB, caseDB) +INTO %s.e2etest_titanic_explain_result;`, caseDB, caseDB) _, _, _, err = connectAndRunSQL(titanicExplain) a.NoError(err, "Run titanicExplain error.") } diff --git a/python/runtime/xgboost/dataset.py b/python/runtime/xgboost/dataset.py index 401beece7b..b3241f8b2d 100644 --- a/python/runtime/xgboost/dataset.py +++ b/python/runtime/xgboost/dataset.py @@ -125,6 +125,7 @@ def dump_dmatrix(filename, with open(filename, 'a') as f: for row, label in generator: + print("dump dmatrix row: ", row, label) features = db.read_features_from_row(row, selected_cols, feature_column_names, feature_metas) @@ -252,20 +253,15 @@ def pai_dataset(filename, from subprocess import Popen, PIPE from multiprocessing.dummy import Pool # ThreadPool import queue - dname = filename if single_file: dname = filename + '.dir' - if os.path.exists(dname): shutil.rmtree(dname, ignore_errors=True) os.mkdir(dname) - slice_count = get_pai_table_slice_count(pai_table, nworkers, batch_size) - thread_num = min(int(slice_count / nworkers), 128) - pool = Pool(thread_num) complete_queue = queue.Queue() @@ -337,7 +333,7 @@ def pai_download_table_data_worker(dname, feature_metas, feature_column_names, feature_column_names, *feature_column_transformers) conn = PaiIOConnection.from_table(pai_table, slice_id, slice_count) - gen = db.db_generator(conn, None)() + gen = db.db_generator(conn, None, label_meta=label_meta)() selected_cols = db.selected_cols(conn, None) filename = "{}/{}.txt".format(dname, slice_id) dump_dmatrix(filename, From d6d5165b952a3c666027ca5c9bf69eb723806487 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 20 Aug 2020 13:54:30 +0800 Subject: [PATCH 5/6] update --- go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go index 4e560a958c..5cf41d85f9 100644 --- a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go +++ b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go @@ -410,7 +410,7 @@ TO TRAIN xgboost.gbtree WITH objective="binary:logistic" LABEL survived INTO e2etest_xgb_titanic;`, caseDB) - _, _, _, err := connectAndRunSQL(titanicTrain) + _, _, _, err = connectAndRunSQL(titanicTrain) a.NoError(err, "Run titanicTrain error.") titanicExplain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train From aa0040162fe7eb3e335092e412e6dbf524e517a8 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 20 Aug 2020 14:15:19 +0800 Subject: [PATCH 6/6] clean --- doc/datasets/popularize_titanic.sql | 8 ++++---- python/runtime/xgboost/dataset.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/doc/datasets/popularize_titanic.sql b/doc/datasets/popularize_titanic.sql index b9c8f29ee8..fd90a4a1af 100644 --- a/doc/datasets/popularize_titanic.sql +++ b/doc/datasets/popularize_titanic.sql @@ -34,8 +34,8 @@ cabinalpha int, family int, isalone int, ismother int, -age VARCHAR(255), -realfare VARCHAR(255), +age float, +realfare float, survived int ); @@ -954,8 +954,8 @@ cabinalpha int, family int, isalone int, ismother int, -age VARCHAR(255), -realfare VARCHAR(255), +age float, +realfare float, survived int ); diff --git a/python/runtime/xgboost/dataset.py b/python/runtime/xgboost/dataset.py index b3241f8b2d..9b271a5520 100644 --- a/python/runtime/xgboost/dataset.py +++ b/python/runtime/xgboost/dataset.py @@ -125,7 +125,6 @@ def dump_dmatrix(filename, with open(filename, 'a') as f: for row, label in generator: - print("dump dmatrix row: ", row, label) features = db.read_features_from_row(row, selected_cols, feature_column_names, feature_metas)