Skip to content

Commit 664023f

Browse files
committed
fix(KDP): Added get_feature_importances() method and fixed the docs.
1 parent 11f258d commit 664023f

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

.gitignore

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,4 @@ kdp/data/fake_data.csv
166166
my_tests/*
167167

168168
# derivative files
169-
data.csv
170-
sample_data.csv
171-
stats.json
169+
*.csv

complex_model.png

-311 KB
Binary file not shown.

docs/complex_example.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,16 @@ Now if one wants to plot the a block diagram of the model or get the outout of t
130130
# Plot the model architecture
131131
ppr.plot_model("complex_model.png")
132132

133-
# Get predictions with an example test batch from the example data
134-
processed_data = ppr.transform(test_batch) # this returns a dict with "transformed_data" and "feature_weights"
135-
print("Output shape:", processed_data["transformed_data"].shape)
136-
137-
# Analyze feature importance if feature selection is enabled
138-
if "feature_weights" in processed_data:
139-
for feature_name in features:
140-
weights = processed_data[f"{feature_name}_weights"]
141-
print(f"Feature {feature_name} importance: {weights.mean()}")
133+
# Transform data using direct model prediction
134+
transformed_data = ppr.model.predict(test_batch)
135+
136+
# Transform data using batch_predict
137+
transformed_data = ppr.batch_predict(test_batch)
138+
transformed_batches = list(transformed_data) # For better visualization
139+
140+
# Get feature importances
141+
feature_importances = ppr.get_feature_importances()
142+
print("Feature importances:", feature_importances)
142143
```
143144

144145

kdp/processor.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,16 +1459,23 @@ def _convert_to_dataset(self, data: tf.data.Dataset | pd.DataFrame | dict) -> tf
14591459
else:
14601460
raise ValueError("Input data must be a DataFrame, dict, or TensorFlow Dataset")
14611461

1462-
def _extract_feature_weights(self) -> dict[str, np.ndarray]:
1463-
"""Extract feature importance weights from feature selection layers.
1462+
def get_feature_importances(self) -> dict[str, float]:
1463+
"""Get feature importance scores from feature selection layers.
14641464
14651465
Returns:
1466-
dict[str, np.ndarray]: Dictionary mapping feature names to their importance weights.
1466+
dict[str, float]: Dictionary mapping feature names to their importance scores,
1467+
where scores are averaged across all dimensions.
14671468
"""
1468-
weights = {}
1469+
feature_importances = {}
1470+
14691471
for layer in self.model.layers:
14701472
if "feature_selection" in layer.name:
14711473
layer_weights = layer.get_weights()
14721474
for i, feature_name in enumerate(self.features_specs.keys()):
1473-
weights[f"{feature_name}_weights"] = layer_weights[0][:, i]
1474-
return weights
1475+
weights = layer_weights[0][:, i]
1476+
feature_importances[feature_name] = float(np.mean(weights))
1477+
1478+
if not feature_importances:
1479+
logger.warning("No feature selection layers found in the model")
1480+
1481+
return feature_importances

0 commit comments

Comments
 (0)