diff --git a/sql/featurederivation.go b/sql/featurederivation.go index dcfc013b5d..e6eeee80e2 100644 --- a/sql/featurederivation.go +++ b/sql/featurederivation.go @@ -25,8 +25,8 @@ import ( const featureDerivationRows = 1000 -// FeatureColumnMap is a mapping from column name to FeatureColumn struct -type FeatureColumnMap map[string]columns.FeatureColumn +// FeatureColumnMap is like: target -> key -> FeatureColumn +type FeatureColumnMap map[string]map[string]columns.FeatureColumn // ColumnSpecMap is a mappign from column name to ColumnSpec struct type ColumnSpecMap map[string]*columns.ColumnSpec @@ -35,9 +35,10 @@ type ColumnSpecMap map[string]*columns.ColumnSpec // NOTE that the target is not important for analyzing feature derivation. func makeFeatureColumnMap(parsedFeatureColumns map[string][]columns.FeatureColumn) FeatureColumnMap { fcMap := make(FeatureColumnMap) - for _, fcList := range parsedFeatureColumns { + for target, fcList := range parsedFeatureColumns { + fcMap[target] = make(map[string]columns.FeatureColumn) for _, fc := range fcList { - fcMap[fc.GetKey()] = fc + fcMap[target][fc.GetKey()] = fc } } return fcMap @@ -245,44 +246,58 @@ func InferFeatureColumns(slct *standardSelect, // 1. Infer omited category_id_column for embedding_columns // 2. Add derivated feature column. - for slctKey := range selectFieldTypeMap { - if fc, ok := fcMap[slctKey]; ok { - if fc.GetColumnType() == columns.ColumnTypeEmbedding { - if fc.(*columns.EmbeddingColumn).CategoryColumn == nil { - cs, ok := csMap[fc.GetKey()] - if !ok { - return nil, nil, fmt.Errorf("column not found or infered: %s", fc.GetKey()) + // + // need to store FeatureColumn under it's target in case of + // the same column used for different target, e.g. + // COLUMN EMBEDDING(c1) for deep + // EMBEDDING(c2) for deep + // EMBEDDING(c1) for wide + for target := range parsedFeatureColumns { + for slctKey := range selectFieldTypeMap { + fcTargetMap, ok := fcMap[target] + if !ok { + // create map for current target + fcMap[target] = make(map[string]columns.FeatureColumn) + fcTargetMap = fcMap[target] + } + if fc, ok := fcTargetMap[slctKey]; ok { + if fc.GetColumnType() == columns.ColumnTypeEmbedding { + if fc.(*columns.EmbeddingColumn).CategoryColumn == nil { + cs, ok := csMap[fc.GetKey()] + if !ok { + return nil, nil, fmt.Errorf("column not found or infered: %s", fc.GetKey()) + } + // FIXME(typhoonzero): when to use sequence_category_id_column? + fc.(*columns.EmbeddingColumn).CategoryColumn = &columns.CategoryIDColumn{ + Key: cs.ColumnName, + BucketSize: cs.Shape[0], + Delimiter: cs.Delimiter, + Dtype: cs.DType, + } + } + } + } else { + cs, ok := csMap[slctKey] + if !ok { + return nil, nil, fmt.Errorf("column not found or infered: %s", slctKey) + } + if cs.DType != "string" { + fcMap[target][slctKey] = &columns.NumericColumn{ + Key: cs.ColumnName, + Shape: cs.Shape, + Dtype: cs.DType, + Delimiter: cs.Delimiter, } - // FIXME(typhoonzero): when to use sequence_category_id_column? - fc.(*columns.EmbeddingColumn).CategoryColumn = &columns.CategoryIDColumn{ + } else { + // FIXME(typhoonzero): need full test case for string numeric columns + fcMap[target][slctKey] = &columns.CategoryIDColumn{ Key: cs.ColumnName, - BucketSize: cs.Shape[0], + BucketSize: len(cs.Vocabulary), Delimiter: cs.Delimiter, Dtype: cs.DType, } } } - } else { - cs, ok := csMap[slctKey] - if !ok { - return nil, nil, fmt.Errorf("column not found or infered: %s", slctKey) - } - if cs.DType != "string" { - fcMap[slctKey] = &columns.NumericColumn{ - Key: cs.ColumnName, - Shape: cs.Shape, - Dtype: cs.DType, - Delimiter: cs.Delimiter, - } - } else { - // FIXME(typhoonzero): need full test case for string numeric columns - fcMap[slctKey] = &columns.CategoryIDColumn{ - Key: cs.ColumnName, - BucketSize: len(cs.Vocabulary), - Delimiter: cs.Delimiter, - Dtype: cs.DType, - } - } } } diff --git a/sql/featurederivation_test.go b/sql/featurederivation_test.go index 672b18d964..542eaa6a35 100644 --- a/sql/featurederivation_test.go +++ b/sql/featurederivation_test.go @@ -14,6 +14,7 @@ package sql import ( + "fmt" "regexp" "testing" @@ -103,17 +104,18 @@ func TestFeatureDerivation(t *testing.T) { a.Equal("int", cs.DType) a.True(cs.IsSparse) - fc := res.FeatureColumnInfered["c1"] + fmt.Printf("fc inferred: %v\n", res.FeatureColumnInfered) + fc := res.FeatureColumnInfered["feature_columns"]["c1"] a.Equal(columns.ColumnTypeNumeric, fc.GetColumnType()) - fc = res.FeatureColumnInfered["c3"] + fc = res.FeatureColumnInfered["feature_columns"]["c3"] a.Equal(columns.ColumnTypeEmbedding, fc.GetColumnType()) emb, ok := fc.(*columns.EmbeddingColumn) a.True(ok) a.NotNil(emb.CategoryColumn) a.Equal("c3", emb.CategoryColumn.(*columns.CategoryIDColumn).GetKey()) - fc = res.FeatureColumnInfered["c5"] + fc = res.FeatureColumnInfered["feature_columns"]["c5"] a.Equal(columns.ColumnTypeEmbedding, fc.GetColumnType()) emb, ok = fc.(*columns.EmbeddingColumn) a.Equal(10000, emb.CategoryColumn.(*columns.CategoryIDColumn).BucketSize)