Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions doc/model_parameter.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ INTO sqlflow_models.my_xgb_regression_model;
<tr>
<td>max_bin</td>
<td>%!s(<nil>)</td>
<td>used if tree_method is set to hist, Maximum number of discrete bins to bucket continuous features.</td>
<td>Only used if tree_method is set to hist, Maximum number of discrete bins to bucket continuous features.</td>
</tr>
<tr>
<td>max_delta_step</td>
Expand Down Expand Up @@ -153,7 +153,7 @@ INTO sqlflow_models.my_xgb_regression_model;
</tr>
<tr>
<td>silent</td>
<td>%!s(<nil>)</td>
<td>bool</td>
<td>Whether to print messages while running boosting. Deprecated. Use verbosity instead.</td>
</tr>
<tr>
Expand Down
183 changes: 115 additions & 68 deletions pkg/attribute/attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"log"
"os/exec"
"reflect"
"regexp"
"sort"
"strings"
"sync"
Expand All @@ -30,29 +31,29 @@ const (
)

var (
// Bool indicates that the corresponding attribute is a boolean
Bool = reflect.TypeOf(true)
// Int indicates that the corresponding attribute is an integer
Int = reflect.TypeOf(0)
// Float indicates that the corresponding attribute is a float32
Float = reflect.TypeOf(float32(0.))
// String indicates the corresponding attribute is a string
String = reflect.TypeOf("")
// IntList indicates the corresponding attribute is a list of integers
IntList = reflect.TypeOf([]int{})
// Unknown type indicates that the attribute type is dynamically determined.
Unknown = reflect.Type(nil)
// boolType indicates that the corresponding attribute is a boolean
boolType = reflect.TypeOf(true)
// intType indicates that the corresponding attribute is an integer
intType = reflect.TypeOf(0)
// floatType indicates that the corresponding attribute is a float32
floatType = reflect.TypeOf(float32(0.))
// stringType indicates the corresponding attribute is a string
stringType = reflect.TypeOf("")
// intListType indicates the corresponding attribute is a list of integers
intListType = reflect.TypeOf([]int{})
// unknownType indicates that the attribute type is dynamically determined.
unknownType = reflect.Type(nil)
)

// Dictionary contains the description of all attributes
type Dictionary map[string]*Description

// Description describes a requirement for a particular attribute
type Description struct {
Type reflect.Type
Default interface{}
Doc string
Checker func(i interface{}) error
type Dictionary map[string]*description

// description describes a requirement for a particular attribute
type description struct {
typ reflect.Type
defaultValue interface{}
doc string
checker func(i interface{}) error
}

// Int declares an attribute of int-typed in Dictionary d.
Expand All @@ -74,25 +75,31 @@ func (d Dictionary) Int(name string, value interface{}, doc string, checker func
}
}

d[name] = &Description{
Type: Int,
Default: value,
Doc: doc,
Checker: interfaceChecker,
d[name] = &description{
typ: intType,
defaultValue: value,
doc: doc,
checker: interfaceChecker,
}
return d
}

// Float declares an attribute of float32-typed in Dictionary d.
func (d Dictionary) Float(name string, value interface{}, doc string, checker func(float32) error) Dictionary {
interfaceChecker := func(v interface{}) error {
var fValue float32
if floatValue, ok := v.(float32); ok {
if checker != nil {
return checker(floatValue)
}
return nil
fValue = floatValue
} else if intValue, ok := v.(int); ok { // implicit type conversion from int to float
fValue = float32(intValue)
} else {
return fmt.Errorf("attribute %s must be of type float, but got %T", name, v)
}

if checker != nil {
return checker(fValue)
}
return fmt.Errorf("attribute %s must be of type float, but got %T", name, v)
return nil
}

if value != nil {
Expand All @@ -102,11 +109,20 @@ func (d Dictionary) Float(name string, value interface{}, doc string, checker fu
}
}

d[name] = &Description{
Type: Float,
Default: value,
Doc: doc,
Checker: interfaceChecker,
var fInterfaceValue interface{}
if value == nil {
fInterfaceValue = nil
} else if floatValue, ok := value.(float32); ok {
fInterfaceValue = floatValue
} else if intValue, ok := value.(int); ok { // implicit type conversion from int to float
fInterfaceValue = float32(intValue)
}

d[name] = &description{
typ: floatType,
defaultValue: fInterfaceValue,
doc: doc,
checker: interfaceChecker,
}
return d
}
Expand All @@ -130,11 +146,11 @@ func (d Dictionary) Bool(name string, value interface{}, doc string, checker fun
}
}

