diff --git a/doc/model_parameter.md b/doc/model_parameter.md index 53c97feb36..bfb97a2e2e 100644 --- a/doc/model_parameter.md +++ b/doc/model_parameter.md @@ -79,7 +79,7 @@ INTO sqlflow_models.my_xgb_regression_model;
| Name | diff --git a/pkg/attribute/checker.go b/pkg/attribute/checker.go index 07f08b5d9c..7229a5da74 100644 --- a/pkg/attribute/checker.go +++ b/pkg/attribute/checker.go @@ -22,135 +22,89 @@ var equalSign = map[bool]string{true: "=", false: ""} // Float32RangeChecker is a helper function to generate range checkers on attributes. // lower/upper indicates the lower bound and upper bound of the attribute value. // includeLower/includeUpper indicates the inclusion of the bound. -func Float32RangeChecker(lower, upper float32, includeLower, includeUpper bool) func(interface{}) error { - return func(attr interface{}) error { - if f, ok := attr.(float32); ok { - e := Float32LowerBoundChecker(lower, includeLower)(f) - if e == nil { - e = Float32UpperBoundChecker(upper, includeUpper)(f) - } - return e +func Float32RangeChecker(lower, upper float32, includeLower, includeUpper bool) func(float32) error { + return func(f float32) error { + e := Float32LowerBoundChecker(lower, includeLower)(f) + if e == nil { + e = Float32UpperBoundChecker(upper, includeUpper)(f) } - return fmt.Errorf("expected type float32, received %T", attr) + return e } } // Float32LowerBoundChecker returns a range checker that only checks the lower bound. -func Float32LowerBoundChecker(lower float32, includeLower bool) func(interface{}) error { - return func(attr interface{}) error { - if f, ok := attr.(float32); ok { - if (!includeLower && f > lower) || (includeLower && f >= lower) { - return nil - } - return fmt.Errorf("range check %v <%v %v failed", lower, equalSign[includeLower], f) +func Float32LowerBoundChecker(lower float32, includeLower bool) func(float32) error { + return func(f float32) error { + if (!includeLower && f > lower) || (includeLower && f >= lower) { + return nil } - return fmt.Errorf("expected type float32, received %T", attr) + return fmt.Errorf("range check %v <%v %v failed", lower, equalSign[includeLower], f) } } // Float32UpperBoundChecker returns a range checker that only checks the upper bound. -func Float32UpperBoundChecker(upper float32, includeUpper bool) func(interface{}) error { - return func(attr interface{}) error { - if f, ok := attr.(float32); ok { - if (!includeUpper && f < upper) || (includeUpper && f <= upper) { - return nil - } - return fmt.Errorf("range check %v >%v %v failed", upper, equalSign[includeUpper], f) +func Float32UpperBoundChecker(upper float32, includeUpper bool) func(float32) error { + return func(f float32) error { + if (!includeUpper && f < upper) || (includeUpper && f <= upper) { + return nil } - return fmt.Errorf("expected type float32, received %T", attr) + return fmt.Errorf("range check %v >%v %v failed", upper, equalSign[includeUpper], f) } } // IntRangeChecker is a helper function to generate range checkers on attributes. // lower/upper indicates the lower bound and upper bound of the attribute value. // includeLower/includeUpper indicates the inclusion of the bound. -func IntRangeChecker(lower, upper int, includeLower, includeUpper bool) func(interface{}) error { - return func(attr interface{}) error { - if f, ok := attr.(int); ok { - e := IntLowerBoundChecker(lower, includeLower)(f) - if e == nil { - e = IntUpperBoundChecker(upper, includeUpper)(f) - } - return e +func IntRangeChecker(lower, upper int, includeLower, includeUpper bool) func(int) error { + return func(i int) error { + e := IntLowerBoundChecker(lower, includeLower)(i) + if e == nil { + e = IntUpperBoundChecker(upper, includeUpper)(i) } - return fmt.Errorf("expected type int, received %T", attr) + return e } } // IntLowerBoundChecker returns a range checker that only checks the lower bound. -func IntLowerBoundChecker(lower int, includeLower bool) func(interface{}) error { - return func(attr interface{}) error { - if f, ok := attr.(int); ok { - if f > lower || includeLower && f == lower { - return nil - } - return fmt.Errorf("range check %v <%v %v failed", lower, equalSign[includeLower], f) +func IntLowerBoundChecker(lower int, includeLower bool) func(int) error { + return func(i int) error { + if i > lower || includeLower && i == lower { + return nil } - return fmt.Errorf("expected type int, received %T", attr) + return fmt.Errorf("range check %v <%v %v failed", lower, equalSign[includeLower], i) } } // IntUpperBoundChecker returns a range checker that only checks the upper bound. -func IntUpperBoundChecker(upper int, includeUpper bool) func(interface{}) error { - return func(attr interface{}) error { - if f, ok := attr.(int); ok { - if f < upper || includeUpper && f == upper { - return nil - } - return fmt.Errorf("range check %v >%v %v failed", upper, equalSign[includeUpper], f) +func IntUpperBoundChecker(upper int, includeUpper bool) func(int) error { + return func(i int) error { + if i < upper || includeUpper && i == upper { + return nil } - return fmt.Errorf("expected type int, received %T", attr) + return fmt.Errorf("range check %v >%v %v failed", upper, equalSign[includeUpper], i) } } // IntChoicesChecker verifies the attribute value is in a list of choices. -func IntChoicesChecker(choices ...int) func(interface{}) error { - checker := func(e interface{}) error { - i, ok := e.(int) - if !ok { - return fmt.Errorf("expected type int, received %T", e) - } - found := false +func IntChoicesChecker(choices ...int) func(int) error { + return func(i int) error { for _, possibleValue := range choices { if i == possibleValue { - found = true - break + return nil } } - if found == false { - return fmt.Errorf("expected value in %v, actual: %v", choices, i) - } - return nil + return fmt.Errorf("expected value in %v, actual: %v", choices, i) } - return checker } // StringChoicesChecker verifies the attribute value is in a list of choices. -func StringChoicesChecker(choices ...string) func(interface{}) error { - checker := func(e interface{}) error { - s, ok := e.(string) - if !ok { - return fmt.Errorf("expected type string, received %T", e) - } - found := false +func StringChoicesChecker(choices ...string) func(string) error { + return func(s string) error { for _, possibleValue := range choices { if s == possibleValue { - found = true - break + return nil } } - if found == false { - return fmt.Errorf("expected value in %v, actual: %v", choices, s) - } - return nil - } - return checker -} - -// EmptyChecker returns a checker function that do **not** check the input. -func EmptyChecker() func(interface{}) error { - checker := func(e interface{}) error { - return nil + return fmt.Errorf("expected value in %v, actual: %v", choices, s) } - return checker } diff --git a/pkg/attribute/checker_test.go b/pkg/attribute/checker_test.go index 476d241427..1c546a58f6 100644 --- a/pkg/attribute/checker_test.go +++ b/pkg/attribute/checker_test.go @@ -23,7 +23,7 @@ func TestFloat32RangeChecker(t *testing.T) { a := assert.New(t) checker := Float32RangeChecker(0.0, 1.0, true, true) - a.Error(checker(1)) + a.NoError(checker(1)) a.Error(checker(float32(-1))) a.NoError(checker(float32(0))) a.NoError(checker(float32(0.5))) @@ -31,7 +31,7 @@ func TestFloat32RangeChecker(t *testing.T) { a.Error(checker(float32(2))) checker2 := Float32RangeChecker(0.0, 1.0, false, false) - a.Error(checker(1)) + a.NoError(checker(1)) a.Error(checker2(float32(-1))) a.Error(checker2(float32(0))) a.NoError(checker2(float32(0.5))) @@ -43,7 +43,7 @@ func TestIntRangeChecker(t *testing.T) { a := assert.New(t) checker := IntRangeChecker(0, 2, true, true) - a.Error(checker(1.0)) + a.NoError(checker(1.0)) a.Error(checker(int(-1))) a.NoError(checker(int(0))) a.NoError(checker(int(1))) @@ -51,7 +51,7 @@ func TestIntRangeChecker(t *testing.T) { a.Error(checker(int(3))) checker2 := IntRangeChecker(0, 2, false, false) - a.Error(checker(1.0)) + a.NoError(checker(1.0)) a.Error(checker2(int(-1))) a.Error(checker2(int(0))) a.NoError(checker2(int(1))) @@ -63,7 +63,6 @@ func TestIntChoicesChecker(t *testing.T) { a := assert.New(t) checker := IntChoicesChecker(0, 1, 2) - a.Error(checker(1.0)) a.Error(checker(-1)) a.NoError(checker(0)) a.NoError(checker(1)) @@ -75,8 +74,6 @@ func TestStringChoicesChecker(t *testing.T) { a := assert.New(t) checker := StringChoicesChecker("0", "1", "2") - a.Error(checker(1.0)) - a.Error(checker(-1)) a.NoError(checker("0")) a.NoError(checker("1")) a.NoError(checker("2")) diff --git a/pkg/codegen/optimize/codegen.go b/pkg/codegen/optimize/codegen.go index b2ae87fe37..d9cd9daf8e 100644 --- a/pkg/codegen/optimize/codegen.go +++ b/pkg/codegen/optimize/codegen.go @@ -24,32 +24,17 @@ import ( "text/template" ) -func checkIsPositiveInteger(i interface{}, name string) error { - if v, ok := i.(int); !ok || v <= 0 { - return fmt.Errorf("%s should be positive integer", name) - } - return nil -} - -// TODO(sneaxiy): polish attribute codes -var attributeDictionary = attribute.Dictionary{ - "data.enable_slice": {attribute.Bool, false, "Whether to enable data slicing", nil}, - "data.batch_size": {attribute.Int, -1, "Batch size when training", nil}, - "worker.num": {attribute.Int, 1, "Worker number", func(i interface{}) error { - return checkIsPositiveInteger(i, "worker.num") - }}, - "worker.core": {attribute.Int, 8, "Worker core number", func(i interface{}) error { - return checkIsPositiveInteger(i, "worker.core") - }}, - "worker.memory": {attribute.Int, 4096, "Worker memory", func(i interface{}) error { - return checkIsPositiveInteger(i, "worker.memory") - }}, - "solver.*": {attribute.Unknown, nil, "Solver options", nil}, -} +var attributeDictionary = attribute.Dictionary{}. + Bool("data.enable_slice", false, "Whether to enable data slicing", nil). + Int("data.batch_size", -1, "Batch size when training", nil). + Int("worker.num", 1, "Worker number", attribute.IntLowerBoundChecker(1, true)). + Int("worker.core", 8, "Worker core number", attribute.IntLowerBoundChecker(1, true)). + Int("worker.memory", 4096, "Worker memory", attribute.IntLowerBoundChecker(1, true)). + Unknown("solver.*", nil, "Solver options", nil) // InitializeAttributes initialize attributes in optimize clause IR func InitializeAttributes(stmt *ir.OptimizeStmt) error { - attributeDictionary.FillDefaults(stmt.Attributes) + attributeDictionary.ExportDefaults(stmt.Attributes) err := attributeDictionary.Validate(stmt.Attributes) return err } diff --git a/pkg/codegen/pai/kmeans.go b/pkg/codegen/pai/kmeans.go index 18469c9e4e..25163c1e57 100644 --- a/pkg/codegen/pai/kmeans.go +++ b/pkg/codegen/pai/kmeans.go @@ -23,22 +23,21 @@ import ( pb "sqlflow.org/sqlflow/pkg/proto" ) -var kmeansAttributes = attribute.Dictionary{ - "center_count": {attribute.Int, 3, `[default=3] +var kmeansAttributes = attribute.Dictionary{}. + Int("center_count", 3, `[default=3] The cluster count. range: [1, Infinity] -`, attribute.IntLowerBoundChecker(1, true)}, - "idx_table_name": {attribute.String, "", ` +`, attribute.IntLowerBoundChecker(1, true)). + String("idx_table_name", "", ` The output table name which includes cluster_index column indicates the cluster result, distance column indicates the distance from the center and -all the columns of input table.`, nil}, - "excluded_columns": {attribute.String, "", `[default=""] -excluded the special feature columns from the SELECT statement.`, nil}, -} +all the columns of input table.`, nil). + String("excluded_columns", "", `[default=""] +excluded the special feature columns from the SELECT statement.`, nil) // InitializeKMeansAttributes initializes the attributes of KMeans and does type checking for them func InitializeKMeansAttributes(trainStmt *ir.TrainStmt) error { - kmeansAttributes.FillDefaults(trainStmt.Attributes) + kmeansAttributes.ExportDefaults(trainStmt.Attributes) return kmeansAttributes.Validate(trainStmt.Attributes) } @@ -55,7 +54,7 @@ func parseExcludedColsMap(attrs map[string]interface{}) map[string]int { } func getTrainKMeansPAICmd(ir *ir.TrainStmt, session *pb.Session) (string, error) { - kmeansAttributes.FillDefaults(ir.Attributes) + kmeansAttributes.ExportDefaults(ir.Attributes) if e := kmeansAttributes.Validate(ir.Attributes); e != nil { return "", e } diff --git a/pkg/codegen/tensorflow/codegen.go b/pkg/codegen/tensorflow/codegen.go index a7721d3f27..aa888a340c 100644 --- a/pkg/codegen/tensorflow/codegen.go +++ b/pkg/codegen/tensorflow/codegen.go @@ -28,46 +28,45 @@ import ( pb "sqlflow.org/sqlflow/pkg/proto" ) -var commonAttributes = attribute.Dictionary{ - "train.batch_size": {attribute.Int, 1, `[default=1] +var commonAttributes = attribute.Dictionary{}. + Int("train.batch_size", 1, `[default=1] The training batch size. -range: [1,Infinity]`, attribute.IntLowerBoundChecker(1, true)}, - "train.epoch": {attribute.Int, 1, `[default=1] +range: [1,Infinity]`, attribute.IntLowerBoundChecker(1, true)). + Int("train.epoch", 1, `[default=1] Number of epochs the training will run. -range: [1, Infinity]`, attribute.IntLowerBoundChecker(1, true)}, - "train.verbose": {attribute.Int, 0, `[default=0] +range: [1, Infinity]`, attribute.IntLowerBoundChecker(1, true)). + Int("train.verbose", 0, `[default=0] Show verbose logs when training. -possible values: 0, 1, 2`, attribute.IntChoicesChecker(0, 1, 2)}, - "train.max_steps": {attribute.Int, 0, `[default=0] -Max steps to run training.`, attribute.IntLowerBoundChecker(0, true)}, - "train.save_checkpoints_steps": {attribute.Int, 100, `[default=100] -Steps to run between saving checkpoints.`, attribute.IntLowerBoundChecker(1, true)}, - "train.log_every_n_iter": {attribute.Int, 10, `[default=10] -Print logs every n iterations`, attribute.IntLowerBoundChecker(1, true)}, - "validation.start_delay_secs": {attribute.Int, 0, `[default=0] -Seconds to wait before starting validation.`, attribute.IntLowerBoundChecker(0, true)}, - "validation.throttle_secs": {attribute.Int, 0, `[default=0] -Seconds to wait when need to run validation again.`, attribute.IntLowerBoundChecker(0, true)}, - "validation.metrics": {attribute.String, "Accuracy", `[default=""] +possible values: 0, 1, 2`, attribute.IntChoicesChecker(0, 1, 2)). + Int("train.max_steps", 0, `[default=0] +Max steps to run training.`, attribute.IntLowerBoundChecker(0, true)). + Int("train.save_checkpoints_steps", 100, `[default=100] +Steps to run between saving checkpoints.`, attribute.IntLowerBoundChecker(1, true)). + Int("train.log_every_n_iter", 10, `[default=10] +Print logs every n iterations`, attribute.IntLowerBoundChecker(1, true)). + Int("validation.start_delay_secs", 0, `[default=0] +Seconds to wait before starting validation.`, attribute.IntLowerBoundChecker(0, true)). + Int("validation.throttle_secs", 0, `[default=0] +Seconds to wait when need to run validation again.`, attribute.IntLowerBoundChecker(0, true)). + String("validation.metrics", "Accuracy", `[default=""] Specify metrics when training and evaluating. -example: "Accuracy,AUC"`, nil}, - "validation.select": {attribute.String, "", `[default=""] +example: "Accuracy,AUC"`, nil). + String("validation.select", "", `[default=""] Specify the dataset for validation. -example: "SELECT * FROM iris.train LIMIT 100"`, nil}, - "validation.steps": {attribute.Int, 1, `[default=1] -Specify steps for validation.`, attribute.IntLowerBoundChecker(1, true)}, -} -var distributedTrainingAttributes = attribute.Dictionary{ - "train.num_ps": {attribute.Int, 0, "", nil}, - "train.num_workers": {attribute.Int, 1, "", nil}, - "train.worker_cpu": {attribute.Int, 400, "", nil}, - "train.worker_gpu": {attribute.Int, 0, "", nil}, - "train.ps_cpu": {attribute.Int, 200, "", nil}, - "train.ps_gpu": {attribute.Int, 0, "", nil}, - "train.num_evaluator": {attribute.Int, 0, "", nil}, - "train.evaluator_cpu": {attribute.Int, 200, "", nil}, - "train.evaluator_gpu": {attribute.Int, 0, "", nil}, -} +example: "SELECT * FROM iris.train LIMIT 100"`, nil). + Int("validation.steps", 1, `[default=1] +Specify steps for validation.`, attribute.IntLowerBoundChecker(1, true)) + +var distributedTrainingAttributes = attribute.Dictionary{}. + Int("train.num_ps", 0, "", nil). + Int("train.num_workers", 1, "", nil). + Int("train.worker_cpu", 400, "", nil). + Int("train.worker_gpu", 0, "", nil). + Int("train.ps_cpu", 200, "", nil). + Int("train.ps_gpu", 0, "", nil). + Int("train.num_evaluator", 0, "", nil). + Int("train.evaluator_cpu", 200, "", nil). + Int("train.evaluator_gpu", 0, "", nil) func attrToPythonValue(attr interface{}) string { switch attr.(type) { @@ -232,7 +231,7 @@ func constructLosses(trainStmt *ir.TrainStmt) { // InitializeAttributes initializes the attributes of TensorFlow and does type checking for them func InitializeAttributes(trainStmt *ir.TrainStmt) error { attribute.ExtractSymbolOnce() - commonAttributes.FillDefaults(trainStmt.Attributes) + commonAttributes.ExportDefaults(trainStmt.Attributes) modelAttr := attribute.NewDictionaryFromModelDefinition(trainStmt.Estimator, "model.") // TODO(shendiaomo): Restrict optimizer parameters to the available set @@ -240,16 +239,17 @@ func InitializeAttributes(trainStmt *ir.TrainStmt) error { constructLosses(trainStmt) if len(modelAttr) == 0 { // TODO(shendiaomo): Use the same mechanism as `sqlflow_models` to extract parameters automatically - // Unknown custom models - modelAttr.Update(attribute.Dictionary{"model.*": {attribute.Unknown, nil, "Any model parameters defined in custom models", nil}}) + // unknownType custom models + modelAttr.Update(attribute.Dictionary{}. + Unknown("model.*", nil, "Any model parameters defined in custom models", nil)) } attrValidator := modelAttr.Update(commonAttributes) if strings.HasPrefix(trainStmt.Estimator, "sqlflow_models.") { // Special attributes defined as global variables in `sqlflow_models` - modelAttr.Update(attribute.Dictionary{ - "model.optimizer": {attribute.Unknown, nil, "Specify optimizer", nil}, - "model.loss": {attribute.Unknown, nil, "Specify loss", nil}, - "model.*": {attribute.Unknown, nil, "Any model parameters defined in custom models", nil}}) + modelAttr.Update(attribute.Dictionary{}. + Unknown("model.optimizer", nil, "Specify optimizer", nil). + Unknown("model.loss", nil, "Specify loss", nil). + Unknown("model.*", nil, "Any model parameters defined in custom models", nil)) } if IsPAI() { modelAttr.Update(distributedTrainingAttributes) diff --git a/pkg/codegen/xgboost/codegen.go b/pkg/codegen/xgboost/codegen.go index 4b300f9406..5081dcdf8b 100644 --- a/pkg/codegen/xgboost/codegen.go +++ b/pkg/codegen/xgboost/codegen.go @@ -17,12 +17,10 @@ import ( "bytes" "encoding/json" "fmt" - "regexp" "strings" - "sqlflow.org/sqlflow/pkg/codegen" - "sqlflow.org/sqlflow/pkg/attribute" + "sqlflow.org/sqlflow/pkg/codegen" tf "sqlflow.org/sqlflow/pkg/codegen/tensorflow" "sqlflow.org/sqlflow/pkg/ir" pb "sqlflow.org/sqlflow/pkg/proto" @@ -38,31 +36,31 @@ func getXGBoostObjectives() (ret []string) { // TODO(tony): complete model parameter and training parameter list // model parameter list: https://xgboost.readthedocs.io/en/latest/parameter.html#general-parameters // training parameter list: https://github.com/dmlc/xgboost/blob/b61d53447203ca7a321d72f6bdd3f553a3aa06c4/python-package/xgboost/training.py#L115-L117 -var attributeDictionary = attribute.Dictionary{ - "eta": {attribute.Float, float32(0.3), `[default=0.3, alias: learning_rate] +var attributeDictionary = attribute.Dictionary{}. + Float("eta", float32(0.3), `[default=0.3, alias: learning_rate] Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and eta shrinks the feature weights to make the boosting process more conservative. -range: [0,1]`, attribute.Float32RangeChecker(0, 1, true, true)}, - "num_class": {attribute.Int, nil, `Number of classes. -range: [2, Infinity]`, attribute.IntLowerBoundChecker(2, true)}, - "objective": {attribute.String, nil, `Learning objective`, attribute.StringChoicesChecker(getXGBoostObjectives()...)}, - "eval_metric": {attribute.String, nil, `eval metric`, nil}, - "train.disk_cache": {attribute.Bool, false, `whether use external memory to cache train data`, nil}, - "train.num_boost_round": {attribute.Int, 10, `[default=10] +range: [0,1]`, attribute.Float32RangeChecker(0, 1, true, true)). + Int("num_class", nil, `Number of classes. +range: [2, Infinity]`, attribute.IntLowerBoundChecker(2, true)). + String("objective", nil, `Learning objective`, attribute.StringChoicesChecker(getXGBoostObjectives()...)). + String("eval_metric", nil, `eval metric`, nil). + Bool("train.disk_cache", false, `whether use external memory to cache train data`, nil). + Int("train.num_boost_round", 10, `[default=10] The number of rounds for boosting. -range: [1, Infinity]`, attribute.IntLowerBoundChecker(1, true)}, - "train.batch_size": {attribute.Int, -1, `[default=-1] +range: [1, Infinity]`, attribute.IntLowerBoundChecker(1, true)). + Int("train.batch_size", -1, `[default=-1] Batch size for each iteration, -1 means use all data at once. -range: [-1, Infinity]`, attribute.IntLowerBoundChecker(-1, true)}, - "train.epoch": {attribute.Int, 1, `[default=1] +range: [-1, Infinity]`, attribute.IntLowerBoundChecker(-1, true)). + Int("train.epoch", 1, `[default=1] Number of rounds to run the training. -range: [1, Infinity]`, attribute.IntLowerBoundChecker(1, true)}, - "validation.select": {attribute.String, "", `[default=""] +range: [1, Infinity]`, attribute.IntLowerBoundChecker(1, true)). + String("validation.select", "", `[default=""] Specify the dataset for validation. -example: "SELECT * FROM boston.train LIMIT 8"`, nil}, - "train.num_workers": {attribute.Int, 1, `[default=1] +example: "SELECT * FROM boston.train LIMIT 8"`, nil). + Int("train.num_workers", 1, `[default=1] Number of workers for distributed train, 1 means stand-alone mode. -range: [1, 128]`, attribute.IntRangeChecker(1, 128, true, true)}, -} +range: [1, 128]`, attribute.IntRangeChecker(1, 128, true, true)) + var fullAttrValidator = attribute.Dictionary{} func objectiveChecker(obj interface{}) error { @@ -135,7 +133,7 @@ func resolveModelParams(ir *ir.TrainStmt) error { // InitializeAttributes initializes the attributes of XGBoost and does type checking for them func InitializeAttributes(trainStmt *ir.TrainStmt) error { - attributeDictionary.FillDefaults(trainStmt.Attributes) + attributeDictionary.ExportDefaults(trainStmt.Attributes) return fullAttrValidator.Validate(trainStmt.Attributes) } @@ -446,23 +444,7 @@ func Evaluate(evalStmt *ir.EvaluateStmt, session *pb.Session) (string, error) { } func init() { - re := regexp.MustCompile("[^a-z]") // xgboost.gbtree, xgboost.dart, xgboost.gblinear share the same parameter set fullAttrValidator = attribute.NewDictionaryFromModelDefinition("xgboost.gbtree", "") - for _, v := range fullAttrValidator { - pieces := strings.SplitN(v.Doc, " ", 2) - maybeType := re.ReplaceAllString(pieces[0], "") - if maybeType == strings.ToLower(maybeType) { - switch maybeType { - case "float": - v.Type = attribute.Float - case "int": - v.Type = attribute.Int - case "string": - v.Type = attribute.String - } - v.Doc = pieces[1] - } - } fullAttrValidator.Update(attributeDictionary) } diff --git a/pkg/workflow/couler/katib.go b/pkg/workflow/couler/katib.go index 9a44dbe35b..f14dc1c6ff 100644 --- a/pkg/workflow/couler/katib.go +++ b/pkg/workflow/couler/katib.go @@ -22,19 +22,18 @@ import ( "sqlflow.org/sqlflow/pkg/ir" ) -var attributeDictionary = attribute.Dictionary{ - "eta": {attribute.Float, float32(0.3), `[default=0.3, alias: learning_rate] +var attributeDictionary = attribute.Dictionary{}. + Float("eta", float32(0.3), `[default=0.3, alias: learning_rate] Step size shrinkage used in update to prevents overfitting. After each boosting step, we can directly get the weights of new features, and eta shrinks the feature weights to make the boosting process more conservative. -range: [0,1]`, attribute.Float32RangeChecker(0, 1, true, true)}, - "num_class": {attribute.Int, nil, `Number of classes. -range: [2, Infinity]`, attribute.IntLowerBoundChecker(2, true)}, - "objective": {attribute.String, nil, `Learning objective`, nil}, - "range.num_round": {attribute.IntList, nil, `[ default=[50, 100] ] The range of number of rounds for boosting.`, nil}, - "range.max_depth": {attribute.IntList, nil, `[ default=[2, 8] ] The range of max depth during training.`, nil}, - "validation.select": {attribute.String, nil, `[default=""] +range: [0,1]`, attribute.Float32RangeChecker(0, 1, true, true)). + Int("num_class", nil, `Number of classes. +range: [2, Infinity]`, attribute.IntLowerBoundChecker(2, true)). + String("objective", nil, `Learning objective`, nil). + IntList("range.num_round", nil, `[ default=[50, 100] ] The range of number of rounds for boosting.`, nil). + IntList("range.max_depth", nil, `[ default=[2, 8] ] The range of max depth during training.`, nil). + String("validation.select", nil, `[default=""] Specify the dataset for validation. -example: "SELECT * FROM boston.train LIMIT 8"`, nil}, -} +example: "SELECT * FROM boston.train LIMIT 8"`, nil) func resolveModelType(estimator string) (string, string, error) { switch strings.ToUpper(estimator) {