diff --git a/sql/python/sqlflow_submitter/db.py b/sql/python/sqlflow_submitter/db.py index 511138015f..756587f16a 100644 --- a/sql/python/sqlflow_submitter/db.py +++ b/sql/python/sqlflow_submitter/db.py @@ -17,7 +17,7 @@ import tensorflow as tf import sqlflow_submitter.db_writer as db_writer -def connect(driver, database, user, password, host, port, auth=""): +def connect(driver, database, user, password, host, port, session_cfg={}, auth=""): if driver == "mysql": # NOTE: use MySQLdb to avoid bugs like infinite reading: # https://bugs.mysql.com/bug.php?id=91971 @@ -32,19 +32,21 @@ def connect(driver, database, user, password, host, port, auth=""): return connect(database) elif driver == "hive": from impala.dbapi import connect - return connect(user=user, + conn = connect(user=user, password=password, database=database, host=host, port=int(port), auth_mechanism=auth) + conn.session_cfg = session_cfg + return conn elif driver == "maxcompute": from sqlflow_submitter.maxcompute import MaxCompute return MaxCompute.connect(database, user, password, host) raise ValueError("unrecognized database driver: %s" % driver) -def db_generator(driver, conn, session_cfg, statement, +def db_generator(driver, conn, statement, feature_column_names, label_column_name, feature_specs, fetch_size=128): def read_feature(raw_val, feature_spec, feature_name): @@ -69,7 +71,7 @@ def read_feature(raw_val, feature_spec, feature_name): def reader(): if driver == "hive": - cursor = conn.cursor(configuration=session_cfg) + cursor = conn.cursor(configuration=conn.session_cfg) else: cursor = conn.cursor() cursor.execute(statement) diff --git a/sql/python/sqlflow_submitter/db_test.py b/sql/python/sqlflow_submitter/db_test.py index f3b9861cd4..0f142320c0 100644 --- a/sql/python/sqlflow_submitter/db_test.py +++ b/sql/python/sqlflow_submitter/db_test.py @@ -132,7 +132,7 @@ def test_generator(self): "is_sparse": False, "shape": [] }} - gen = db_generator(driver, conn, {}, "SELECT * FROM test_table_float_fea", + gen = db_generator(driver, conn, "SELECT * FROM test_table_float_fea", ["features"], "label", column_name_to_type) idx = 0 for d in gen(): @@ -159,6 +159,6 @@ def test_generate_fetch_size(self): }} - gen = db_generator(driver, conn, {}, 'SELECT * FROM iris.train limit 10', + gen = db_generator(driver, conn, 'SELECT * FROM iris.train limit 10', ["sepal_length"], "class", column_name_to_type, fetch_size=4) self.assertEqual(len([g for g in gen()]), 10) diff --git a/sql/template_analyze.go b/sql/template_analyze.go index 752fb84599..29746c13d0 100644 --- a/sql/template_analyze.go +++ b/sql/template_analyze.go @@ -51,10 +51,10 @@ session_cfg = {} session_cfg["{{$k}}"] = "{{$v}}" {{end}} -conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}",session_cfg=session_cfg) def analyzer_dataset(): - stream = db_generator(driver, conn, session_cfg, """{{.AnalyzeDatasetSQL}}""", feature_names, label_name, feature_metas) + stream = db_generator(driver, conn, """{{.AnalyzeDatasetSQL}}""", feature_names, label_name, feature_metas) xs = pd.DataFrame(columns=feature_names) ys = pd.DataFrame(columns=[label_name]) i = 0 diff --git a/sql/template_tf.go b/sql/template_tf.go index 990acdfa0e..a7b36118ab 100644 --- a/sql/template_tf.go +++ b/sql/template_tf.go @@ -42,7 +42,12 @@ database="" database="{{.Database}}" {{end}} -conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") +session_cfg = {} +{{ range $k, $v := .Session }} +session_cfg["{{$k}}"] = "{{$v}}" +{{end}} + +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}",session_cfg=session_cfg) feature_column_names = [{{range .X}} "{{.FeatureName}}", @@ -70,11 +75,6 @@ feature_metas["{{$value.FeatureName}}"] = { } {{end}} -session_cfg = {} -{{ range $k, $v := .Session }} -session_cfg["{{$k}}"] = "{{$v}}" -{{end}} - def get_dtype(type_str): if type_str == "float32": return tf.float32 @@ -104,7 +104,7 @@ def input_fn(datasetStr): else: feature_types.append(get_dtype(feature_metas[name]["dtype"])) - gen = db_generator(driver, conn, session_cfg, datasetStr, feature_column_names, "{{.Y.FeatureName}}", feature_metas) + gen = db_generator(driver, conn, datasetStr, feature_column_names, "{{.Y.FeatureName}}", feature_metas) dataset = tf.data.Dataset.from_generator(gen, (tuple(feature_types), tf.{{.Y.Dtype}})) ds_mapper = functools.partial(_parse_sparse_feature, feature_metas=feature_metas) return dataset.map(ds_mapper) @@ -169,7 +169,12 @@ database="{{.Database}}" database="" {{end}} -conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") +session_cfg = {} +{{ range $k, $v := .Session }} +session_cfg["{{$k}}"] = "{{$v}}" +{{end}} + +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}",session_cfg=session_cfg) feature_column_names = [{{range .X}} "{{.FeatureName}}", @@ -197,11 +202,6 @@ feature_metas["{{$value.FeatureName}}"] = { } {{end}} -session_cfg = {} -{{ range $k, $v := .Session }} -session_cfg["{{$k}}"] = "{{$v}}" -{{end}} - def get_dtype(type_str): if type_str == "float32": return tf.float32 @@ -232,7 +232,7 @@ def eval_input_fn(batch_size): else: feature_types.append(get_dtype(feature_metas[name]["dtype"])) - gen = db_generator(driver, conn, session_cfg, """{{.PredictionDatasetSQL}}""", + gen = db_generator(driver, conn, """{{.PredictionDatasetSQL}}""", feature_column_names, "{{.Y.FeatureName}}", feature_metas) dataset = tf.data.Dataset.from_generator(gen, (tuple(feature_types), tf.{{.Y.Dtype}})) ds_mapper = functools.partial(_parse_sparse_feature, feature_metas=feature_metas) @@ -322,7 +322,7 @@ class FastPredict: column_names = feature_column_names[:] column_names.append("{{.Y.FeatureName}}") -pred_gen = db_generator(driver, conn, session_cfg, """{{.PredictionDatasetSQL}}""", feature_column_names, "{{.Y.FeatureName}}", feature_metas)() +pred_gen = db_generator(driver, conn, """{{.PredictionDatasetSQL}}""", feature_column_names, "{{.Y.FeatureName}}", feature_metas)() fast_predictor = FastPredict(classifier, fast_input_fn) with buffered_db_writer(driver, conn, "{{.TableName}}", column_names, 100) as w: diff --git a/sql/template_xgboost.go b/sql/template_xgboost.go index 8ad022ada2..dc4ceee25a 100644 --- a/sql/template_xgboost.go +++ b/sql/template_xgboost.go @@ -58,10 +58,10 @@ feature_specs["{{$value.FeatureName}}"] = { } {{end}} -conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}",session_cfg=session_cfg) def xgb_dataset(fn, dataset_sql): - gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs) + gen = db_generator(driver, conn, dataset_sql, feature_column_names, "{{.Y.FeatureName}}", feature_specs) with open(fn, 'w') as f: for item in gen(): features, label = item @@ -117,10 +117,10 @@ feature_specs["{{$value.FeatureName}}"] = { } {{end}} -conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}") +conn = connect(driver, database, user="{{.User}}", password="{{.Password}}", host="{{.Host}}", port={{.Port}}, auth="{{.Auth}}",session_cfg=session_cfg) def xgb_dataset(fn, dataset_sql): - gen = db_generator(driver, conn, session_cfg, dataset_sql, feature_column_names, "", feature_specs) + gen = db_generator(driver, conn, dataset_sql, feature_column_names, "", feature_specs) with open(fn, 'w') as f: for item in gen(): features, label = item