d[name] = &Description{
Type: Bool,
Default: value,
Doc: doc,
Checker: interfaceChecker,
d[name] = &description{
typ: boolType,
defaultValue: value,
doc: doc,
checker: interfaceChecker,
}
return d
}
Expand All @@ -158,11 +174,11 @@ func (d Dictionary) String(name string, value interface{}, doc string, checker f
}
}

d[name] = &Description{
Type: String,
Default: value,
Doc: doc,
Checker: interfaceChecker,
d[name] = &description{
typ: stringType,
defaultValue: value,
doc: doc,
checker: interfaceChecker,
}
return d
}
Expand All @@ -186,11 +202,11 @@ func (d Dictionary) IntList(name string, value interface{}, doc string, checker
}
}

d[name] = &Description{
Type: IntList,
Default: value,
Doc: doc,
Checker: interfaceChecker,
d[name] = &description{
typ: intListType,
defaultValue: value,
doc: doc,
checker: interfaceChecker,
}
return d
}
Expand All @@ -204,28 +220,28 @@ func (d Dictionary) Unknown(name string, value interface{}, doc string, checker
}
}

d[name] = &Description{
Type: Unknown,
Default: value,
Doc: doc,
Checker: checker,
d[name] = &description{
typ: unknownType,
defaultValue: value,
doc: doc,
checker: checker,
}
return d
}

