1818
1919class PreprocessorLayerFactory :
2020 @staticmethod
21- def create_layer (layer_class : str | object , name : str = None , ** kwargs ) -> tf .keras .layers .Layer :
21+ def create_layer (
22+ layer_class : str | object , name : str = None , ** kwargs
23+ ) -> tf .keras .layers .Layer :
2224 """Create a layer using the layer class name, automatically filtering kwargs based on the layer class.
2325
2426 Args:
@@ -39,7 +41,9 @@ def create_layer(layer_class: str | object, name: str = None, **kwargs) -> tf.ke
3941 constructor_params = inspect .signature (layer_class .__init__ ).parameters
4042
4143 # Filter kwargs to only include those that the constructor can accept
42- filtered_kwargs = {key : value for key , value in kwargs .items () if key in constructor_params }
44+ filtered_kwargs = {
45+ key : value for key , value in kwargs .items () if key in constructor_params
46+ }
4347
4448 # Add the 'name' argument if provided else default the class name lowercase option
4549 if name :
@@ -86,7 +90,9 @@ def distribution_aware_encoder(
8690 )
8791
8892 @staticmethod
89- def text_preprocessing_layer (name : str = "text_preprocessing" , ** kwargs : dict ) -> tf .keras .layers .Layer :
93+ def text_preprocessing_layer (
94+ name : str = "text_preprocessing" , ** kwargs : dict
95+ ) -> tf .keras .layers .Layer :
9096 """Create a TextPreprocessingLayer layer.
9197
9298 Args:
@@ -103,7 +109,9 @@ def text_preprocessing_layer(name: str = "text_preprocessing", **kwargs: dict) -
103109 )
104110
105111 @staticmethod
106- def cast_to_float32_layer (name : str = "cast_to_float32" , ** kwargs : dict ) -> tf .keras .layers .Layer :
112+ def cast_to_float32_layer (
113+ name : str = "cast_to_float32" , ** kwargs : dict
114+ ) -> tf .keras .layers .Layer :
107115 """Create a CastToFloat32Layer layer.
108116
109117 Args:
@@ -120,7 +128,9 @@ def cast_to_float32_layer(name: str = "cast_to_float32", **kwargs: dict) -> tf.k
120128 )
121129
122130 @staticmethod
123- def date_parsing_layer (name : str = "date_parsing_layer" , ** kwargs : dict ) -> tf .keras .layers .Layer :
131+ def date_parsing_layer (
132+ name : str = "date_parsing_layer" , ** kwargs : dict
133+ ) -> tf .keras .layers .Layer :
124134 """Create a DateParsingLayer layer.
125135
126136 Args:
@@ -137,7 +147,9 @@ def date_parsing_layer(name: str = "date_parsing_layer", **kwargs: dict) -> tf.k
137147 )
138148
139149 @staticmethod
140- def date_encoding_layer (name : str = "date_encoding_layer" , ** kwargs : dict ) -> tf .keras .layers .Layer :
150+ def date_encoding_layer (
151+ name : str = "date_encoding_layer" , ** kwargs : dict
152+ ) -> tf .keras .layers .Layer :
141153 """Create a DateEncodingLayer layer.
142154
143155 Args:
@@ -154,7 +166,9 @@ def date_encoding_layer(name: str = "date_encoding_layer", **kwargs: dict) -> tf
154166 )
155167
156168 @staticmethod
157- def date_season_layer (name : str = "date_season_layer" , ** kwargs : dict ) -> tf .keras .layers .Layer :
169+ def date_season_layer (
170+ name : str = "date_season_layer" , ** kwargs : dict
171+ ) -> tf .keras .layers .Layer :
158172 """Create a SeasonLayer layer.
159173
160174 Args:
@@ -171,7 +185,9 @@ def date_season_layer(name: str = "date_season_layer", **kwargs: dict) -> tf.ker
171185 )
172186
173187 @staticmethod
174- def transformer_block_layer (name : str = "transformer" , ** kwargs : dict ) -> tf .keras .layers .Layer :
188+ def transformer_block_layer (
189+ name : str = "transformer" , ** kwargs : dict
190+ ) -> tf .keras .layers .Layer :
175191 """Create a TransformerBlock layer.
176192
177193 Args:
@@ -241,7 +257,9 @@ def multi_resolution_attention_layer(
241257 )
242258
243259 @staticmethod
244- def variable_selection_layer (name : str = "variable_selection" , ** kwargs : dict ) -> tf .keras .layers .Layer :
260+ def variable_selection_layer (
261+ name : str = "variable_selection" , ** kwargs : dict
262+ ) -> tf .keras .layers .Layer :
245263 """Create a VariableSelection layer.
246264
247265 Args:
0 commit comments