Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 24 additions & 17 deletions go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,16 +378,16 @@ func CasePAIMaxComputeTrainXGBoost(t *testing.T) {
a := assert.New(t)

trainSQL := fmt.Sprintf(`SELECT * FROM %s
TO TRAIN xgboost.gbtree
WITH
objective="multi:softprob",
train.num_boost_round = 30,
eta = 0.4,
num_class = 3,
train.batch_size=10,
validation.select="select * from %s"
LABEL class
INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable)
TO TRAIN xgboost.gbtree
WITH
objective="multi:softprob",
train.num_boost_round = 30,
eta = 0.4,
num_class = 3,
train.batch_size=10,
validation.select="select * from %s"
LABEL class
INTO e2etest_xgb_classi_model;`, caseTrainTable, caseTrainTable)
_, _, _, err := connectAndRunSQL(trainSQL)
a.NoError(err, "Run trainSQL error.")

Expand All @@ -405,13 +405,20 @@ INTO %s.e2etest_xgb_evaluate_result;`, caseTestTable, caseDB)
_, _, _, err = connectAndRunSQL(evalSQL)
a.NoError(err, "Run evalSQL error.")

explainSQL := fmt.Sprintf(`SELECT * FROM %s
TO EXPLAIN e2etest_xgb_classi_model
WITH label_col=class
USING TreeExplainer
INTO %s.e2etest_xgb_explain_result;`, caseTrainTable, caseDB)
_, _, _, err = connectAndRunSQL(explainSQL)
a.NoError(err, "Run explainSQL error.")
titanicTrain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train
TO TRAIN xgboost.gbtree
WITH objective="binary:logistic"
LABEL survived
INTO e2etest_xgb_titanic;`, caseDB)
_, _, _, err = connectAndRunSQL(titanicTrain)
a.NoError(err, "Run titanicTrain error.")

titanicExplain := fmt.Sprintf(`SELECT * FROM %s.sqlflow_titanic_train
TO EXPLAIN e2etest_xgb_titanic
WITH label_col=survived
INTO %s.e2etest_titanic_explain_result;`, caseDB, caseDB)
_, _, _, err = connectAndRunSQL(titanicExplain)
a.NoError(err, "Run titanicExplain error.")
}

