Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions sql/codegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package sql

import (
"encoding/json"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -163,19 +164,21 @@ func newFiller(pr *extendedSelect, ds *trainAndValDataset, fts fieldTypes, db *D
_, ok = col.(*columns.EmbeddingColumn).CategoryColumn.(*columns.SequenceCategoryIDColumn)
}
}
if !ok && col.GetDelimiter() != "" {
if _, ok := col.(*columns.NumericColumn); !ok {
isSparse = true
}
}
fm := &FeatureMeta{
FeatureName: col.GetKey(),
Dtype: col.GetDtype(),
Delimiter: col.GetDelimiter(),
InputShape: col.GetInputShape(),
IsSparse: isSparse,

fieldMetas := col.GetFieldMetas()
if len(fieldMetas) == 0 {
return nil, fmt.Errorf("no fieldmeta found for current column: %s", col.GetKey())
}
r.X = append(r.X, fm)
fm := fieldMetas[0]
jsonShape, _ := json.Marshal(fm.Shape)
r.X = append(r.X, &FeatureMeta{
FeatureName: fm.ColumnName,
Dtype: fm.DType,
Delimiter: fm.Delimiter,
InputShape: string(jsonShape),
IsSparse: fm.IsSparse,
})

featureColumnsCode[target] = append(
featureColumnsCode[target],
feaColCode[0])
Expand Down
41 changes: 18 additions & 23 deletions sql/codegen_alps.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ type alpsFiller struct {
OSSEndpoint string
}

// type alpsFeatureColumn interface {
// columns.FeatureColumn
// GenerateAlpsCode(metadata *metadata) ([]string, error)
// }

type alpsBucketCol struct {
columns.BucketColumn
}
Expand Down Expand Up @@ -184,7 +179,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra
}

var odpsConfig = &gomaxcompute.Config{}
var columnInfo map[string]*columns.ColumnSpec
var columnInfo map[string]*columns.FieldMeta

