Skip to content

Commit 448f63f

Browse files
committed
feat(KDP): smart processing for custom pipelines
1 parent c18a59b commit 448f63f

File tree

2 files changed

+92
-12
lines changed

2 files changed

+92
-12
lines changed

kdp/dynamic_pipeline.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
class DynamicPreprocessingPipeline:
2+
"""
3+
Dynamically initializes a sequence of Keras preprocessing layers based on the output
4+
from each previous layer, allowing each layer to access the outputs of all prior layers where relevant.
5+
"""
6+
7+
def __init__(self, layers):
8+
"""
9+
Initializes the DynamicPreprocessingPipeline with a list of layers.
10+
11+
Args:
12+
layers (list): A list of Keras preprocessing layers, each potentially named for reference.
13+
"""
14+
self.layers = layers
15+
16+
def initialize_and_transform(self, init_data):
17+
"""
18+
Sequentially processes each layer, applying transformations selectively based on each
19+
layer's input requirements and ensuring efficient data usage and processing. Each layer
20+
can access the outputs of all previous layers.
21+
22+
Args:
23+
init_data (dict): A dictionary with initialization data, dynamically keyed.
24+
25+
Returns:
26+
dict: The dictionary containing selectively transformed data for each layer.
27+
"""
28+
current_data = init_data
29+
30+
for i, layer in enumerate(self.layers):
31+
# For many layers we may not have a formal input_spec, so assume the layer uses all current data.
32+
required_keys = current_data.keys()
33+
34+
# Prepare input for the current layer based on the determined keys.
35+
# Here, we assume that each layer accepts a dictionary of inputs.
36+
current_input = {k: current_data[k] for k in required_keys}
37+
38+
# Apply transformation: if the layer returns a tensor, wrap it in a dict using the layer name.
39+
transformed_output = layer(current_input)
40+
if not isinstance(transformed_output, dict):
41+
transformed_output = {layer.name: transformed_output}
42+
43+
# Update the current data with the transformed output so that subsequent layers can reuse it.
44+
current_data.update(transformed_output)
45+
46+
return current_data

kdp/pipeline.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from loguru import logger
55

66
from kdp.layers_factory import PreprocessorLayerFactory
7+
from kdp.dynamic_pipeline import DynamicPreprocessingPipeline
78

89

910
class ProcessingStep:
@@ -87,44 +88,77 @@ def transform(self, input_data: tf.Tensor) -> tf.Tensor:
8788

8889

8990
class FeaturePreprocessor:
90-
def __init__(self, name: str) -> None:
91-
"""Initialize a feature preprocessor.
91+
def __init__(self, name: str, use_dynamic: bool = False) -> None:
92+
"""
93+
Initializes a feature preprocessor.
9294
9395
Args:
9496
name (str): The name of the feature preprocessor.
97+
use_dynamic (bool): Whether to use the dynamic preprocessing pipeline.
9598
"""
9699
self.name = name
97-
self.pipeline = Pipeline(name=name)
100+
self.use_dynamic = use_dynamic
101+
if not self.use_dynamic:
102+
self.pipeline = Pipeline(name=name)
103+
else:
104+
self.layers = [] # for dynamic pipeline
98105

99106
def add_processing_step(
100107
self, layer_creator: Callable[..., tf.keras.layers.Layer] = None, **layer_kwargs
101108
) -> None:
102-
"""Add a processing step to the feature preprocessor.
109+
"""
110+
Add a preprocessing layer to the feature preprocessor pipeline.
111+
If using the standard pipeline, a ProcessingStep is added.
112+
Otherwise, the layer is added to a list for dynamic handling.
103113
104114
Args:
105115
layer_creator (Callable[..., tf.keras.layers.Layer]): A callable that creates a layer.
106116
If not provided, the default layer creator is used.
107117
**layer_kwargs: Additional keyword arguments for the layer creator.
108118
"""
109119
layer_creator = layer_creator or PreprocessorLayerFactory.create_layer
110-
step = ProcessingStep(layer_creator=layer_creator, **layer_kwargs)
111-
self.pipeline.add_step(step=step)
120+
if self.use_dynamic:
121+
layer = layer_creator(**layer_kwargs)
122+
logger.info(f"Adding {layer.name} to dynamic preprocessing pipeline")
123+
self.layers.append(layer)
124+
else:
125+
step = ProcessingStep(layer_creator=layer_creator, **layer_kwargs)
126+
self.pipeline.add_step(step=step)
112127

113128
def chain(self, input_layer) -> tf.keras.layers.Layer:
114-
"""Chain the preprocessor's pipeline steps starting from the input layer.
129+
"""
130+
Chains the processing steps starting from the given input_layer.
115131
116-
Args:
117-
input_layer: The input layer to start the chain from.
132+
For a static pipeline, this delegates to the internal Pipeline's chain() method.
133+
For the dynamic pipeline, it constructs the dynamic pipeline on the fly.
118134
"""
119-
return self.pipeline.chain(input_layer)
135+
if not self.use_dynamic:
136+
return self.pipeline.chain(input_layer)
137+
else:
138+
dynamic_pipeline = DynamicPreprocessingPipeline(self.layers)
139+
# In the dynamic case, we use a dict for the input.
140+
output_dict = dynamic_pipeline.initialize_and_transform(
141+
{"input": input_layer}
142+
)
143+
# Return the transformed data at key "input" (or adjust as needed).
144+
return output_dict.get("input", input_layer)
120145

121146
def transform(self, input_data: tf.Tensor) -> tf.Tensor:
122-
"""Apply the feature preprocessor to the input data.
147+
"""
148+
Process the input data through the pipeline.
149+
For the dynamic pipeline, wrap input in a dictionary and extract final output.
123150
124151
Args:
125152
input_data: The input data to process.
126153
127154
Returns:
128155
tf.Tensor: The processed data.
129156
"""
130-
return self.pipeline.transform(input_data)
157+
if not self.use_dynamic:
158+
return self.pipeline.transform(input_data)
159+
else:
160+
dynamic_pipeline = DynamicPreprocessingPipeline(self.layers)
161+
output_dict = dynamic_pipeline.initialize_and_transform(
162+
{"input": input_data}
163+
)
164+
return output_dict.get("input", input_data)

0 commit comments

Comments
 (0)