@@ -880,45 +880,6 @@ def _prepare_outputs(self) -> None:
880880 else :
881881 raise ValueError ("No features available for concatenation" )
882882
883- # Add transformer blocks if specified
884- if self .transfo_nr_blocks :
885- if self .transfo_placement == TransformerBlockPlacementOptions .CATEGORICAL and concat_cat is not None :
886- logger .info (f"Adding transformer blocks to categorical features: #{ self .transfo_nr_blocks } " )
887- transformed = concat_cat
888- for block_idx in range (self .transfo_nr_blocks ):
889- transformed = PreprocessorLayerFactory .transformer_block_layer (
890- dim_model = transformed .shape [- 1 ],
891- num_heads = self .transfo_nr_heads ,
892- ff_units = self .transfo_ff_units ,
893- dropout_rate = self .transfo_dropout_rate ,
894- name = f"transformer_block_{ block_idx } _{ self .transfo_nr_heads } heads" ,
895- )(transformed )
896- # Reshape transformer output to remove the extra dimension
897- transformed = tf .keras .layers .Reshape (
898- target_shape = (- 1 ,), # Flatten to match numeric shape
899- name = "reshape_transformer_output" ,
900- )(transformed )
901-
902- # Recombine with numeric features if they exist
903- if concat_num is not None :
904- self .concat_all = tf .keras .layers .Concatenate (
905- name = "ConcatenateTransformed" ,
906- axis = - 1 ,
907- )([concat_num , transformed ])
908- else :
909- self .concat_all = transformed
910-
911- elif self .transfo_placement == TransformerBlockPlacementOptions .ALL_FEATURES :
912- logger .info (f"Adding transformer blocks to all features: #{ self .transfo_nr_blocks } " )
913- for block_idx in range (self .transfo_nr_blocks ):
914- self .concat_all = PreprocessorLayerFactory .transformer_block_layer (
915- dim_model = self .concat_all .shape [- 1 ],
916- num_heads = self .transfo_nr_heads ,
917- ff_units = self .transfo_ff_units ,
918- dropout_rate = self .transfo_dropout_rate ,
919- name = f"transformer_block_{ block_idx } _{ self .transfo_nr_heads } heads" ,
920- )(self .concat_all )
921-
922883 # Add tabular attention if specified
923884 if self .tabular_attention :
924885 if self .tabular_attention_placement == TabularAttentionPlacementOptions .MULTI_RESOLUTION :
@@ -1047,6 +1008,45 @@ def _prepare_outputs(self) -> None:
10471008 else :
10481009 self .concat_all = concat_cat
10491010
1011+ # Add transformer blocks if specified
1012+ if self .transfo_nr_blocks :
1013+ if self .transfo_placement == TransformerBlockPlacementOptions .CATEGORICAL and concat_cat is not None :
1014+ logger .info (f"Adding transformer blocks to categorical features: #{ self .transfo_nr_blocks } " )
1015+ transformed = concat_cat
1016+ for block_idx in range (self .transfo_nr_blocks ):
1017+ transformed = PreprocessorLayerFactory .transformer_block_layer (
1018+ dim_model = transformed .shape [- 1 ],
1019+ num_heads = self .transfo_nr_heads ,
1020+ ff_units = self .transfo_ff_units ,
1021+ dropout_rate = self .transfo_dropout_rate ,
1022+ name = f"transformer_block_{ block_idx } _{ self .transfo_nr_heads } heads" ,
1023+ )(transformed )
1024+ # Reshape transformer output to remove the extra dimension
1025+ transformed = tf .keras .layers .Reshape (
1026+ target_shape = (- 1 ,), # Flatten to match numeric shape
1027+ name = "reshape_transformer_output" ,
1028+ )(transformed )
1029+
1030+ # Recombine with numeric features if they exist
1031+ if concat_num is not None :
1032+ self .concat_all = tf .keras .layers .Concatenate (
1033+ name = "ConcatenateTransformed" ,
1034+ axis = - 1 ,
1035+ )([concat_num , transformed ])
1036+ else :
1037+ self .concat_all = transformed
1038+
1039+ elif self .transfo_placement == TransformerBlockPlacementOptions .ALL_FEATURES :
1040+ logger .info (f"Adding transformer blocks to all features: #{ self .transfo_nr_blocks } " )
1041+ for block_idx in range (self .transfo_nr_blocks ):
1042+ self .concat_all = PreprocessorLayerFactory .transformer_block_layer (
1043+ dim_model = self .concat_all .shape [- 1 ],
1044+ num_heads = self .transfo_nr_heads ,
1045+ ff_units = self .transfo_ff_units ,
1046+ dropout_rate = self .transfo_dropout_rate ,
1047+ name = f"transformer_block_{ block_idx } _{ self .transfo_nr_heads } heads" ,
1048+ )(self .concat_all )
1049+
10501050 logger .info ("Concatenating outputs mode enabled" )
10511051 else :
10521052 # Dictionary mode
0 commit comments