Skip to content

Commit 8fa90e7

Browse files
committed
feat(KDP): add the AdvancedNumericalEmbedding feature
1 parent bd90f11 commit 8fa90e7

File tree

3 files changed

+269
-6
lines changed

3 files changed

+269
-6
lines changed

kdp/custom_layers.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import numpy as np
77
import tensorflow as tf
88
import tensorflow_probability as tfp
9+
from tensorflow.keras import layers
10+
11+
from loguru import logger
912

1013

1114
class TextPreprocessingLayer(tf.keras.layers.Layer):
@@ -1945,3 +1948,220 @@ def from_config(cls, config: dict) -> "VariableSelection":
19451948
VariableSelection: A new instance of the layer.
19461949
"""
19471950
return cls(**config)
1951+
1952+
1953+
class AdvancedNumericalEmbedding(layers.Layer):
1954+
"""Advanced numerical embedding layer for continuous features.
1955+
1956+
This layer embeds each continuous numerical feature into a higher-dimensional space by
1957+
combining two branches:
1958+
1959+
1. Continuous Branch: Each feature is processed via a small MLP (using TimeDistributed layers).
1960+
2. Discrete Branch: Each feature is discretized into bins using learnable min/max boundaries
1961+
and then an embedding is looked up for its bin.
1962+
1963+
A learnable gate (of shape (num_features, embedding_dim)) combines the two branch outputs
1964+
per feature and per embedding dimension. Additionally, the continuous branch uses a residual
1965+
connection and optional batch normalization to improve training stability.
1966+
1967+
The layer supports inputs of shape (batch, num_features) for any number of features and returns
1968+
outputs of shape (batch, num_features, embedding_dim).
1969+
1970+
Args:
1971+
embedding_dim (int): Output embedding dimension per feature.
1972+
mlp_hidden_units (int): Hidden units for the continuous branch MLP.
1973+
num_bins (int): Number of bins for discretization.
1974+
init_min (float or list): Initial minimum values for discretization boundaries. If a scalar is
1975+
provided, it is applied to all features.
1976+
init_max (float or list): Initial maximum values for discretization boundaries.
1977+
dropout_rate (float): Dropout rate applied to the continuous branch.
1978+
use_batch_norm (bool): Whether to apply batch normalization to the continuous branch.
1979+
1980+
"""
1981+
1982+
def __init__(
1983+
self,
1984+
embedding_dim: int,
1985+
mlp_hidden_units: int,
1986+
num_bins: int,
1987+
init_min,
1988+
init_max,
1989+
dropout_rate: float = 0.0,
1990+
use_batch_norm: bool = False,
1991+
**kwargs,
1992+
):
1993+
super().__init__(**kwargs)
1994+
self.embedding_dim = embedding_dim
1995+
self.mlp_hidden_units = mlp_hidden_units
1996+
self.num_bins = num_bins
1997+
self.dropout_rate = dropout_rate
1998+
self.use_batch_norm = use_batch_norm
1999+
self.init_min = init_min
2000+
self.init_max = init_max
2001+
2002+
if self.num_bins is None:
2003+
raise ValueError(
2004+
"num_bins must be provided to activate the discrete branch."
2005+
)
2006+
2007+
def build(self, input_shape):
2008+
# input_shape: (batch, num_features)
2009+
self.num_features = input_shape[-1]
2010+
# Continuous branch: process each feature independently using TimeDistributed MLP.
2011+
self.cont_mlp = tf.keras.Sequential(
2012+
[
2013+
layers.TimeDistributed(
2014+
layers.Dense(self.mlp_hidden_units, activation="relu")
2015+
),
2016+
layers.TimeDistributed(layers.Dense(self.embedding_dim)),
2017+
],
2018+
name="cont_mlp",
2019+
)
2020+
self.dropout = (
2021+
layers.Dropout(self.dropout_rate)
2022+
if self.dropout_rate > 0
2023+
else lambda x, training: x
2024+
)
2025+
if self.use_batch_norm:
2026+
self.batch_norm = layers.TimeDistributed(
2027+
layers.BatchNormalization(), name="cont_batch_norm"
2028+
)
2029+
# Residual projection to match embedding_dim.
2030+
self.residual_proj = layers.TimeDistributed(
2031+
layers.Dense(self.embedding_dim, activation=None), name="residual_proj"
2032+
)
2033+
# Discrete branch: Create one Embedding layer per feature.
2034+
self.bin_embeddings = []
2035+
for i in range(self.num_features):
2036+
embed_layer = layers.Embedding(
2037+
input_dim=self.num_bins,
2038+
output_dim=self.embedding_dim,
2039+
name=f"bin_embed_{i}",
2040+
)
2041+
self.bin_embeddings.append(embed_layer)
2042+
# Learned bin boundaries for each feature, shape: (num_features,)
2043+
init_min_tensor = tf.convert_to_tensor(self.init_min, dtype=tf.float32)
2044+
init_max_tensor = tf.convert_to_tensor(self.init_max, dtype=tf.float32)
2045+
if init_min_tensor.shape.ndims == 0:
2046+
init_min_tensor = tf.fill([self.num_features], init_min_tensor)
2047+
if init_max_tensor.shape.ndims == 0:
2048+
init_max_tensor = tf.fill([self.num_features], init_max_tensor)
2049+
# Convert tensors to numpy arrays, which are acceptable by tf.constant_initializer.
2050+
init_min_value = (
2051+
init_min_tensor.numpy()
2052+
if hasattr(init_min_tensor, "numpy")
2053+
else init_min_tensor
2054+
)
2055+
init_max_value = (
2056+
init_max_tensor.numpy()
2057+
if hasattr(init_max_tensor, "numpy")
2058+
else init_max_tensor
2059+
)
2060+
2061+
self.learned_min = self.add_weight(
2062+
name="learned_min",
2063+
shape=(self.num_features,),
2064+
initializer=tf.constant_initializer(init_min_value),
2065+
trainable=True,
2066+
)
2067+
self.learned_max = self.add_weight(
2068+
name="learned_max",
2069+
shape=(self.num_features,),
2070+
initializer=tf.constant_initializer(init_max_value),
2071+
trainable=True,
2072+
)
2073+
# Gate to combine continuous and discrete branches, shape: (num_features, embedding_dim)
2074+
self.gate = self.add_weight(
2075+
name="gate",
2076+
shape=(self.num_features, self.embedding_dim),
2077+
initializer="zeros",
2078+
trainable=True,
2079+
)
2080+
logger.debug(
2081+
"AdvancedNumericalEmbedding built for {} features with embedding_dim={}",
2082+
self.num_features,
2083+
self.embedding_dim,
2084+
)
2085+
super().build(input_shape)
2086+
2087+
def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor:
2088+
# Continuous branch.
2089+
inputs_expanded = tf.expand_dims(inputs, axis=-1) # (batch, num_features, 1)
2090+
cont = self.cont_mlp(inputs_expanded)
2091+
cont = self.dropout(cont, training=training)
2092+
if self.use_batch_norm:
2093+
cont = self.batch_norm(cont, training=training)
2094+
# Residual connection.
2095+
cont_res = self.residual_proj(inputs_expanded)
2096+
cont = cont + cont_res # (batch, num_features, embedding_dim)
2097+
2098+
# Discrete branch.
2099+
inputs_float = tf.cast(inputs, tf.float32)
2100+
# Use learned min and max for scaling.
2101+
scaled = (inputs_float - self.learned_min) / (
2102+
self.learned_max - self.learned_min + 1e-6
2103+
)
2104+
# Compute bin indices.
2105+
bin_indices = tf.floor(scaled * self.num_bins)
2106+
bin_indices = tf.cast(bin_indices, tf.int32)
2107+
bin_indices = tf.clip_by_value(bin_indices, 0, self.num_bins - 1)
2108+
disc_embeddings = []
2109+
for i in range(self.num_features):
2110+
feat_bins = bin_indices[:, i] # (batch,)
2111+
feat_embed = self.bin_embeddings[i](
2112+
feat_bins
2113+
) # i is a Python integer here.
2114+
disc_embeddings.append(feat_embed)
2115+
disc = tf.stack(disc_embeddings, axis=1) # (batch, num_features, embedding_dim)
2116+
2117+
# Combine branches via a per-feature, per-dimension gate.
2118+
gate = tf.nn.sigmoid(self.gate) # (num_features, embedding_dim)
2119+
output = gate * cont + (1 - gate) * disc # (batch, num_features, embedding_dim)
2120+
return output
2121+
2122+
def get_config(self):
2123+
config = super().get_config()
2124+
config.update(
2125+
{
2126+
"embedding_dim": self.embedding_dim,
2127+
"mlp_hidden_units": self.mlp_hidden_units,
2128+
"num_bins": self.num_bins,
2129+
"init_min": self.init_min,
2130+
"init_max": self.init_max,
2131+
"dropout_rate": self.dropout_rate,
2132+
"use_batch_norm": self.use_batch_norm,
2133+
}
2134+
)
2135+
return config
2136+
2137+
2138+
if __name__ == "__main__":
2139+
tf.random.set_seed(42)
2140+
logger.info("Testing AdvancedNumericalEmbedding with multi-feature input.")
2141+
# Multi-feature test: 32 samples, 3 features.
2142+
x_multi = tf.random.normal((32, 3))
2143+
layer_multi = AdvancedNumericalEmbedding(
2144+
embedding_dim=8,
2145+
mlp_hidden_units=16,
2146+
num_bins=10,
2147+
init_min=[-3.0, -2.0, -4.0],
2148+
init_max=[3.0, 2.0, 4.0],
2149+
dropout_rate=0.1,
2150+
use_batch_norm=True,
2151+
)
2152+
y_multi = layer_multi(x_multi)
2153+
logger.info("Multi-feature output shape: {}", y_multi.shape)
2154+
2155+
# Single-feature test: 32 samples, 1 feature.
2156+
x_single = tf.random.normal((32, 1))
2157+
layer_single = AdvancedNumericalEmbedding(
2158+
embedding_dim=8,
2159+
mlp_hidden_units=16,
2160+
num_bins=10,
2161+
init_min=-3.0,
2162+
init_max=3.0,
2163+
dropout_rate=0.1,
2164+
use_batch_norm=True,
2165+
)
2166+
y_single = layer_single(x_single)
2167+
logger.info("Single-feature output shape: {}", y_single.shape)

kdp/features.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,52 @@ def from_string(type_str: str) -> "FeatureType":
117117

118118

119119
class NumericalFeature(Feature):
120-
"""NumericalFeature with dynamic kwargs passing."""
120+
"""NumericalFeature with dynamic kwargs passing and embedding support."""
121121

122122
def __init__(
123123
self,
124124
name: str,
125125
feature_type: FeatureType = FeatureType.FLOAT_NORMALIZED,
126126
preferred_distribution: DistributionType | None = None,
127+
use_embedding: bool = False,
128+
embedding_dim: int = 8,
129+
num_bins: int = 10,
127130
**kwargs,
128131
) -> None:
129132
"""Initializes a NumericalFeature instance.
130133
131134
Args:
132135
name (str): The name of the feature.
133136
feature_type (FeatureType): The type of the feature.
134-
preferred_distribution (DistributionType | None): The preferred distribution type for the feature.
137+
preferred_distribution (DistributionType | None): The preferred distribution type.
138+
use_embedding (bool): Whether to use advanced numerical embedding.
139+
embedding_dim (int): Dimension of the embedding space.
140+
num_bins (int): Number of bins for discretization.
135141
**kwargs: Additional keyword arguments for the feature.
136142
"""
137143
super().__init__(name, feature_type, **kwargs)
138144
self.dtype = tf.float32
139145
self.preferred_distribution = preferred_distribution
146+
self.use_embedding = use_embedding
147+
self.embedding_dim = embedding_dim
148+
self.num_bins = num_bins
149+
150+
def get_embedding_layer(self, input_shape: tuple) -> tf.keras.layers.Layer:
151+
"""Creates and returns an AdvancedNumericalEmbedding layer configured for this feature."""
152+
from kdp.custom_layers import (
153+
AdvancedNumericalEmbedding,
154+
) # Avoid circular import
155+
156+
return AdvancedNumericalEmbedding(
157+
embedding_dim=self.embedding_dim,
158+
mlp_hidden_units=max(16, self.embedding_dim * 2),
159+
num_bins=self.num_bins,
160+
init_min=self.kwargs.get("init_min", -3.0),
161+
init_max=self.kwargs.get("init_max", 3.0),
162+
dropout_rate=self.kwargs.get("dropout_rate", 0.1),
163+
use_batch_norm=self.kwargs.get("use_batch_norm", True),
164+
name=f"{self.name}_embedding",
165+
)
140166

141167

142168
class CategoricalFeature(Feature):

kdp/processor.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def __init__(
192192
distribution_aware_bins: int = 1000,
193193
feature_selection_units: int = 32,
194194
feature_selection_dropout: float = 0.2,
195+
use_advanced_numerical_embedding: bool = False,
195196
) -> None:
196197
"""Initialize a preprocessing model.
197198
@@ -258,6 +259,9 @@ def __init__(
258259
self.distribution_aware_bins = distribution_aware_bins
259260
self.feature_selection_dropout = feature_selection_dropout
260261

262+
# advanced numerical embedding control
263+
self.use_advanced_numerical_embedding = use_advanced_numerical_embedding
264+
261265
# PLACEHOLDERS
262266
self.preprocessors = {}
263267
self.inputs = {}
@@ -576,13 +580,13 @@ def _add_pipeline_numeric(
576580
stats (dict): A dictionary containing the metadata of the feature, including
577581
the mean and variance of the feature.
578582
"""
579-
# getting feature object
583+
# Get the feature specifications
580584
_feature = self.features_specs[feature_name]
581585

582-
# initializing preprocessor
586+
# Initialize preprocessor
583587
preprocessor = FeaturePreprocessor(name=feature_name)
584588

585-
# Add cast to float32 first for all numeric features
589+
# First, cast to float32 is applied to all numeric features.
586590
preprocessor.add_processing_step(
587591
layer_creator=PreprocessorLayerFactory.cast_to_float32_layer,
588592
name=f"cast_to_float_{feature_name}",
@@ -676,10 +680,23 @@ def _add_pipeline_numeric(
676680
name=f"norm_{feature_name}",
677681
)
678682

683+
# Check for advanced numerical embedding.
684+
if self.use_advanced_numerical_embedding:
685+
logger.info(f"Using AdvancedNumericalEmbedding for {feature_name}")
686+
# Obtain the embedding layer.
687+
embedding_layer = _feature.get_embedding_layer(
688+
input_shape=input_layer.shape
689+
)
690+
preprocessor.add_processing_step(
691+
layer_creator=lambda **kwargs: embedding_layer,
692+
layer_class="AdvancedNumericalEmbedding",
693+
name=f"advanced_embedding_{feature_name}",
694+
)
695+
679696
# Process the feature
680697
_output_pipeline = preprocessor.chain(input_layer=input_layer)
681698

682-
# Apply feature selection if enabled for numeric features
699+
# Optionally, apply feature selection for numeric features.
683700
if (
684701
self.feature_selection_placement == FeatureSelectionPlacementOptions.NUMERIC
685702
or self.feature_selection_placement

0 commit comments

Comments
 (0)