diff --git a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go index 5b35b280a5..5cf41d85f9 100644 --- a/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go +++ b/go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go @@ -378,16 +378,16 @@ func CasePAIMaxComputeTrainXGBoost(t *testing.T) { a := assert.New(t) trainSQL := fmt.Sprintf(`SELECT * FROM %s - TO TRAIN xgboost.gbtree - WITH - objective="multi:softprob", - train.num_boost_round = 30, - eta = 0.4, - num_class = 3, - train.batch_size=10, - validation.select="select * from %s" - LABEL class - INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable) +TO TRAIN xgboost.gbtree +WITH + objective="multi:softprob", + train.num_boost_round = 30, + eta = 0.4, + num_class = 3, + train.batch_size=10, + validation.select="select * from %s" +LABEL class +INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable) _, _, _, err := connectAndRunSQL(trainSQL) a.NoError(err, "Run trainSQL error.") @@ -405,13 +405,20 @@ INTO %s.e2etest_xgb_evaluate_result;`, caseTestTable, caseDB) _, _, _, err = connectAndRunSQL(evalSQL) a.NoError(err, "Run evalSQL error.") - explainSQL := fmt.Sprintf(`SELECT * FROM %s -TO EXPLAIN e2etest_xgb_classi_model -WITH label_col=class -USING TreeExplainer -INTO %s.e2etest_xgb_explain_result;`, caseTrainTable, caseDB) - _, _, _, err = connectAndRunSQL(explainSQL) - a.NoError(err, "Run explainSQL error.") + titanicTrain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train +TO TRAIN xgboost.gbtree +WITH objective="binary:logistic" +LABEL survived +INTO e2etest_xgb_titanic;`, caseDB) + _, _, _, err = connectAndRunSQL(titanicTrain) + a.NoError(err, "Run titanicTrain error.") + + titanicExplain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train +TO EXPLAIN e2etest_xgb_titanic +WITH label_col=survived +INTO %s.e2etest_titanic_explain_result;`, caseDB, caseDB) + _, _, _, err = connectAndRunSQL(titanicExplain) + a.NoError(err, "Run titanicExplain error.") } func CasePAIMaxComputeTrainCustomModel(t *testing.T) { diff --git a/go/codegen/xgboost/codegen_explain.go b/go/codegen/xgboost/codegen_explain.go index e94bbaf6c3..cb35a1a723 100644 --- a/go/codegen/xgboost/codegen_explain.go +++ b/go/codegen/xgboost/codegen_explain.go @@ -57,6 +57,7 @@ func Explain(explainStmt *ir.ExplainStmt, session *pb.Session) (string, error) { FeatureColumnNames: fs, FeatureColumnCode: featureColumnCode, LabelJSON: string(l), + ResultTable: explainStmt.Into, IsPAI: tf.IsPAI(), PAIExplainTable: explainStmt.TmpExplainTable, } diff --git a/go/codegen/xgboost/template_explain.go b/go/codegen/xgboost/template_explain.go index c63b58be7d..c5f7b62631 100644 --- a/go/codegen/xgboost/template_explain.go +++ b/go/codegen/xgboost/template_explain.go @@ -25,6 +25,7 @@ type explainFiller struct { FeatureColumnNames []string FeatureColumnCode string LabelJSON string + ResultTable string IsPAI bool PAIExplainTable string } @@ -48,10 +49,11 @@ transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(feature explain( datasource='''{{.DataSource}}''', select='''{{.DatasetSQL}}''', - feature_field_meta=feature_field_meta, - feature_column_names=feature_column_names, + feature_field_meta=feature_field_meta, + feature_column_names=feature_column_names, label_meta=label_meta, summary_params=summary_params, + result_table="{{.ResultTable}}", is_pai="{{.IsPAI}}" == "true", pai_explain_table="{{.PAIExplainTable}}", transform_fn=transform_fn, diff --git a/go/executor/executor.go b/go/executor/executor.go index a155e4e259..beb3ed9dda 100644 --- a/go/executor/executor.go +++ b/go/executor/executor.go @@ -272,16 +272,20 @@ func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error { return err } defer db.Close() + + var modelType int if cl.TrainStmt.GetModelKind() == ir.XGBoost { code, err = xgboost.Explain(cl, s.Session) - // TODO(typhoonzero): deal with XGBoost model explain result table creation. + modelType = pai.ModelTypeXGBoost } else { code, err = tensorflow.Explain(cl, s.Session) - if cl.Into != "" { - err := createExplainResultTable(db, cl, cl.Into, pai.ModelTypeTF, cl.TrainStmt.Estimator) - if err != nil { - return err - } + modelType = pai.ModelTypeTF + } + + if cl.Into != "" { + err := createExplainResultTable(db, cl, cl.Into, modelType, cl.TrainStmt.Estimator) + if err != nil { + return err } } diff --git a/go/executor/pai.go b/go/executor/pai.go index 99b64030c8..fda8bc300a 100644 --- a/go/executor/pai.go +++ b/go/executor/pai.go @@ -23,9 +23,10 @@ import ( "path" "path/filepath" "regexp" - "sqlflow.org/sqlflow/go/verifier" "strings" + "sqlflow.org/sqlflow/go/verifier" + "sqlflow.org/sqlflow/go/codegen/optimize" "github.com/aliyun/aliyun-oss-go-sdk/oss" @@ -636,9 +637,13 @@ func getCreateShapResultSQL(db *database.DB, tableName string, selectStmt string return "", err } columnDefList := []string{} + columnType := "STRING" + if db.DriverName == "mysql" { + columnType = "VARCHAR(255)" + } for _, fieldName := range flds { if fieldName != labelCol { - columnDefList = append(columnDefList, fmt.Sprintf("%s STRING", fieldName)) + columnDefList = append(columnDefList, fmt.Sprintf("%s %s", fieldName, columnType)) } } createStmt := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (%s);`, tableName, strings.Join(columnDefList, ",")) diff --git a/python/runtime/xgboost/dataset.py b/python/runtime/xgboost/dataset.py index 401beece7b..9b271a5520 100644 --- a/python/runtime/xgboost/dataset.py +++ b/python/runtime/xgboost/dataset.py @@ -252,20 +252,15 @@ def pai_dataset(filename, from subprocess import Popen, PIPE from multiprocessing.dummy import Pool # ThreadPool import queue - dname = filename if single_file: dname = filename + '.dir' - if os.path.exists(dname): shutil.rmtree(dname, ignore_errors=True) os.mkdir(dname) - slice_count = get_pai_table_slice_count(pai_table, nworkers, batch_size) - thread_num = min(int(slice_count / nworkers), 128) - pool = Pool(thread_num) complete_queue = queue.Queue() @@ -337,7 +332,7 @@ def pai_download_table_data_worker(dname, feature_metas, feature_column_names, feature_column_names, *feature_column_transformers) conn = PaiIOConnection.from_table(pai_table, slice_id, slice_count) - gen = db.db_generator(conn, None)() + gen = db.db_generator(conn, None, label_meta=label_meta)() selected_cols = db.selected_cols(conn, None) filename = "{}/{}.txt".format(dname, slice_id) dump_dmatrix(filename, diff --git a/python/runtime/xgboost/explain.py b/python/runtime/xgboost/explain.py index 609e0d63c9..f6a4ba5332 100644 --- a/python/runtime/xgboost/explain.py +++ b/python/runtime/xgboost/explain.py @@ -173,23 +173,23 @@ def explain(datasource, pai_explain_table, transform_fn=transform_fn, feature_column_code=feature_column_code) - shap_values, shap_interaction_values, expected_value = xgb_shap_values(x) - if result_table != "": if is_pai: from runtime.dbapi.paiio import PaiIOConnection conn = PaiIOConnection.from_table(result_table) - # TODO(typhoonzero): the shape of shap_values is - # (3, num_samples, num_features), use the first - # dimension here, should find out how to use - # the other two. else: conn = db.connect_with_data_source(datasource) - - write_shap_values(shap_values[0], conn, result_table, - feature_column_names) - return + # TODO(typhoonzero): the shap_values is may be a + # list of shape [3, num_samples, num_features], + # use the first dimension here, should find out + # when to use the other two. When shap_values is + # not a list it can be directly used. + if isinstance(shap_values, list): + to_write = shap_values[0] + else: + to_write = shap_values + write_shap_values(to_write, conn, result_table, feature_column_names) if summary_params.get("plot_type") == "decision": explainer.plot_and_save(