Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 50 additions & 35 deletions sql/featurederivation.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (

const featureDerivationRows = 1000

// FeatureColumnMap is a mapping from column name to FeatureColumn struct
type FeatureColumnMap map[string]columns.FeatureColumn
// FeatureColumnMap is like: target -> key -> FeatureColumn
type FeatureColumnMap map[string]map[string]columns.FeatureColumn

// ColumnSpecMap is a mappign from column name to ColumnSpec struct
type ColumnSpecMap map[string]*columns.ColumnSpec
Expand All @@ -35,9 +35,10 @@ type ColumnSpecMap map[string]*columns.ColumnSpec
// NOTE that the target is not important for analyzing feature derivation.
func makeFeatureColumnMap(parsedFeatureColumns map[string][]columns.FeatureColumn) FeatureColumnMap {
fcMap := make(FeatureColumnMap)
for _, fcList := range parsedFeatureColumns {
for target, fcList := range parsedFeatureColumns {
fcMap[target] = make(map[string]columns.FeatureColumn)
for _, fc := range fcList {
fcMap[fc.GetKey()] = fc
fcMap[target][fc.GetKey()] = fc
}
}
return fcMap
Expand Down Expand Up @@ -245,44 +246,58 @@ func InferFeatureColumns(slct *standardSelect,

// 1. Infer omited category_id_column for embedding_columns
// 2. Add derivated feature column.
for slctKey := range selectFieldTypeMap {
if fc, ok := fcMap[slctKey]; ok {
if fc.GetColumnType() == columns.ColumnTypeEmbedding {
if fc.(*columns.EmbeddingColumn).CategoryColumn == nil {
cs, ok := csMap[fc.GetKey()]
if !ok {
return nil, nil, fmt.Errorf("column not found or infered: %s", fc.GetKey())
//
// need to store FeatureColumn under it's target in case of
// the same column used for different target, e.g.
// COLUMN EMBEDDING(c1) for deep
// EMBEDDING(c2) for deep
// EMBEDDING(c1) for wide
for target := range parsedFeatureColumns {
for slctKey := range selectFieldTypeMap {
fcTargetMap, ok := fcMap[target]
if !ok {
// create map for current target
fcMap[target] = make(map[string]columns.FeatureColumn)
fcTargetMap = fcMap[target]
}
if fc, ok := fcTargetMap[slctKey]; ok {
if fc.GetColumnType() == columns.ColumnTypeEmbedding {
if fc.(*columns.EmbeddingColumn).CategoryColumn == nil {
cs, ok := csMap[fc.GetKey()]
if !ok {
return nil, nil, fmt.Errorf("column not found or infered: %s", fc.GetKey())
}
// FIXME(typhoonzero): when to use sequence_category_id_column?
fc.(*columns.EmbeddingColumn).CategoryColumn = &columns.CategoryIDColumn{
Key: cs.ColumnName,
BucketSize: cs.Shape[0],
Delimiter: cs.Delimiter,
Dtype: cs.DType,
}
}
}
} else {
cs, ok := csMap[slctKey]
if !ok {
return nil, nil, fmt.Errorf("column not found or infered: %s", slctKey)
}
if cs.DType != "string" {
fcMap[target][slctKey] = &columns.NumericColumn{
Key: cs.ColumnName,
Shape: cs.Shape,
Dtype: cs.DType,
Delimiter: cs.Delimiter,
}
// FIXME(typhoonzero): when to use sequence_category_id_column?
fc.(*columns.EmbeddingColumn).CategoryColumn = &columns.CategoryIDColumn{
} else {
// FIXME(typhoonzero): need full test case for string numeric columns
fcMap[target][slctKey] = &columns.CategoryIDColumn{
Key: cs.ColumnName,
BucketSize: cs.Shape[0],
BucketSize: len(cs.Vocabulary),
Delimiter: cs.Delimiter,
Dtype: cs.DType,
}
}
}
} else {
cs, ok := csMap[slctKey]
if !ok {
return nil, nil, fmt.Errorf("column not found or infered: %s", slctKey)
}
if cs.DType != "string" {
fcMap[slctKey] = &columns.NumericColumn{
Key: cs.ColumnName,
Shape: cs.Shape,
Dtype: cs.DType,
Delimiter: cs.Delimiter,
}
} else {
// FIXME(typhoonzero): need full test case for string numeric columns
fcMap[slctKey] = &columns.CategoryIDColumn{
Key: cs.ColumnName,
BucketSize: len(cs.Vocabulary),
Delimiter: cs.Delimiter,
Dtype: cs.DType,
}
}
}
}

Expand Down
8 changes: 5 additions & 3 deletions sql/featurederivation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package sql

import (
"fmt"
"regexp"
"testing"

Expand Down Expand Up @@ -103,17 +104,18 @@ func TestFeatureDerivation(t *testing.T) {
a.Equal("int", cs.DType)
a.True(cs.IsSparse)

fc := res.FeatureColumnInfered["c1"]
fmt.Printf("fc inferred: %v\n", res.FeatureColumnInfered)
fc := res.FeatureColumnInfered["feature_columns"]["c1"]
a.Equal(columns.ColumnTypeNumeric, fc.GetColumnType())

fc = res.FeatureColumnInfered["c3"]
fc = res.FeatureColumnInfered["feature_columns"]["c3"]
a.Equal(columns.ColumnTypeEmbedding, fc.GetColumnType())
emb, ok := fc.(*columns.EmbeddingColumn)
a.True(ok)
a.NotNil(emb.CategoryColumn)
a.Equal("c3", emb.CategoryColumn.(*columns.CategoryIDColumn).GetKey())

fc = res.FeatureColumnInfered["c5"]
fc = res.FeatureColumnInfered["feature_columns"]["c5"]
a.Equal(columns.ColumnTypeEmbedding, fc.GetColumnType())
emb, ok = fc.(*columns.EmbeddingColumn)
a.Equal(10000, emb.CategoryColumn.(*columns.CategoryIDColumn).BucketSize)
Expand Down