diff --git a/sql/codegen.go b/sql/codegen.go index bdbb7e4228..707894cd7c 100644 --- a/sql/codegen.go +++ b/sql/codegen.go @@ -14,6 +14,7 @@ package sql import ( + "encoding/json" "fmt" "io" "strings" @@ -163,19 +164,21 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D _, ok = col.(*columns.EmbeddingColumn).CategoryColumn.(*columns.SequenceCategoryIDColumn) } } - if !ok && col.GetDelimiter() != "" { - if _, ok := col.(*columns.NumericColumn); !ok { - isSparse = true - } - } - fm := &FeatureMeta{ - FeatureName: col.GetKey(), - Dtype: col.GetDtype(), - Delimiter: col.GetDelimiter(), - InputShape: col.GetInputShape(), - IsSparse: isSparse, + + fieldMetas := col.GetFieldMetas() + if len(fieldMetas) == 0 { + return nil, fmt.Errorf("no fieldmeta found for current column: %s", col.GetKey()) } - r.X = append(r.X, fm) + fm := fieldMetas[0] + jsonShape, _ := json.Marshal(fm.Shape) + r.X = append(r.X, &FeatureMeta{ + FeatureName: fm.ColumnName, + Dtype: fm.DType, + Delimiter: fm.Delimiter, + InputShape: string(jsonShape), + IsSparse: fm.IsSparse, + }) + featureColumnsCode[target] = append( featureColumnsCode[target], feaColCode[0]) diff --git a/sql/codegen_alps.go b/sql/codegen_alps.go index 05152ee0f1..53a7cd0e27 100644 --- a/sql/codegen_alps.go +++ b/sql/codegen_alps.go @@ -77,11 +77,6 @@ type alpsFiller struct { OSSEndpoint string } -// type alpsFeatureColumn interface { -// columns.FeatureColumn -// GenerateAlpsCode(metadata *metadata) ([]string, error) -// } - type alpsBucketCol struct { columns.BucketColumn } @@ -184,7 +179,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra } var odpsConfig = &gomaxcompute.Config{} - var columnInfo map[string]*columns.ColumnSpec + var columnInfo map[string]*columns.FieldMeta // TODO(joyyoj) read feature mapping table's name from table attributes. // TODO(joyyoj) pr may contains partition. @@ -205,7 +200,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra meta.columnInfo = &columnInfo } else { meta = metadata{odpsConfig, pr.tables[0], nil, nil} - columnInfo = map[string]*columns.ColumnSpec{} + columnInfo = map[string]*columns.FieldMeta{} for _, css := range resolved.ColumnSpecs { for _, cs := range css { columnInfo[cs.ColumnName] = cs @@ -223,7 +218,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra for _, cs := range columnInfo { csCode = append(csCode, cs.ToString()) } - y := &columns.ColumnSpec{ + y := &columns.FieldMeta{ ColumnName: pr.label, IsSparse: false, Shape: []int{1}, @@ -432,7 +427,7 @@ func alpsPred(w *PipeWriter, pr *extendedSelect, db *DB, cwd string, session *pb } // GenerateCode overrides the member function defined in `category_id_column.go` -func (cc *alpsCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) { +func (cc *alpsCategoryIDCol) GenerateCode(cs *columns.FieldMeta) ([]string, error) { output := make([]string, 0) // columnInfo, present := (*metadata.columnInfo)[cc.Key] columnInfo := cs @@ -455,7 +450,7 @@ func (cc *alpsCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, err return output, err } -func (cc *alpsSeqCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) { +func (cc *alpsSeqCategoryIDCol) GenerateCode(cs *columns.FieldMeta) ([]string, error) { output := make([]string, 0) // columnInfo, present := (*metadata.columnInfo)[cc.Key] columnInfo := cs @@ -477,7 +472,7 @@ func (cc *alpsSeqCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, return output, err } -func (ec *alpsEmbeddingCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) { +func (ec *alpsEmbeddingCol) GenerateCode(cs *columns.FieldMeta) ([]string, error) { var output []string catColumn := &alpsCategoryIDCol{*(ec.CategoryColumn.(*columns.CategoryIDColumn))} // if !ok { @@ -539,11 +534,11 @@ type metadata struct { odpsConfig *gomaxcompute.Config table string featureMap *columns.FeatureMap - columnInfo *map[string]*columns.ColumnSpec + columnInfo *map[string]*columns.FieldMeta } -func flattenColumnSpec(columnSpecs map[string][]*columns.ColumnSpec) map[string]*columns.ColumnSpec { - output := map[string]*columns.ColumnSpec{} +func flattenColumnSpec(columnSpecs map[string][]*columns.FieldMeta) map[string]*columns.FieldMeta { + output := map[string]*columns.FieldMeta{} for _, cols := range columnSpecs { for _, col := range cols { output[col.ColumnName] = col @@ -552,8 +547,8 @@ func flattenColumnSpec(columnSpecs map[string][]*columns.ColumnSpec) map[string] return output } -func (meta *metadata) getColumnInfo(resolved *resolvedTrainClause, fields []string) (map[string]*columns.ColumnSpec, error) { - columns := map[string]*columns.ColumnSpec{} +func (meta *metadata) getColumnInfo(resolved *resolvedTrainClause, fields []string) (map[string]*columns.FieldMeta, error) { + columns := map[string]*columns.FieldMeta{} refColumns := flattenColumnSpec(resolved.ColumnSpecs) sparseColumns, _ := meta.getSparseColumnInfo() @@ -638,8 +633,8 @@ func getFields(meta *metadata, pr *extendedSelect) ([]string, error) { return fields, nil } -func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*columns.ColumnSpec) (map[string]*columns.ColumnSpec, error) { - output := map[string]*columns.ColumnSpec{} +func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*columns.FieldMeta) (map[string]*columns.FieldMeta, error) { + output := map[string]*columns.FieldMeta{} fields := strings.Join(keys, ",") query := fmt.Sprintf("SELECT %s FROM %s LIMIT 1", fields, meta.table) sqlDB, _ := sql.Open("maxcompute", meta.odpsConfig.FormatDSN()) @@ -669,7 +664,7 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c shape := make([]int, 1) shape[0] = len(fields) if userSpec, ok := refColumns[ct.Name()]; ok { - output[ct.Name()] = &columns.ColumnSpec{ + output[ct.Name()] = &columns.FieldMeta{ ct.Name(), false, shape, @@ -678,7 +673,7 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c nil, *meta.featureMap} } else { - output[ct.Name()] = &columns.ColumnSpec{ + output[ct.Name()] = &columns.FieldMeta{ ct.Name(), false, shape, @@ -692,8 +687,8 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c return output, nil } -func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, error) { - output := map[string]*columns.ColumnSpec{} +func (meta *metadata) getSparseColumnInfo() (map[string]*columns.FieldMeta, error) { + output := map[string]*columns.FieldMeta{} sqlDB, _ := sql.Open("maxcompute", meta.odpsConfig.FormatDSN()) filter := "feature_type != '' " @@ -732,7 +727,7 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, err column, present := output[*name] if !present { shape := make([]int, 0, 1000) - column := &columns.ColumnSpec{*name, true, shape, "int64", "", nil, *meta.featureMap} + column := &columns.FieldMeta{*name, true, shape, "int64", "", nil, *meta.featureMap} column.DType = "int64" output[*name] = column } diff --git a/sql/columns/bucket_column.go b/sql/columns/bucket_column.go index 053c1dfb07..d97724f524 100644 --- a/sql/columns/bucket_column.go +++ b/sql/columns/bucket_column.go @@ -20,12 +20,13 @@ import ( // BucketColumn is the wrapper of `tf.feature_column.bucketized_column` type BucketColumn struct { + FeatureColumnMetasImpl SourceColumn *NumericColumn Boundaries []int } // GenerateCode implements the FeatureColumn interface. -func (bc *BucketColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { +func (bc *BucketColumn) GenerateCode(cs *FieldMeta) ([]string, error) { sourceCode, _ := bc.SourceColumn.GenerateCode(cs) if len(sourceCode) > 1 { return []string{}, fmt.Errorf("does not support grouped column: %v", sourceCode) @@ -41,21 +42,6 @@ func (bc *BucketColumn) GetKey() string { return bc.SourceColumn.Key } -// GetDelimiter implements the FeatureColumn interface. -func (bc *BucketColumn) GetDelimiter() string { - return "" -} - -// GetDtype implements the FeatureColumn interface. -func (bc *BucketColumn) GetDtype() string { - return "" -} - -// GetInputShape implements the FeatureColumn interface. -func (bc *BucketColumn) GetInputShape() string { - return bc.SourceColumn.GetInputShape() -} - // GetColumnType implements the FeatureColumn interface. func (bc *BucketColumn) GetColumnType() int { return ColumnTypeBucket diff --git a/sql/columns/category_id_column.go b/sql/columns/category_id_column.go index 403695cdd8..512094225a 100644 --- a/sql/columns/category_id_column.go +++ b/sql/columns/category_id_column.go @@ -19,6 +19,7 @@ import ( // CategoryIDColumn is the wrapper of `tf.feature_column.categorical_column_with_identity` type CategoryIDColumn struct { + FeatureColumnMetasImpl Key string BucketSize int Delimiter string @@ -28,6 +29,7 @@ type CategoryIDColumn struct { // SequenceCategoryIDColumn is the wrapper of `tf.feature_column.sequence_categorical_column_with_identity` // NOTE: only used in tf >= 2.0 versions. type SequenceCategoryIDColumn struct { + FeatureColumnMetasImpl Key string BucketSize int Delimiter string @@ -35,7 +37,7 @@ type SequenceCategoryIDColumn struct { } // GenerateCode implements the FeatureColumn interface. -func (cc *CategoryIDColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { +func (cc *CategoryIDColumn) GenerateCode(cs *FieldMeta) ([]string, error) { return []string{fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)", cc.Key, cc.BucketSize)}, nil } @@ -45,63 +47,23 @@ func (cc *CategoryIDColumn) GetKey() string { return cc.Key } -// GetDelimiter implements the FeatureColumn interface. -func (cc *CategoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -// GetDtype implements the FeatureColumn interface. -func (cc *CategoryIDColumn) GetDtype() string { - return cc.Dtype -} - -// GetInputShape implements the FeatureColumn interface. -func (cc *CategoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - // GetColumnType implements the FeatureColumn interface. func (cc *CategoryIDColumn) GetColumnType() int { return ColumnTypeCategoryID } // GenerateCode implements the FeatureColumn interface. -func (cc *SequenceCategoryIDColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { +func (cc *SequenceCategoryIDColumn) GenerateCode(cs *FieldMeta) ([]string, error) { return []string{fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)", cc.Key, cc.BucketSize)}, nil } -// GetDelimiter implements the FeatureColumn interface. -func (cc *SequenceCategoryIDColumn) GetDelimiter() string { - return cc.Delimiter -} - -// GetDtype implements the FeatureColumn interface. -func (cc *SequenceCategoryIDColumn) GetDtype() string { - return cc.Dtype -} - // GetKey implements the FeatureColumn interface. func (cc *SequenceCategoryIDColumn) GetKey() string { return cc.Key } -// GetInputShape implements the FeatureColumn interface. -func (cc *SequenceCategoryIDColumn) GetInputShape() string { - return fmt.Sprintf("[%d]", cc.BucketSize) -} - // GetColumnType implements the FeatureColumn interface. func (cc *SequenceCategoryIDColumn) GetColumnType() int { return ColumnTypeSeqCategoryID } - -// func parseCategoryColumnKey(el *exprlist) (*columnSpec, error) { -// if (*el)[1].typ == 0 { -// // explist, maybe DENSE/SPARSE expressions -// subExprList := (*el)[1].sexp -// isSparse := subExprList[0].val == sparse -// return resolveColumnSpec(&subExprList, isSparse) -// } -// return nil, nil -// } diff --git a/sql/columns/column_spec.go b/sql/columns/column_spec.go index acf5145d60..5c1175ef47 100644 --- a/sql/columns/column_spec.go +++ b/sql/columns/column_spec.go @@ -25,19 +25,19 @@ type FeatureMap struct { Partition string } -// ColumnSpec defines how to generate codes to parse column data to tensor/sparsetensor -type ColumnSpec struct { +// FieldMeta defines how to generate codes to parse column data to tensor/sparsetensor +type FieldMeta struct { ColumnName string IsSparse bool Shape []int DType string Delimiter string Vocabulary map[string]string // use a map to generate a list without duplication - FeatureMap FeatureMap + FeatureMap FeatureMap // FeatureMap is a table describes how to parse the columns data, only used in codegen_alps } -// ToString generates the debug string of ColumnSpec -func (cs *ColumnSpec) ToString() string { +// ToString generates the debug string of FieldMeta +func (cs *FieldMeta) ToString() string { if cs.IsSparse { shape := strings.Join(strings.Split(fmt.Sprint(cs.Shape), " "), ",") if len(cs.Shape) > 1 { diff --git a/sql/columns/cross_column.go b/sql/columns/cross_column.go index cb5ece31d7..a47ba63afb 100644 --- a/sql/columns/cross_column.go +++ b/sql/columns/cross_column.go @@ -21,12 +21,13 @@ import ( // CrossColumn is the wapper of `tf.feature_column.crossed_column` // TODO(uuleon) specify the hash_key if needed type CrossColumn struct { + FeatureColumnMetasImpl Keys []interface{} HashBucketSize int } // GenerateCode implements the FeatureColumn interface. -func (cc *CrossColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { +func (cc *CrossColumn) GenerateCode(cs *FieldMeta) ([]string, error) { var keysGenerated = make([]string, len(cc.Keys)) for idx, key := range cc.Keys { if c, ok := key.(FeatureColumn); ok { @@ -57,23 +58,6 @@ func (cc *CrossColumn) GetKey() string { return "" } -// GetDelimiter implements the FeatureColumn interface. -func (cc *CrossColumn) GetDelimiter() string { - return "" -} - -// GetDtype implements the FeatureColumn interface. -func (cc *CrossColumn) GetDtype() string { - return "" -} - -// GetInputShape implements the FeatureColumn interface. -func (cc *CrossColumn) GetInputShape() string { - // NOTE: return empty since crossed column input shape is already determined - // by the two crossed columns. - return "" -} - // GetColumnType implements the FeatureColumn interface. func (cc *CrossColumn) GetColumnType() int { return ColumnTypeCross diff --git a/sql/columns/embedding_column.go b/sql/columns/embedding_column.go index d5cd48213c..8b0eb8fd6a 100644 --- a/sql/columns/embedding_column.go +++ b/sql/columns/embedding_column.go @@ -19,6 +19,7 @@ import ( // EmbeddingColumn is the wrapper of `tf.feature_column.embedding_column` type EmbeddingColumn struct { + FeatureColumnMetasImpl Key string // only used when CategoryColumn = nil, feature derivation will fill up the details CategoryColumn interface{} Dimension int @@ -34,28 +35,13 @@ func (ec *EmbeddingColumn) GetKey() string { return ec.Key } -// GetDelimiter implements the FeatureColumn interface. -func (ec *EmbeddingColumn) GetDelimiter() string { - return ec.CategoryColumn.(FeatureColumn).GetDelimiter() -} - -// GetDtype implements the FeatureColumn interface. -func (ec *EmbeddingColumn) GetDtype() string { - return ec.CategoryColumn.(FeatureColumn).GetDtype() -} - -// GetInputShape implements the FeatureColumn interface. -func (ec *EmbeddingColumn) GetInputShape() string { - return ec.CategoryColumn.(FeatureColumn).GetInputShape() -} - // GetColumnType implements the FeatureColumn interface. func (ec *EmbeddingColumn) GetColumnType() int { return ColumnTypeEmbedding } // GenerateCode implements the FeatureColumn interface. -func (ec *EmbeddingColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { +func (ec *EmbeddingColumn) GenerateCode(cs *FieldMeta) ([]string, error) { catColumn, ok := ec.CategoryColumn.(FeatureColumn) if !ok { return []string{}, fmt.Errorf("embedding generate code error, input is not featureColumn: %s", ec.CategoryColumn) diff --git a/sql/columns/feature_column.go b/sql/columns/feature_column.go index 3dc9b50be7..b418786c93 100644 --- a/sql/columns/feature_column.go +++ b/sql/columns/feature_column.go @@ -31,17 +31,40 @@ const ( // FeatureColumn is an interface that all types of feature columns // should follow. featureColumn is used to generate feature column code. type FeatureColumn interface { - // NOTE: submitters need to know the columnSpec when generating + FeatureColumnMetas + // NOTE: submitters need to know the FieldMeta when generating // feature_column code. And we maybe use one compound column's data to generate // multiple feature columns, so return a list of strings. - GenerateCode(cs *ColumnSpec) ([]string, error) + GenerateCode(cs *FieldMeta) ([]string, error) GetKey() string // FIXME(typhoonzero): remove delimiter, dtype shape from feature column // get these from column spec claused or by feature derivation. - GetDelimiter() string - GetDtype() string - GetInputShape() string + // GetDelimiter() string + // GetDtype() string + // GetInputShape() string GetColumnType() int } + +// FeatureColumnMetas is the interface of FieldMetas list, embeded in FeatureColumn interface. +type FeatureColumnMetas interface { + GetFieldMetas() []*FieldMeta + AppendFieldMetas(fm *FieldMeta) +} + +// FeatureColumnMetasImpl is a general struct that all feature column structs should embed. +type FeatureColumnMetasImpl struct { + // one feature column may use more than one field as input data, like cross_column + FieldMetas []*FieldMeta +} + +// GetFieldMetas returns the FieldMeta List +func (fcm *FeatureColumnMetasImpl) GetFieldMetas() []*FieldMeta { + return fcm.FieldMetas +} + +// AppendFieldMetas append a new FieldMeta +func (fcm *FeatureColumnMetasImpl) AppendFieldMetas(fm *FieldMeta) { + fcm.FieldMetas = append(fcm.FieldMetas, fm) +} diff --git a/sql/columns/numeric_column.go b/sql/columns/numeric_column.go index 399e5712fd..718546386d 100644 --- a/sql/columns/numeric_column.go +++ b/sql/columns/numeric_column.go @@ -14,33 +14,26 @@ package columns import ( - "encoding/json" "fmt" "strings" ) // NumericColumn is the wrapper of `tf.feature_column.numeric_column` type NumericColumn struct { - Key string - Shape []int - Delimiter string - Dtype string + FeatureColumnMetasImpl + Key string } // GenerateCode implements FeatureColumn interface. -func (nc *NumericColumn) GenerateCode(cs *ColumnSpec) ([]string, error) { +func (nc *NumericColumn) GenerateCode(cs *FieldMeta) ([]string, error) { + var shape []int + if len(nc.FieldMetas) > 0 { + shape = nc.FieldMetas[0].Shape + } else { + shape = []int{1} + } return []string{fmt.Sprintf("tf.feature_column.numeric_column(\"%s\", shape=%s)", nc.Key, - strings.Join(strings.Split(fmt.Sprint(nc.Shape), " "), ","))}, nil -} - -// GetDelimiter implements FeatureColumn interface. -func (nc *NumericColumn) GetDelimiter() string { - return nc.Delimiter -} - -// GetDtype implements FeatureColumn interface. -func (nc *NumericColumn) GetDtype() string { - return nc.Dtype + strings.Join(strings.Split(fmt.Sprint(shape), " "), ","))}, nil } // GetKey implements FeatureColumn interface. @@ -48,15 +41,6 @@ func (nc *NumericColumn) GetKey() string { return nc.Key } -// GetInputShape implements FeatureColumn interface. -func (nc *NumericColumn) GetInputShape() string { - jsonBytes, err := json.Marshal(nc.Shape) - if err != nil { - return "" - } - return string(jsonBytes) -} - // GetColumnType implements FeatureColumn interface. func (nc *NumericColumn) GetColumnType() int { return ColumnTypeNumeric diff --git a/sql/expression_resolver.go b/sql/expression_resolver.go index e2146f1bda..b84274f361 100644 --- a/sql/expression_resolver.go +++ b/sql/expression_resolver.go @@ -64,11 +64,9 @@ type resolvedTrainClause struct { EvalThrottle int EvalCheckpointFilenameForInit string FeatureColumns map[string][]columns.FeatureColumn - ColumnSpecs map[string][]*columns.ColumnSpec EngineParams engineSpec CustomModule *gitLabModule FeatureColumnInfered FeatureColumnMap - ColumnSpecInfered ColumnSpecMap } type resolvedPredictClause struct { @@ -212,7 +210,7 @@ func resolveTrainClause(tc *trainClause, slct *standardSelect, connConfig *conne } fcMap := map[string][]columns.FeatureColumn{} - csMap := map[string][]*columns.ColumnSpec{} + csMap := map[string][]*columns.FieldMeta{} for target, columns := range tc.columns { fcs, css, err := resolveTrainColumns(&columns) if err != nil { @@ -251,11 +249,9 @@ func resolveTrainClause(tc *trainClause, slct *standardSelect, connConfig *conne EvalThrottle: evalThrottleSecs, EvalCheckpointFilenameForInit: evalCheckpointFilenameForInit, FeatureColumns: fcMap, - ColumnSpecs: csMap, EngineParams: getEngineSpec(engineParams), CustomModule: customModel, FeatureColumnInfered: fcInfered, - ColumnSpecInfered: csInfered, }, nil } @@ -288,28 +284,32 @@ func resolvePredictClause(pc *predictClause) (*resolvedPredictClause, error) { // resolveTrainColumns resolve columns from SQL statement, // returns featureColumn list and featureSpecs -func resolveTrainColumns(columnExprs *exprlist) ([]columns.FeatureColumn, []*columns.ColumnSpec, error) { +func resolveTrainColumns(columnExprs *exprlist) ([]columns.FeatureColumn, []*columns.FieldMeta, error) { var fcs = make([]columns.FeatureColumn, 0) - var css = make([]*columns.ColumnSpec, 0) + var css = make([]*columns.FieldMeta, 0) for _, expr := range *columnExprs { if expr.typ != 0 { // Column identifier like "COLUMN a1,b1" c := &columns.NumericColumn{ - Key: expr.val, - Shape: []int{1}, - Dtype: "float32", + Key: expr.val, } + fm := &columns.FieldMeta{ + ColumnName: expr.val, + Shape: []int{1}, + DType: "float32", + } + c.FieldMetas = append(c.FieldMetas, fm) fcs = append(fcs, c) } else { - result, cs, err := resolveColumn(&expr.sexp) + result, err := resolveColumn(&expr.sexp) if err != nil { return nil, nil, err } - if cs != nil { - css = append(css, cs) + if fm, ok := result.(*columns.FieldMeta); ok { + css = append(css, fm) } - if result != nil { - fcs = append(fcs, result) + if fc, ok := result.(columns.FeatureColumn); ok { + fcs = append(fcs, fc) } } } @@ -320,11 +320,14 @@ func getExpressionFieldName(expr *expr) (string, error) { if expr.typ != 0 { return expr.val, nil } - fc, _, err := resolveColumn(&expr.sexp) + result, err := resolveColumn(&expr.sexp) if err != nil { return "", err } - return fc.GetKey(), nil + if fc, ok := result.(columns.FeatureColumn); ok { + return fc.GetKey(), nil + } + return "", fmt.Errorf("expression not a feature column") } // resolveExpression parse the expression recursively and @@ -334,16 +337,16 @@ func getExpressionFieldName(expr *expr) (string, error) { // column_1 -> "column_1", nil, nil // [1,2,3,4] -> [1,2,3,4], nil, nil // [NUMERIC(col1), col2] -> [*numericColumn, "col2"], nil, nil -func resolveExpression(e interface{}) (interface{}, interface{}, error) { +func resolveExpression(e interface{}) (interface{}, error) { if expr, ok := e.(*expr); ok { if expr.typ != 0 { - return expr.val, nil, nil + return expr.val, nil } return resolveExpression(&expr.sexp) } el, ok := e.(*exprlist) if !ok { - return nil, nil, fmt.Errorf("input of resolveExpression must be `expr` or `exprlist` given %s", e) + return nil, fmt.Errorf("input of resolveExpression must be `expr` or `exprlist` given %s", e) } headTyp := (*el)[0].typ if headTyp == IDENT { @@ -351,7 +354,6 @@ func resolveExpression(e interface{}) (interface{}, interface{}, error) { return resolveColumn(el) } else if headTyp == '[' { var list []interface{} - var columnSpecList []interface{} for idx, expr := range *el { if idx > 0 { if expr.sexp == nil { @@ -363,22 +365,21 @@ func resolveExpression(e interface{}) (interface{}, interface{}, error) { list = append(list, intVal) } } else { - value, cs, err := resolveExpression(&expr.sexp) + value, err := resolveExpression(&expr.sexp) if err != nil { - return nil, nil, err + return nil, err } list = append(list, value) - columnSpecList = append(columnSpecList, cs) } } } - return list, columnSpecList, nil + return list, nil } - return nil, nil, fmt.Errorf("not supported expr: %v", el) + return nil, fmt.Errorf("not supported expr: %v", el) } func expression2string(e interface{}) (string, error) { - resolved, _, err := resolveExpression(e) + resolved, err := resolveExpression(e) if err != nil { return "", err } @@ -417,7 +418,7 @@ func resolveAttribute(attrs *attrs) (map[string]*attribute, error) { if len(subs) == 2 { prefix = subs[0] } - r, _, err := resolveExpression(v) + r, err := resolveExpression(v) if err != nil { return nil, err } @@ -440,14 +441,17 @@ func resolveBucketColumn(el *exprlist) (*columns.BucketColumn, error) { if sourceExprList.typ != 0 { return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %v", sourceExprList) } - source, _, err := resolveColumn(&sourceExprList.sexp) + source, err := resolveColumn(&sourceExprList.sexp) if err != nil { return nil, err } - if source.GetColumnType() != columns.ColumnTypeNumeric { - return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) + if fc, ok := source.(columns.FeatureColumn); !ok { + if fc.GetColumnType() != columns.ColumnTypeNumeric { + return nil, fmt.Errorf("key of BUCKET must be NUMERIC, which is %s", source) + } } - boundaries, _, err := resolveExpression(boundariesExprList) + + boundaries, err := resolveExpression(boundariesExprList) if err != nil { return nil, err } @@ -463,37 +467,41 @@ func resolveBucketColumn(el *exprlist) (*columns.BucketColumn, error) { Boundaries: b}, nil } -func resolveSeqCategoryIDColumn(el *exprlist) (*columns.SequenceCategoryIDColumn, *columns.ColumnSpec, error) { - key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) +func resolveSeqCategoryIDColumn(el *exprlist) (*columns.SequenceCategoryIDColumn, error) { + key, bucketSize, delimiter, fm, err := parseCategoryIDColumnExpr(el) if err != nil { - return nil, nil, err + return nil, err } - return &columns.SequenceCategoryIDColumn{ + fc := &columns.SequenceCategoryIDColumn{ Key: key, BucketSize: bucketSize, Delimiter: delimiter, // TODO(typhoonzero): support config dtype - Dtype: "int64"}, cs, nil + Dtype: "int64"} + fc.AppendFieldMetas(fm) + return fc, nil } -func resolveCategoryIDColumn(el *exprlist) (*columns.CategoryIDColumn, *columns.ColumnSpec, error) { - key, bucketSize, delimiter, cs, err := parseCategoryIDColumnExpr(el) +func resolveCategoryIDColumn(el *exprlist) (*columns.CategoryIDColumn, error) { + key, bucketSize, delimiter, fm, err := parseCategoryIDColumnExpr(el) if err != nil { - return nil, nil, err + return nil, err } - return &columns.CategoryIDColumn{ + fc := &columns.CategoryIDColumn{ Key: key, BucketSize: bucketSize, Delimiter: delimiter, // TODO(typhoonzero): support config dtype - Dtype: "int64"}, cs, nil + Dtype: "int64"} + fc.AppendFieldMetas(fm) + return fc, nil } -func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columns.ColumnSpec, error) { +func parseCategoryIDColumnExpr(el *exprlist) (string, int, string, *columns.FieldMeta, error) { if len(*el) != 3 && len(*el) != 4 { return "", 0, "", nil, fmt.Errorf("bad CATEGORY_ID expression format: %s", *el) } - var cs *columns.ColumnSpec + var cs *columns.FieldMeta key := "" var err error if (*el)[1].typ == 0 { @@ -530,7 +538,7 @@ func resolveCrossColumn(el *exprlist) (*columns.CrossColumn, error) { return nil, fmt.Errorf("bad CROSS expression format: %s", *el) } keysExpr := (*el)[1] - key, _, err := resolveExpression(keysExpr) + key, err := resolveExpression(keysExpr) if err != nil { return nil, err } @@ -547,45 +555,45 @@ func resolveCrossColumn(el *exprlist) (*columns.CrossColumn, error) { HashBucketSize: bucketSize}, nil } -func resolveEmbeddingColumn(el *exprlist) (*columns.EmbeddingColumn, *columns.ColumnSpec, error) { +func resolveEmbeddingColumn(el *exprlist) (*columns.EmbeddingColumn, error) { if len(*el) != 4 && len(*el) != 5 { - return nil, nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) + return nil, fmt.Errorf("bad EMBEDDING expression format: %s", *el) } sourceExprList := (*el)[1] - var source columns.FeatureColumn - var cs *columns.ColumnSpec + var source interface{} var err error var innerCategoryColumnKey string var catColumnResult interface{} if sourceExprList.typ == 0 { - source, cs, err = resolveColumn(&sourceExprList.sexp) + source, err = resolveColumn(&sourceExprList.sexp) if err != nil { - return nil, nil, err + return nil, err } // user may write EMBEDDING(SPARSE(...)) or EMBEDDING(DENSE(...)) - if cs != nil { - innerCategoryColumnKey = cs.ColumnName + fm, ok := source.(*columns.FieldMeta) + if ok { + innerCategoryColumnKey = fm.ColumnName catColumnResult = &columns.CategoryIDColumn{ - Key: cs.ColumnName, - BucketSize: cs.Shape[0], - Delimiter: cs.Delimiter, - Dtype: cs.DType, + Key: fm.ColumnName, + BucketSize: fm.Shape[0], + Delimiter: fm.Delimiter, + Dtype: fm.DType, } + catColumnResult.(*columns.CategoryIDColumn).AppendFieldMetas(fm) } else { // TODO(uuleon) support other kinds of categorical column in the future var catColumn interface{} - catColumn, ok := source.(*columns.CategoryIDColumn) + catColumn, ok = source.(*columns.CategoryIDColumn) if !ok { catColumn, ok = source.(*columns.SequenceCategoryIDColumn) if !ok { - return nil, nil, fmt.Errorf("key of EMBEDDING must be categorical column") + return nil, fmt.Errorf("key of EMBEDDING must be categorical column") } } - // NOTE: to avoid golang multiple assignment compiler restrictions catColumnResult = catColumn - innerCategoryColumnKey = source.GetKey() + innerCategoryColumnKey = source.(columns.FeatureColumn).GetKey() } } else { // generate a default CategoryIDColumn for later feature derivation. @@ -595,17 +603,17 @@ func resolveEmbeddingColumn(el *exprlist) (*columns.EmbeddingColumn, *columns.Co dimension, err := strconv.Atoi((*el)[2].val) if err != nil { - return nil, nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) + return nil, fmt.Errorf("bad EMBEDDING dimension: %s, err: %s", (*el)[2].val, err) } combiner, err := expression2string((*el)[3]) if err != nil { - return nil, nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) + return nil, fmt.Errorf("bad EMBEDDING combiner: %s, err: %s", (*el)[3], err) } initializer := "" if len(*el) == 5 { initializer, err = expression2string((*el)[4]) if err != nil { - return nil, nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) + return nil, fmt.Errorf("bad EMBEDDING initializer: %s, err: %s", (*el)[4], err) } } return &columns.EmbeddingColumn{ @@ -613,7 +621,7 @@ func resolveEmbeddingColumn(el *exprlist) (*columns.EmbeddingColumn, *columns.Co CategoryColumn: catColumnResult, Dimension: dimension, Combiner: combiner, - Initializer: initializer}, cs, nil + Initializer: initializer}, nil } func resolveNumericColumn(el *exprlist) (*columns.NumericColumn, error) { @@ -627,7 +635,7 @@ func resolveNumericColumn(el *exprlist) (*columns.NumericColumn, error) { var shape []int intVal, err := strconv.Atoi((*el)[2].val) if err != nil { - list, _, err := resolveExpression((*el)[2]) + list, err := resolveExpression((*el)[2]) if err != nil { return nil, err } @@ -643,14 +651,10 @@ func resolveNumericColumn(el *exprlist) (*columns.NumericColumn, error) { shape = append(shape, intVal) } return &columns.NumericColumn{ - Key: key, - Shape: shape, - // FIXME(typhoonzero, tony): support config Delimiter and Dtype - Delimiter: ",", - Dtype: "float32"}, nil + Key: key}, nil } -func resolveColumnSpec(el *exprlist, isSparse bool) (*columns.ColumnSpec, error) { +func resolveColumnSpec(el *exprlist, isSparse bool) (*columns.FieldMeta, error) { if len(*el) < 4 { return nil, fmt.Errorf("bad FeatureSpec expression format: %s", *el) } @@ -690,7 +694,7 @@ func resolveColumnSpec(el *exprlist, isSparse bool) (*columns.ColumnSpec, error) if len(*el) >= 5 { dtype, err = expression2string((*el)[4]) } - return &columns.ColumnSpec{ + return &columns.FieldMeta{ ColumnName: name, IsSparse: isSparse, Shape: shape, @@ -699,31 +703,34 @@ func resolveColumnSpec(el *exprlist, isSparse bool) (*columns.ColumnSpec, error) FeatureMap: fm}, nil } -// resolveFeatureColumn returns the acutal feature column typed struct -// as well as the columnSpec infomation. -func resolveColumn(el *exprlist) (columns.FeatureColumn, *columns.ColumnSpec, error) { +// resolveColumn returns the acutal feature column typed struct +// as well as the FieldMeta infomation. +func resolveColumn(el *exprlist) (interface{}, error) { head := (*el)[0].val if head == "" { - return nil, nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %v", el) + return nil, fmt.Errorf("column description expects format like NUMERIC(key) etc, got %v", el) } switch strings.ToUpper(head) { case dense: - cs, err := resolveColumnSpec(el, false) - return nil, cs, err + fm, err := resolveColumnSpec(el, false) + if err != nil { + return nil, err + } + return fm, err case sparse: - cs, err := resolveColumnSpec(el, true) - return nil, cs, err + fm, err := resolveColumnSpec(el, true) + if err != nil { + return nil, err + } + return fm, err case numeric: // TODO(typhoonzero): support NUMERIC(DENSE(col)) and NUMERIC(SPARSE(col)) - fc, err := resolveNumericColumn(el) - return fc, nil, err + return resolveNumericColumn(el) case bucket: - fc, err := resolveBucketColumn(el) - return fc, nil, err + return resolveBucketColumn(el) case cross: - fc, err := resolveCrossColumn(el) - return fc, nil, err + return resolveCrossColumn(el) case categoryID: return resolveCategoryIDColumn(el) case seqCategoryID: @@ -731,6 +738,6 @@ func resolveColumn(el *exprlist) (columns.FeatureColumn, *columns.ColumnSpec, er case embedding: return resolveEmbeddingColumn(el) default: - return nil, nil, fmt.Errorf("not supported expr: %s", head) + return nil, fmt.Errorf("not supported expr: %s", head) } } diff --git a/sql/featurederivation.go b/sql/featurederivation.go index dcfc013b5d..590e6efafa 100644 --- a/sql/featurederivation.go +++ b/sql/featurederivation.go @@ -28,8 +28,8 @@ const featureDerivationRows = 1000 // FeatureColumnMap is a mapping from column name to FeatureColumn struct type FeatureColumnMap map[string]columns.FeatureColumn -// ColumnSpecMap is a mappign from column name to ColumnSpec struct -type ColumnSpecMap map[string]*columns.ColumnSpec +// ColumnSpecMap is a mappign from column name to FieldMeta struct +type ColumnSpecMap map[string]*columns.FieldMeta // makeFeatureColumnMap returns a map from column key to FeatureColumn // NOTE that the target is not important for analyzing feature derivation. @@ -43,9 +43,9 @@ func makeFeatureColumnMap(parsedFeatureColumns map[string][]columns.FeatureColum return fcMap } -// makeColumnSpecMap returns a map from column key to ColumnSpec +// makeColumnSpecMap returns a map from column key to FieldMeta // NOTE that the target is not important for analyzing feature derivation. -func makeColumnSpecMap(parsedColumnSpecs map[string][]*columns.ColumnSpec) ColumnSpecMap { +func makeColumnSpecMap(parsedColumnSpecs map[string][]*columns.FieldMeta) ColumnSpecMap { csMap := make(ColumnSpecMap) for _, fcList := range parsedColumnSpecs { for _, cs := range fcList { @@ -84,9 +84,9 @@ func fillColumnSpec(columnTypeList []*sql.ColumnType, rowdata []interface{}, csm } for idx, ct := range columnTypeList { _, fld := decomp(ct.Name()) - // add a default ColumnSpec for updating. + // add a default FieldMeta for updating. if _, ok := csmap[fld]; !ok { - csmap[fld] = &columns.ColumnSpec{ + csmap[fld] = &columns.FieldMeta{ ColumnName: fld, IsSparse: false, Shape: nil, @@ -172,11 +172,11 @@ func fillColumnSpec(columnTypeList []*sql.ColumnType, rowdata []interface{}, csm return nil } -// InferFeatureColumns fill up featureColumn and columnSpec structs +// InferFeatureColumns fill up featureColumn and FieldMeta structs // for all fields. func InferFeatureColumns(slct *standardSelect, parsedFeatureColumns map[string][]columns.FeatureColumn, - parsedColumnSpecs map[string][]*columns.ColumnSpec, + parsedColumnSpecs map[string][]*columns.FieldMeta, connConfig *connectionConfig) (FeatureColumnMap, ColumnSpecMap, error) { if connConfig == nil { return nil, nil, fmt.Errorf("no connectionConfig provided")