Skip to content

Commit d387826

Browse files
committed
fix(KDP): changed the order of the transormers and the tabularAttention applications
1 parent da24b7b commit d387826

File tree

2 files changed

+50
-52
lines changed

2 files changed

+50
-52
lines changed

kdp/processor.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -880,45 +880,6 @@ def _prepare_outputs(self) -> None:
880880
else:
881881
raise ValueError("No features available for concatenation")
882882

883-
# Add transformer blocks if specified
884-
if self.transfo_nr_blocks:
885-
if self.transfo_placement == TransformerBlockPlacementOptions.CATEGORICAL and concat_cat is not None:
886-
logger.info(f"Adding transformer blocks to categorical features: #{self.transfo_nr_blocks}")
887-
transformed = concat_cat
888-
for block_idx in range(self.transfo_nr_blocks):
889-
transformed = PreprocessorLayerFactory.transformer_block_layer(
890-
dim_model=transformed.shape[-1],
891-
num_heads=self.transfo_nr_heads,
892-
ff_units=self.transfo_ff_units,
893-
dropout_rate=self.transfo_dropout_rate,
894-
name=f"transformer_block_{block_idx}_{self.transfo_nr_heads}heads",
895-
)(transformed)
896-
# Reshape transformer output to remove the extra dimension
897-
transformed = tf.keras.layers.Reshape(
898-
target_shape=(-1,), # Flatten to match numeric shape
899-
name="reshape_transformer_output",
900-
)(transformed)
901-
902-
# Recombine with numeric features if they exist
903-
if concat_num is not None:
904-
self.concat_all = tf.keras.layers.Concatenate(
905-
name="ConcatenateTransformed",
906-
axis=-1,
907-
)([concat_num, transformed])
908-
else:
909-
self.concat_all = transformed
910-
911-
elif self.transfo_placement == TransformerBlockPlacementOptions.ALL_FEATURES:
912-
logger.info(f"Adding transformer blocks to all features: #{self.transfo_nr_blocks}")
913-
for block_idx in range(self.transfo_nr_blocks):
914-
self.concat_all = PreprocessorLayerFactory.transformer_block_layer(
915-
dim_model=self.concat_all.shape[-1],
916-
num_heads=self.transfo_nr_heads,
917-
ff_units=self.transfo_ff_units,
918-
dropout_rate=self.transfo_dropout_rate,
919-
name=f"transformer_block_{block_idx}_{self.transfo_nr_heads}heads",
920-
)(self.concat_all)
921-
922883
# Add tabular attention if specified
923884
if self.tabular_attention:
924885
if self.tabular_attention_placement == TabularAttentionPlacementOptions.MULTI_RESOLUTION:
@@ -1047,6 +1008,45 @@ def _prepare_outputs(self) -> None:
10471008
else:
10481009
self.concat_all = concat_cat
10491010

1011+
# Add transformer blocks if specified
1012+
if self.transfo_nr_blocks:
1013+
if self.transfo_placement == TransformerBlockPlacementOptions.CATEGORICAL and concat_cat is not None:
1014+
logger.info(f"Adding transformer blocks to categorical features: #{self.transfo_nr_blocks}")
1015+
transformed = concat_cat
1016+
for block_idx in range(self.transfo_nr_blocks):
1017+
transformed = PreprocessorLayerFactory.transformer_block_layer(
1018+
dim_model=transformed.shape[-1],
1019+
num_heads=self.transfo_nr_heads,
1020+
ff_units=self.transfo_ff_units,
1021+
dropout_rate=self.transfo_dropout_rate,
1022+
name=f"transformer_block_{block_idx}_{self.transfo_nr_heads}heads",
1023+
)(transformed)
1024+
# Reshape transformer output to remove the extra dimension
1025+
transformed = tf.keras.layers.Reshape(
1026+
target_shape=(-1,), # Flatten to match numeric shape
1027+
name="reshape_transformer_output",
1028+
)(transformed)
1029+
1030+
# Recombine with numeric features if they exist
1031+
if concat_num is not None:
1032+
self.concat_all = tf.keras.layers.Concatenate(
1033+
name="ConcatenateTransformed",
1034+
axis=-1,
1035+
)([concat_num, transformed])
1036+
else:
1037+
self.concat_all = transformed
1038+
1039+
elif self.transfo_placement == TransformerBlockPlacementOptions.ALL_FEATURES:
1040+
logger.info(f"Adding transformer blocks to all features: #{self.transfo_nr_blocks}")
1041+
for block_idx in range(self.transfo_nr_blocks):
1042+
self.concat_all = PreprocessorLayerFactory.transformer_block_layer(
1043+
dim_model=self.concat_all.shape[-1],
1044+
num_heads=self.transfo_nr_heads,
1045+
ff_units=self.transfo_ff_units,
1046+
dropout_rate=self.transfo_dropout_rate,
1047+
name=f"transformer_block_{block_idx}_{self.transfo_nr_heads}heads",
1048+
)(self.concat_all)
1049+
10501050
logger.info("Concatenating outputs mode enabled")
10511051
else:
10521052
# Dictionary mode

