|
8 | 8 | from functools import wraps |
9 | 9 | from typing import Any |
10 | 10 |
|
| 11 | +import pandas as pd |
11 | 12 | import tensorflow as tf |
12 | 13 | from loguru import logger |
13 | 14 |
|
@@ -1435,3 +1436,43 @@ def get_feature_statistics(self) -> dict: |
1435 | 1436 | "feature_crosses": self.feature_crosses, |
1436 | 1437 | "output_mode": self.output_mode, |
1437 | 1438 | } |
| 1439 | + |
| 1440 | + def transform(self, data: tf.data.Dataset | pd.DataFrame | dict) -> dict[str, Any]: |
| 1441 | + """Transform input data using the built preprocessor model. |
| 1442 | +
|
| 1443 | + Args: |
| 1444 | + data: Input data to transform. Can be a DataFrame, Dataset, or dict. |
| 1445 | +
|
| 1446 | + Returns: |
| 1447 | + dict[str, Any]: Dictionary containing: |
| 1448 | + - transformed_data: The transformed data output |
| 1449 | + - {feature_name}_weights: Weight for each feature from feature selection |
| 1450 | +
|
| 1451 | + Raises: |
| 1452 | + ValueError: If preprocessor hasn't been built yet. |
| 1453 | + """ |
| 1454 | + # Convert input data to TensorFlow dataset if needed |
| 1455 | + if isinstance(data, pd.DataFrame): |
| 1456 | + dataset = tf.data.Dataset.from_tensor_slices(dict(data)).batch(32) |
| 1457 | + elif isinstance(data, dict): |
| 1458 | + dataset = tf.data.Dataset.from_tensor_slices(data).batch(32) |
| 1459 | + elif isinstance(data, tf.data.Dataset): |
| 1460 | + dataset = data |
| 1461 | + else: |
| 1462 | + raise ValueError("Input data must be a DataFrame, dict, or TensorFlow Dataset") |
| 1463 | + |
| 1464 | + # Transform the data using the model |
| 1465 | + transformed = self.model.predict(dataset) |
| 1466 | + |
| 1467 | + # Initialize return dictionary with transformed data |
| 1468 | + result = {"transformed_data": transformed} |
| 1469 | + |
| 1470 | + # Get feature importance from the feature selection layer if it exists |
| 1471 | + for layer in self.model.layers: |
| 1472 | + if "feature_selection" in layer.name: |
| 1473 | + weights = layer.get_weights() |
| 1474 | + for i, feature_name in enumerate(self.features_specs.keys()): |
| 1475 | + # Add weights for each feature with the expected key format |
| 1476 | + result[f"{feature_name}_weights"] = weights[0][:, i] |
| 1477 | + |
| 1478 | + return result |
0 commit comments