// FillDefaults fills default values defined in Dictionary to attrs.
func (d Dictionary) FillDefaults(attrs map[string]interface{}) {
// ExportDefaults exports default values defined in Dictionary to attrs.
func (d Dictionary) ExportDefaults(attrs map[string]interface{}) {
for k, v := range d {
// Do not fill default value for unknown type, and with nil default values.
if v.Type == Unknown {
if v.typ == unknownType {
continue
}
if v.Default == nil {
if v.defaultValue == nil {
continue
}
_, ok := attrs[k]
if !ok {
attrs[k] = v.Default
attrs[k] = v.defaultValue
}
}
}
Expand All @@ -235,7 +251,7 @@ func (d Dictionary) FillDefaults(attrs map[string]interface{}) {
// 2. Customer checker
func (d Dictionary) Validate(attrs map[string]interface{}) error {
for k, v := range attrs {
var desc *Description
var desc *description
desc, ok := d[k]
if !ok {
// Support attribute definition like "model.*" to match
Expand All @@ -254,15 +270,15 @@ func (d Dictionary) Validate(attrs map[string]interface{}) error {
}
}

if desc.Type != Unknown && desc.Type != reflect.TypeOf(v) {
if desc.typ != unknownType && desc.typ != reflect.TypeOf(v) {
// Allow implicit conversion from int to float to ease typing
if !(desc.Type == Float && reflect.TypeOf(v) == Int) {
return fmt.Errorf(errUnexpectedType, k, desc.Type, v)
if !(desc.typ == floatType && reflect.TypeOf(v) == intType) {
return fmt.Errorf(errUnexpectedType, k, desc.typ, v)
}
}

if desc.Checker != nil {
if err := desc.Checker(v); err != nil {
if desc.checker != nil {
if err := desc.checker(v); err != nil {
return err
}
}
Expand Down Expand Up @@ -293,7 +309,7 @@ func (d Dictionary) GenerateTableInHTML() string {
<td>%s</td>
</tr>`
// NOTE(tony): if the doc string has multiple lines, need to replace \n with <br>
s := fmt.Sprintf(t, k, desc.Type, strings.Replace(desc.Doc, "\n", `<br>`, -1))
s := fmt.Sprintf(t, k, desc.typ, strings.Replace(desc.doc, "\n", `<br>`, -1))
l = append(l, s)
}

Expand All @@ -311,9 +327,40 @@ func (d Dictionary) Update(other Dictionary) Dictionary {

// NewDictionaryFromModelDefinition create a new Dictionary according to pre-made estimators or XGBoost model types.
func NewDictionaryFromModelDefinition(estimator, prefix string) Dictionary {
isXGBoostModel := strings.HasPrefix(estimator, "xgboost")
re := regexp.MustCompile("[^a-z]")

var d = Dictionary{}
for param, doc := range PremadeModelParamsDocs[estimator] {
d[prefix+param] = &Description{Unknown, nil, doc, nil}
desc := &description{unknownType, nil, doc, nil}
d[prefix+param] = desc

if !isXGBoostModel {
continue
}

// Fill typ field according to the model parameter doc
// The doc would be like: "int Maximum tree depth for base learners"
pieces := strings.SplitN(strings.TrimSpace(desc.doc), " ", 2)
if len(pieces) != 2 {
continue
}

maybeType := re.ReplaceAllString(pieces[0], "")
switch strings.ToLower(maybeType) {
case "float":
desc.typ = floatType
desc.doc = pieces[1]
case "int":
desc.typ = intType
desc.doc = pieces[1]
case "string":
desc.typ = stringType
desc.doc = pieces[1]
case "boolean":
desc.typ = boolType
desc.doc = pieces[1]
}
}
return d
}
Expand Down
22 changes: 9 additions & 13 deletions pkg/attribute/attribute_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,13 @@ func TestDictionaryNamedTypeChecker(t *testing.T) {
func TestDictionaryValidate(t *testing.T) {
a := assert.New(t)

checker := func(i interface{}) error {
ii, ok := i.(int)
if !ok {
return fmt.Errorf("%T %v should of type integer", i, i)
}
if ii < 0 {
checker := func(i int) error {
if i < 0 {
return fmt.Errorf("some error")
}
return nil
}
tb := Dictionary{"a": {Int, 1, "attribute a", checker}, "b": {Float, 1, "attribute b", nil}}
tb := Dictionary{}.Int("a", 1, "attribute a", checker).Float("b", float32(1), "attribute b", nil)
a.NoError(tb.Validate(map[string]interface{}{"a": 1}))
a.EqualError(tb.Validate(map[string]interface{}{"a": -1}), "some error")
a.EqualError(tb.Validate(map[string]interface{}{"_a": -1}), fmt.Sprintf(errUnsupportedAttribute, "_a"))
Expand Down Expand Up @@ -165,7 +161,7 @@ func TestParamsDocs(t *testing.T) {
func TestNewAndUpdateDictionary(t *testing.T) {
a := assert.New(t)

commonAttrs := Dictionary{"a": {Int, 1, "attribute a", nil}}
commonAttrs := Dictionary{}.Int("a", 1, "attribute a", nil)
specificAttrs := NewDictionaryFromModelDefinition("DNNClassifier", "model.")
a.Equal(len(specificAttrs), 12)
a.Equal(len(specificAttrs.Update(specificAttrs)), 12)
Expand All @@ -180,12 +176,12 @@ func TestNewAndUpdateDictionary(t *testing.T) {

func TestDictionary_GenerateTableInHTML(t *testing.T) {
a := assert.New(t)
tb := Dictionary{
"a": {Int, 1, `this is a
tb := Dictionary{}.
Int("a", 1, `this is a
multiple line
doc string.`, nil},
"世界": {String, "", `42`, nil},
}
doc string.`, nil).
String("世界", "", `42`, nil)

expected := `<table>
<tr>
<td>Name</td>
Expand Down
Loading