|
6 | 6 | import numpy as np |
7 | 7 | import tensorflow as tf |
8 | 8 | import tensorflow_probability as tfp |
| 9 | +from tensorflow.keras import layers |
| 10 | + |
| 11 | +from loguru import logger |
9 | 12 |
|
10 | 13 |
|
11 | 14 | class TextPreprocessingLayer(tf.keras.layers.Layer): |
@@ -1945,3 +1948,188 @@ def from_config(cls, config: dict) -> "VariableSelection": |
1945 | 1948 | VariableSelection: A new instance of the layer. |
1946 | 1949 | """ |
1947 | 1950 | return cls(**config) |
| 1951 | + |
| 1952 | + |
| 1953 | +class AdvancedNumericalEmbedding(layers.Layer): |
| 1954 | + """Advanced numerical embedding layer for continuous features. |
| 1955 | +
|
| 1956 | + This layer embeds each continuous numerical feature into a higher-dimensional space by |
| 1957 | + combining two branches: |
| 1958 | +
|
| 1959 | + 1. Continuous Branch: Each feature is processed via a small MLP (using TimeDistributed layers). |
| 1960 | + 2. Discrete Branch: Each feature is discretized into bins using learnable min/max boundaries |
| 1961 | + and then an embedding is looked up for its bin. |
| 1962 | +
|
| 1963 | + A learnable gate (of shape (num_features, embedding_dim)) combines the two branch outputs |
| 1964 | + per feature and per embedding dimension. Additionally, the continuous branch uses a residual |
| 1965 | + connection and optional batch normalization to improve training stability. |
| 1966 | +
|
| 1967 | + The layer supports inputs of shape (batch, num_features) for any number of features and returns |
| 1968 | + outputs of shape (batch, num_features, embedding_dim). |
| 1969 | +
|
| 1970 | + Args: |
| 1971 | + embedding_dim (int): Output embedding dimension per feature. |
| 1972 | + mlp_hidden_units (int): Hidden units for the continuous branch MLP. |
| 1973 | + num_bins (int): Number of bins for discretization. |
| 1974 | + init_min (float or list): Initial minimum values for discretization boundaries. If a scalar is |
| 1975 | + provided, it is applied to all features. |
| 1976 | + init_max (float or list): Initial maximum values for discretization boundaries. |
| 1977 | + dropout_rate (float): Dropout rate applied to the continuous branch. |
| 1978 | + use_batch_norm (bool): Whether to apply batch normalization to the continuous branch. |
| 1979 | +
|
| 1980 | + """ |
| 1981 | + |
| 1982 | + def __init__( |
| 1983 | + self, |
| 1984 | + embedding_dim: int, |
| 1985 | + mlp_hidden_units: int, |
| 1986 | + num_bins: int, |
| 1987 | + init_min, |
| 1988 | + init_max, |
| 1989 | + dropout_rate: float = 0.0, |
| 1990 | + use_batch_norm: bool = False, |
| 1991 | + **kwargs, |
| 1992 | + ): |
| 1993 | + super().__init__(**kwargs) |
| 1994 | + self.embedding_dim = embedding_dim |
| 1995 | + self.mlp_hidden_units = mlp_hidden_units |
| 1996 | + self.num_bins = num_bins |
| 1997 | + self.dropout_rate = dropout_rate |
| 1998 | + self.use_batch_norm = use_batch_norm |
| 1999 | + self.init_min = init_min |
| 2000 | + self.init_max = init_max |
| 2001 | + |
| 2002 | + if self.num_bins is None: |
| 2003 | + raise ValueError( |
| 2004 | + "num_bins must be provided to activate the discrete branch." |
| 2005 | + ) |
| 2006 | + |
| 2007 | + def build(self, input_shape): |
| 2008 | + # input_shape: (batch, num_features) |
| 2009 | + self.num_features = input_shape[-1] |
| 2010 | + # Continuous branch: process each feature independently using TimeDistributed MLP. |
| 2011 | + self.cont_mlp = tf.keras.Sequential( |
| 2012 | + [ |
| 2013 | + layers.TimeDistributed( |
| 2014 | + layers.Dense(self.mlp_hidden_units, activation="relu") |
| 2015 | + ), |
| 2016 | + layers.TimeDistributed(layers.Dense(self.embedding_dim)), |
| 2017 | + ], |
| 2018 | + name="cont_mlp", |
| 2019 | + ) |
| 2020 | + self.dropout = ( |
| 2021 | + layers.Dropout(self.dropout_rate) |
| 2022 | + if self.dropout_rate > 0 |
| 2023 | + else lambda x, training: x |
| 2024 | + ) |
| 2025 | + if self.use_batch_norm: |
| 2026 | + self.batch_norm = layers.TimeDistributed( |
| 2027 | + layers.BatchNormalization(), name="cont_batch_norm" |
| 2028 | + ) |
| 2029 | + # Residual projection to match embedding_dim. |
| 2030 | + self.residual_proj = layers.TimeDistributed( |
| 2031 | + layers.Dense(self.embedding_dim, activation=None), name="residual_proj" |
| 2032 | + ) |
| 2033 | + # Discrete branch: Create one Embedding layer per feature. |
| 2034 | + self.bin_embeddings = [] |
| 2035 | + for i in range(self.num_features): |
| 2036 | + embed_layer = layers.Embedding( |
| 2037 | + input_dim=self.num_bins, |
| 2038 | + output_dim=self.embedding_dim, |
| 2039 | + name=f"bin_embed_{i}", |
| 2040 | + ) |
| 2041 | + self.bin_embeddings.append(embed_layer) |
| 2042 | + # Learned bin boundaries for each feature, shape: (num_features,) |
| 2043 | + init_min_tensor = tf.convert_to_tensor(self.init_min, dtype=tf.float32) |
| 2044 | + init_max_tensor = tf.convert_to_tensor(self.init_max, dtype=tf.float32) |
| 2045 | + if init_min_tensor.shape.ndims == 0: |
| 2046 | + init_min_tensor = tf.fill([self.num_features], init_min_tensor) |
| 2047 | + if init_max_tensor.shape.ndims == 0: |
| 2048 | + init_max_tensor = tf.fill([self.num_features], init_max_tensor) |
| 2049 | + # Convert tensors to numpy arrays, which are acceptable by tf.constant_initializer. |
| 2050 | + init_min_value = ( |
| 2051 | + init_min_tensor.numpy() |
| 2052 | + if hasattr(init_min_tensor, "numpy") |
| 2053 | + else init_min_tensor |
| 2054 | + ) |
| 2055 | + init_max_value = ( |
| 2056 | + init_max_tensor.numpy() |
| 2057 | + if hasattr(init_max_tensor, "numpy") |
| 2058 | + else init_max_tensor |
| 2059 | + ) |
| 2060 | + |
| 2061 | + self.learned_min = self.add_weight( |
| 2062 | + name="learned_min", |
| 2063 | + shape=(self.num_features,), |
| 2064 | + initializer=tf.constant_initializer(init_min_value), |
| 2065 | + trainable=True, |
| 2066 | + ) |
| 2067 | + self.learned_max = self.add_weight( |
| 2068 | + name="learned_max", |
| 2069 | + shape=(self.num_features,), |
| 2070 | + initializer=tf.constant_initializer(init_max_value), |
| 2071 | + trainable=True, |
| 2072 | + ) |
| 2073 | + # Gate to combine continuous and discrete branches, shape: (num_features, embedding_dim) |
| 2074 | + self.gate = self.add_weight( |
| 2075 | + name="gate", |
| 2076 | + shape=(self.num_features, self.embedding_dim), |
| 2077 | + initializer="zeros", |
| 2078 | + trainable=True, |
| 2079 | + ) |
| 2080 | + logger.debug( |
| 2081 | + "AdvancedNumericalEmbedding built for {} features with embedding_dim={}", |
| 2082 | + self.num_features, |
| 2083 | + self.embedding_dim, |
| 2084 | + ) |
| 2085 | + super().build(input_shape) |
| 2086 | + |
| 2087 | + def call(self, inputs: tf.Tensor, training: bool = False) -> tf.Tensor: |
| 2088 | + # Continuous branch. |
| 2089 | + inputs_expanded = tf.expand_dims(inputs, axis=-1) # (batch, num_features, 1) |
| 2090 | + cont = self.cont_mlp(inputs_expanded) |
| 2091 | + cont = self.dropout(cont, training=training) |
| 2092 | + if self.use_batch_norm: |
| 2093 | + cont = self.batch_norm(cont, training=training) |
| 2094 | + # Residual connection. |
| 2095 | + cont_res = self.residual_proj(inputs_expanded) |
| 2096 | + cont = cont + cont_res # (batch, num_features, embedding_dim) |
| 2097 | + |
| 2098 | + # Discrete branch. |
| 2099 | + inputs_float = tf.cast(inputs, tf.float32) |
| 2100 | + # Use learned min and max for scaling. |
| 2101 | + scaled = (inputs_float - self.learned_min) / ( |
| 2102 | + self.learned_max - self.learned_min + 1e-6 |
| 2103 | + ) |
| 2104 | + # Compute bin indices. |
| 2105 | + bin_indices = tf.floor(scaled * self.num_bins) |
| 2106 | + bin_indices = tf.cast(bin_indices, tf.int32) |
| 2107 | + bin_indices = tf.clip_by_value(bin_indices, 0, self.num_bins - 1) |
| 2108 | + disc_embeddings = [] |
| 2109 | + for i in range(self.num_features): |
| 2110 | + feat_bins = bin_indices[:, i] # (batch,) |
| 2111 | + feat_embed = self.bin_embeddings[i]( |
| 2112 | + feat_bins |
| 2113 | + ) # i is a Python integer here. |
| 2114 | + disc_embeddings.append(feat_embed) |
| 2115 | + disc = tf.stack(disc_embeddings, axis=1) # (batch, num_features, embedding_dim) |
| 2116 | + |
| 2117 | + # Combine branches via a per-feature, per-dimension gate. |
| 2118 | + gate = tf.nn.sigmoid(self.gate) # (num_features, embedding_dim) |
| 2119 | + output = gate * cont + (1 - gate) * disc # (batch, num_features, embedding_dim) |
| 2120 | + return output |
| 2121 | + |
| 2122 | + def get_config(self): |
| 2123 | + config = super().get_config() |
| 2124 | + config.update( |
| 2125 | + { |
| 2126 | + "embedding_dim": self.embedding_dim, |
| 2127 | + "mlp_hidden_units": self.mlp_hidden_units, |
| 2128 | + "num_bins": self.num_bins, |
| 2129 | + "init_min": self.init_min, |
| 2130 | + "init_max": self.init_max, |
| 2131 | + "dropout_rate": self.dropout_rate, |
| 2132 | + "use_batch_norm": self.use_batch_norm, |
| 2133 | + } |
| 2134 | + ) |
| 2135 | + return config |
0 commit comments