// TODO(joyyoj) read feature mapping table's name from table attributes.
// TODO(joyyoj) pr may contains partition.
Expand All @@ -205,7 +200,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra
meta.columnInfo = &columnInfo
} else {
meta = metadata{odpsConfig, pr.tables[0], nil, nil}
columnInfo = map[string]*columns.ColumnSpec{}
columnInfo = map[string]*columns.FieldMeta{}
for _, css := range resolved.ColumnSpecs {
for _, cs := range css {
columnInfo[cs.ColumnName] = cs
Expand All @@ -223,7 +218,7 @@ func newALPSTrainFiller(pr *extendedSelect, db *DB, session *pb.Session, ds *tra
for _, cs := range columnInfo {
csCode = append(csCode, cs.ToString())
}
y := &columns.ColumnSpec{
y := &columns.FieldMeta{
ColumnName: pr.label,
IsSparse: false,
Shape: []int{1},
Expand Down Expand Up @@ -432,7 +427,7 @@ func alpsPred(w *PipeWriter, pr *extendedSelect, db *DB, cwd string, session *pb
}

// GenerateCode overrides the member function defined in `category_id_column.go`
func (cc *alpsCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) {
func (cc *alpsCategoryIDCol) GenerateCode(cs *columns.FieldMeta) ([]string, error) {
output := make([]string, 0)
// columnInfo, present := (*metadata.columnInfo)[cc.Key]
columnInfo := cs
Expand All @@ -455,7 +450,7 @@ func (cc *alpsCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, err
return output, err
}

func (cc *alpsSeqCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) {
func (cc *alpsSeqCategoryIDCol) GenerateCode(cs *columns.FieldMeta) ([]string, error) {
output := make([]string, 0)
// columnInfo, present := (*metadata.columnInfo)[cc.Key]
columnInfo := cs
Expand All @@ -477,7 +472,7 @@ func (cc *alpsSeqCategoryIDCol) GenerateCode(cs *columns.ColumnSpec) ([]string,
return output, err
}

func (ec *alpsEmbeddingCol) GenerateCode(cs *columns.ColumnSpec) ([]string, error) {
func (ec *alpsEmbeddingCol) GenerateCode(cs *columns.FieldMeta) ([]string, error) {
var output []string
catColumn := &alpsCategoryIDCol{*(ec.CategoryColumn.(*columns.CategoryIDColumn))}
// if !ok {
Expand Down Expand Up @@ -539,11 +534,11 @@ type metadata struct {
odpsConfig *gomaxcompute.Config
table string
featureMap *columns.FeatureMap
columnInfo *map[string]*columns.ColumnSpec
columnInfo *map[string]*columns.FieldMeta
}

func flattenColumnSpec(columnSpecs map[string][]*columns.ColumnSpec) map[string]*columns.ColumnSpec {
output := map[string]*columns.ColumnSpec{}
func flattenColumnSpec(columnSpecs map[string][]*columns.FieldMeta) map[string]*columns.FieldMeta {
output := map[string]*columns.FieldMeta{}
for _, cols := range columnSpecs {
for _, col := range cols {
output[col.ColumnName] = col
Expand All @@ -552,8 +547,8 @@ func flattenColumnSpec(columnSpecs map[string][]*columns.ColumnSpec) map[string]
return output
}

func (meta *metadata) getColumnInfo(resolved *resolvedTrainClause, fields []string) (map[string]*columns.ColumnSpec, error) {
columns := map[string]*columns.ColumnSpec{}
func (meta *metadata) getColumnInfo(resolved *resolvedTrainClause, fields []string) (map[string]*columns.FieldMeta, error) {
columns := map[string]*columns.FieldMeta{}
refColumns := flattenColumnSpec(resolved.ColumnSpecs)

sparseColumns, _ := meta.getSparseColumnInfo()
Expand Down Expand Up @@ -638,8 +633,8 @@ func getFields(meta *metadata, pr *extendedSelect) ([]string, error) {
return fields, nil
}

func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*columns.ColumnSpec) (map[string]*columns.ColumnSpec, error) {
output := map[string]*columns.ColumnSpec{}
func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*columns.FieldMeta) (map[string]*columns.FieldMeta, error) {
output := map[string]*columns.FieldMeta{}
fields := strings.Join(keys, ",")
query := fmt.Sprintf("SELECT %s FROM %s LIMIT 1", fields, meta.table)
sqlDB, _ := sql.Open("maxcompute", meta.odpsConfig.FormatDSN())
Expand Down Expand Up @@ -669,7 +664,7 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c
shape := make([]int, 1)
shape[0] = len(fields)
if userSpec, ok := refColumns[ct.Name()]; ok {
output[ct.Name()] = &columns.ColumnSpec{
output[ct.Name()] = &columns.FieldMeta{
ct.Name(),
false,
shape,
Expand All @@ -678,7 +673,7 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c
nil,
*meta.featureMap}
} else {
output[ct.Name()] = &columns.ColumnSpec{
output[ct.Name()] = &columns.FieldMeta{
ct.Name(),
false,
shape,
Expand All @@ -692,8 +687,8 @@ func (meta *metadata) getDenseColumnInfo(keys []string, refColumns map[string]*c
return output, nil
}

func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, error) {
output := map[string]*columns.ColumnSpec{}
func (meta *metadata) getSparseColumnInfo() (map[string]*columns.FieldMeta, error) {
output := map[string]*columns.FieldMeta{}

sqlDB, _ := sql.Open("maxcompute", meta.odpsConfig.FormatDSN())
filter := "feature_type != '' "
Expand Down Expand Up @@ -732,7 +727,7 @@ func (meta *metadata) getSparseColumnInfo() (map[string]*columns.ColumnSpec, err
column, present := output[*name]
if !present {
shape := make([]int, 0, 1000)
column := &columns.ColumnSpec{*name, true, shape, "int64", "", nil, *meta.featureMap}
column := &columns.FieldMeta{*name, true, shape, "int64", "", nil, *meta.featureMap}
column.DType = "int64"
output[*name] = column
}
Expand Down
18 changes: 2 additions & 16 deletions sql/columns/bucket_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (

// BucketColumn is the wrapper of `tf.feature_column.bucketized_column`
type BucketColumn struct {
FeatureColumnMetasImpl
SourceColumn *NumericColumn
Boundaries []int
}

// GenerateCode implements the FeatureColumn interface.
func (bc *BucketColumn) GenerateCode(cs *ColumnSpec) ([]string, error) {
func (bc *BucketColumn) GenerateCode(cs *FieldMeta) ([]string, error) {
sourceCode, _ := bc.SourceColumn.GenerateCode(cs)
if len(sourceCode) > 1 {
return []string{}, fmt.Errorf("does not support grouped column: %v", sourceCode)
Expand All @@ -41,21 +42,6 @@ func (bc *BucketColumn) GetKey() string {
return bc.SourceColumn.Key
}

// GetDelimiter implements the FeatureColumn interface.
func (bc *BucketColumn) GetDelimiter() string {
return ""
}

// GetDtype implements the FeatureColumn interface.
func (bc *BucketColumn) GetDtype() string {
return ""
}

// GetInputShape implements the FeatureColumn interface.
func (bc *BucketColumn) GetInputShape() string {
return bc.SourceColumn.GetInputShape()
}

// GetColumnType implements the FeatureColumn interface.
func (bc *BucketColumn) GetColumnType() int {
return ColumnTypeBucket
Expand Down
46 changes: 4 additions & 42 deletions sql/columns/category_id_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

// CategoryIDColumn is the wrapper of `tf.feature_column.categorical_column_with_identity`
type CategoryIDColumn struct {
FeatureColumnMetasImpl
Key string
BucketSize int
Delimiter string
Expand All @@ -28,14 +29,15 @@ type CategoryIDColumn struct {
// SequenceCategoryIDColumn is the wrapper of `tf.feature_column.sequence_categorical_column_with_identity`
// NOTE: only used in tf >= 2.0 versions.
type SequenceCategoryIDColumn struct {
FeatureColumnMetasImpl
Key string
BucketSize int
Delimiter string
Dtype string
}

// GenerateCode implements the FeatureColumn interface.
func (cc *CategoryIDColumn) GenerateCode(cs *ColumnSpec) ([]string, error) {
func (cc *CategoryIDColumn) GenerateCode(cs *FieldMeta) ([]string, error) {
return []string{fmt.Sprintf("tf.feature_column.categorical_column_with_identity(key=\"%s\", num_buckets=%d)",
cc.Key, cc.BucketSize)}, nil
}
Expand All @@ -45,63 +47,23 @@ func (cc *CategoryIDColumn) GetKey() string {
return cc.Key
}

// GetDelimiter implements the FeatureColumn interface.
func (cc *CategoryIDColumn) GetDelimiter() string {
return cc.Delimiter
}

// GetDtype implements the FeatureColumn interface.
func (cc *CategoryIDColumn) GetDtype() string {
return cc.Dtype
}

// GetInputShape implements the FeatureColumn interface.
func (cc *CategoryIDColumn) GetInputShape() string {
return fmt.Sprintf("[%d]", cc.BucketSize)
}

// GetColumnType implements the FeatureColumn interface.
func (cc *CategoryIDColumn) GetColumnType() int {
return ColumnTypeCategoryID
}

// GenerateCode implements the FeatureColumn interface.
func (cc *SequenceCategoryIDColumn) GenerateCode(cs *ColumnSpec) ([]string, error) {
func (cc *SequenceCategoryIDColumn) GenerateCode(cs *FieldMeta) ([]string, error) {
return []string{fmt.Sprintf("tf.feature_column.sequence_categorical_column_with_identity(key=\"%s\", num_buckets=%d)",
cc.Key, cc.BucketSize)}, nil
}

// GetDelimiter implements the FeatureColumn interface.
func (cc *SequenceCategoryIDColumn) GetDelimiter() string {
return cc.Delimiter
}

// GetDtype implements the FeatureColumn interface.
func (cc *SequenceCategoryIDColumn) GetDtype() string {
return cc.Dtype
}

// GetKey implements the FeatureColumn interface.
func (cc *SequenceCategoryIDColumn) GetKey() string {
return cc.Key
}

// GetInputShape implements the FeatureColumn interface.
func (cc *SequenceCategoryIDColumn) GetInputShape() string {
return fmt.Sprintf("[%d]", cc.BucketSize)
}

// GetColumnType implements the FeatureColumn interface.
func (cc *SequenceCategoryIDColumn) GetColumnType() int {
return ColumnTypeSeqCategoryID
}

// func parseCategoryColumnKey(el *exprlist) (*columnSpec, error) {
// if (*el)[1].typ == 0 {
// // explist, maybe DENSE/SPARSE expressions
// subExprList := (*el)[1].sexp
// isSparse := subExprList[0].val == sparse
// return resolveColumnSpec(&subExprList, isSparse)
// }
// return nil, nil
// }
10 changes: 5 additions & 5 deletions sql/columns/column_spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ type FeatureMap struct {
Partition string
}

// ColumnSpec defines how to generate codes to parse column data to tensor/sparsetensor
type ColumnSpec struct {
// FieldMeta defines how to generate codes to parse column data to tensor/sparsetensor
type FieldMeta struct {
ColumnName string
IsSparse bool
Shape []int
DType string
Delimiter string
Vocabulary map[string]string // use a map to generate a list without duplication
FeatureMap FeatureMap
FeatureMap FeatureMap // FeatureMap is a table describes how to parse the columns data, only used in codegen_alps
}

// ToString generates the debug string of ColumnSpec
func (cs *ColumnSpec) ToString() string {
// ToString generates the debug string of FieldMeta
func (cs *FieldMeta) ToString() string {
if cs.IsSparse {
shape := strings.Join(strings.Split(fmt.Sprint(cs.Shape), " "), ",")
if len(cs.Shape) > 1 {
Expand Down
20 changes: 2 additions & 18 deletions sql/columns/cross_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ import (
// CrossColumn is the wapper of `tf.feature_column.crossed_column`
// TODO(uuleon) specify the hash_key if needed
type CrossColumn struct {
FeatureColumnMetasImpl
Keys []interface{}
HashBucketSize int
}

// GenerateCode implements the FeatureColumn interface.
func (cc *CrossColumn) GenerateCode(cs *ColumnSpec) ([]string, error) {
func (cc *CrossColumn) GenerateCode(cs *FieldMeta) ([]string, error) {
var keysGenerated = make([]string, len(cc.Keys))
for idx, key := range cc.Keys {
if c, ok := key.(FeatureColumn); ok {
Expand Down Expand Up @@ -57,23 +58,6 @@ func (cc *CrossColumn) GetKey() string {
return ""
}

// GetDelimiter implements the FeatureColumn interface.
func (cc *CrossColumn) GetDelimiter() string {
return ""
}

// GetDtype implements the FeatureColumn interface.
func (cc *CrossColumn) GetDtype() string {
return ""
}

// GetInputShape implements the FeatureColumn interface.
func (cc *CrossColumn) GetInputShape() string {
// NOTE: return empty since crossed column input shape is already determined
// by the two crossed columns.
return ""
}

// GetColumnType implements the FeatureColumn interface.
func (cc *CrossColumn) GetColumnType() int {
return ColumnTypeCross
Expand Down
Loading