func CasePAIMaxComputeTrainCustomModel(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions go/codegen/xgboost/codegen_explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func Explain(explainStmt *ir.ExplainStmt, session *pb.Session) (string, error) {
FeatureColumnNames: fs,
FeatureColumnCode: featureColumnCode,
LabelJSON: string(l),
ResultTable: explainStmt.Into,
IsPAI: tf.IsPAI(),
PAIExplainTable: explainStmt.TmpExplainTable,
}
Expand Down
6 changes: 4 additions & 2 deletions go/codegen/xgboost/template_explain.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type explainFiller struct {
FeatureColumnNames []string
FeatureColumnCode string
LabelJSON string
ResultTable string
IsPAI bool
PAIExplainTable string
}
Expand All @@ -48,10 +49,11 @@ transform_fn = xgboost_extended.feature_column.ComposedColumnTransformer(feature
explain(
datasource='''{{.DataSource}}''',
select='''{{.DatasetSQL}}''',
feature_field_meta=feature_field_meta,
feature_column_names=feature_column_names,
feature_field_meta=feature_field_meta,
feature_column_names=feature_column_names,
label_meta=label_meta,
summary_params=summary_params,
result_table="{{.ResultTable}}",
is_pai="{{.IsPAI}}" == "true",
pai_explain_table="{{.PAIExplainTable}}",
transform_fn=transform_fn,
Expand Down
16 changes: 10 additions & 6 deletions go/executor/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,16 +272,20 @@ func (s *pythonExecutor) ExecuteExplain(cl *ir.ExplainStmt) error {
return err
}
defer db.Close()

var modelType int
if cl.TrainStmt.GetModelKind() == ir.XGBoost {
code, err = xgboost.Explain(cl, s.Session)
// TODO(typhoonzero): deal with XGBoost model explain result table creation.
modelType = pai.ModelTypeXGBoost
} else {
code, err = tensorflow.Explain(cl, s.Session)
if cl.Into != "" {
err := createExplainResultTable(db, cl, cl.Into, pai.ModelTypeTF, cl.TrainStmt.Estimator)
if err != nil {
return err
}
modelType = pai.ModelTypeTF
}

if cl.Into != "" {
err := createExplainResultTable(db, cl, cl.Into, modelType, cl.TrainStmt.Estimator)
if err != nil {
return err
}
}

Expand Down
9 changes: 7 additions & 2 deletions go/executor/pai.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import (
"path"
"path/filepath"
"regexp"
"sqlflow.org/sqlflow/go/verifier"
"strings"

"sqlflow.org/sqlflow/go/verifier"

"sqlflow.org/sqlflow/go/codegen/optimize"

"github.com/aliyun/aliyun-oss-go-sdk/oss"
Expand Down Expand Up @@ -636,9 +637,13 @@ func getCreateShapResultSQL(db *database.DB, tableName string, selectStmt string
return "", err
}
columnDefList := []string{}
columnType := "STRING"
if db.DriverName == "mysql" {
columnType = "VARCHAR(255)"
}
for _, fieldName := range flds {
if fieldName != labelCol {
columnDefList = append(columnDefList, fmt.Sprintf("%s STRING", fieldName))
columnDefList = append(columnDefList, fmt.Sprintf("%s %s", fieldName, columnType))
}
}
createStmt := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s (%s);`, tableName, strings.Join(columnDefList, ","))
Expand Down
7 changes: 1 addition & 6 deletions python/runtime/xgboost/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,15 @@ def pai_dataset(filename,
from subprocess import Popen, PIPE
from multiprocessing.dummy import Pool # ThreadPool
import queue

dname = filename
if single_file:
dname = filename + '.dir'

if os.path.exists(dname):
shutil.rmtree(dname, ignore_errors=True)

os.mkdir(dname)

slice_count = get_pai_table_slice_count(pai_table, nworkers, batch_size)

thread_num = min(int(slice_count / nworkers), 128)

pool = Pool(thread_num)
complete_queue = queue.Queue()

Expand Down Expand Up @@ -337,7 +332,7 @@ def pai_download_table_data_worker(dname, feature_metas, feature_column_names,
feature_column_names, *feature_column_transformers)

conn = PaiIOConnection.from_table(pai_table, slice_id, slice_count)
gen = db.db_generator(conn, None)()
gen = db.db_generator(conn, None, label_meta=label_meta)()
selected_cols = db.selected_cols(conn, None)
filename = "{}/{}.txt".format(dname, slice_id)
dump_dmatrix(filename,
Expand Down
20 changes: 10 additions & 10 deletions python/runtime/xgboost/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,23 +173,23 @@ def explain(datasource,
pai_explain_table,
transform_fn=transform_fn,
feature_column_code=feature_column_code)

shap_values, shap_interaction_values, expected_value = xgb_shap_values(x)

if result_table != "":
if is_pai:
from runtime.dbapi.paiio import PaiIOConnection
conn = PaiIOConnection.from_table(result_table)
# TODO(typhoonzero): the shape of shap_values is
# (3, num_samples, num_features), use the first
# dimension here, should find out how to use
# the other two.
else:
conn = db.connect_with_data_source(datasource)

write_shap_values(shap_values[0], conn, result_table,
feature_column_names)
return
# TODO(typhoonzero): the shap_values is may be a
# list of shape [3, num_samples, num_features],
# use the first dimension here, should find out
# when to use the other two. When shap_values is
# not a list it can be directly used.
if isinstance(shap_values, list):
to_write = shap_values[0]
else:
to_write = shap_values
write_shap_values(to_write, conn, result_table, feature_column_names)

if summary_params.get("plot_type") == "decision":
explainer.plot_and_save(
Expand Down