Skip to content

Commit 0c6c65c

Browse files
committed
fix(KDP): add transdormer() method to ProcessingModel
1 parent a4d536c commit 0c6c65c

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,4 @@ my_tests/*
167167

168168
# derivative files
169169
data.csv
170+
sample_data.csv

kdp/processor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from functools import wraps
99
from typing import Any
1010

11+
import pandas as pd
1112
import tensorflow as tf
1213
from loguru import logger
1314

@@ -1435,3 +1436,43 @@ def get_feature_statistics(self) -> dict:
14351436
"feature_crosses": self.feature_crosses,
14361437
"output_mode": self.output_mode,
14371438
}
1439+
1440+
def transform(self, data: tf.data.Dataset | pd.DataFrame | dict) -> dict[str, Any]:
1441+
"""Transform input data using the built preprocessor model.
1442+
1443+
Args:
1444+
data: Input data to transform. Can be a DataFrame, Dataset, or dict.
1445+
1446+
Returns:
1447+
dict[str, Any]: Dictionary containing:
1448+
- transformed_data: The transformed data output
1449+
- {feature_name}_weights: Weight for each feature from feature selection
1450+
1451+
Raises:
1452+
ValueError: If preprocessor hasn't been built yet.
1453+
"""
1454+
# Convert input data to TensorFlow dataset if needed
1455+
if isinstance(data, pd.DataFrame):
1456+
dataset = tf.data.Dataset.from_tensor_slices(dict(data)).batch(32)
1457+
elif isinstance(data, dict):
1458+
dataset = tf.data.Dataset.from_tensor_slices(data).batch(32)
1459+
elif isinstance(data, tf.data.Dataset):
1460+
dataset = data
1461+
else:
1462+
raise ValueError("Input data must be a DataFrame, dict, or TensorFlow Dataset")
1463+
1464+
# Transform the data using the model
1465+
transformed = self.model.predict(dataset)
1466+
1467+
# Initialize return dictionary with transformed data
1468+
result = {"transformed_data": transformed}
1469+
1470+
# Get feature importance from the feature selection layer if it exists
1471+
for layer in self.model.layers:
1472+
if "feature_selection" in layer.name:
1473+
weights = layer.get_weights()
1474+
for i, feature_name in enumerate(self.features_specs.keys()):
1475+
# Add weights for each feature with the expected key format
1476+
result[f"{feature_name}_weights"] = weights[0][:, i]
1477+
1478+
return result

0 commit comments

Comments
 (0)