test/test_processor.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,7 @@ def test_preprocessor_all_features_with_transformer_and_attention(self):
10041004
overwrite_stats=True,
10051005
output_mode=OutputModeOptions.CONCAT,
10061006
tabular_attention=True,
1007+
tabular_attention_placement="all_features",
10071008
tabular_attention_heads=4,
10081009
tabular_attention_dim=64,
10091010
transfo_nr_blocks=2,
@@ -1025,7 +1026,7 @@ def test_preprocessor_all_features_with_transformer_and_attention(self):
10251026

10261027
# Check output dimensions
10271028
self.assertEqual(len(preprocessed.shape), 2) # (batch_size, d_model)
1028-
self.assertEqual(preprocessed.shape[-1], 64) # Example dimension
1029+
self.assertEqual(preprocessed.shape[-1], 1 + 4 + 12) # The dimensions for num1, cat1, date1
10291030

10301031
def test_preprocessor_all_features_with_transformer_and_attention_v2(self):
10311032
"""Test all feature types with both transformer and attention."""
@@ -1050,6 +1051,7 @@ def test_preprocessor_all_features_with_transformer_and_attention_v2(self):
10501051
features_stats_path=self.features_stats_path,
10511052
overwrite_stats=True,
10521053
output_mode=OutputModeOptions.CONCAT,
1054+
tabular_attention_placement="all_features",
10531055
tabular_attention=True,
10541056
tabular_attention_heads=4,
10551057
tabular_attention_dim=64,
@@ -1072,7 +1074,7 @@ def test_preprocessor_all_features_with_transformer_and_attention_v2(self):
10721074

10731075
# Check output dimensions
10741076
self.assertEqual(len(preprocessed.shape), 2) # (batch_size, d_model)
1075-
self.assertEqual(preprocessed.shape[-1], 64) # Example dimension
1077+
self.assertEqual(preprocessed.shape[-1], 65) # Example dimension
10761078

10771079
def test_preprocessor_all_features_with_transformer_and_attention_v3(self):
10781080
"""Test all feature types with both transformer and attention."""
@@ -1122,7 +1124,7 @@ def test_preprocessor_all_features_with_transformer_and_attention_v3(self):
11221124
self.assertIsNotNone(preprocessed)
11231125

11241126
# Check output dimensions
1125-
self.assertEqual(len(preprocessed.shape), 2) # (batch_size, d_model)
1127+
self.assertEqual(len(preprocessed.shape), 3) # (batch_size, d_model)
11261128
self.assertEqual(preprocessed.shape[-1], 23) # Example dimension
11271129

11281130
def test_preprocessor_all_features_with_transformer_and_attention_v4(self):
@@ -1288,7 +1290,7 @@ def test_preprocessor_parameter_combinations(self):
12881290
},
12891291
{
12901292
"tabular_attention": True,
1291-
"tabular_attention_placement": "categorical",
1293+
"tabular_attention_placement": "all_features",
12921294
"tabular_attention_heads": 2,
12931295
"tabular_attention_dim": 32,
12941296
"tabular_attention_dropout": 0.1,
@@ -1363,17 +1365,13 @@ def test_preprocessor_parameter_combinations(self):
13631365

13641366
if test_case["output_mode"] == OutputModeOptions.CONCAT:
13651367
if test_case["tabular_attention"] == True:
1366-
# Check output dimensions for concatenated output
1367-
self.assertEqual(len(preprocessed.shape), 2) # (batch_size, d_model)
1368-
self.assertEqual(
1369-
preprocessed.shape[-1], test_case["tabular_attention_dim"]
1370-
) # Example dimension
1368+
# Check output dimensions for concatenated output with attention
1369+
self.assertEqual(len(preprocessed.shape), 3) # (batch_size, d_model)
1370+
self.assertEqual(preprocessed.shape[-1], test_case["tabular_attention_dim"])
13711371
else:
1372-
# Check output dimensions for concatenated output
1372+
# Check output dimensions for concatenated output without attention
13731373
self.assertEqual(len(preprocessed.shape), 2) # (batch_size, d_model)
1374-
self.assertEqual(
1375-
preprocessed.shape[-1], 65
1376-
) # The dimension of these features at the end should be 65
1374+
self.assertEqual(preprocessed.shape[-1], 65) # Base feature dimension
13771375
else:
13781376
# Check output dimensions for dictionary output
13791377
for key, tensor in preprocessed.items():

0 commit comments

Comments
 (0)