Skip to content

Commit 16340f2

Browse files
committed
fix(KDP): add new examples for tabular attention cases and more complex Mixed Transformers and TabularAttention
1 parent 34476c6 commit 16340f2

File tree

6 files changed

+166
-0
lines changed

6 files changed

+166
-0
lines changed

docs/complex_example.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# 📚 Complex Example 🌟
2+
3+
This example shows how to create a compound model with both transformer blocks and attention mechanisms.
4+
5+
```python
6+
import pandas as pd
7+
import tensorflow as tf
8+
from kdp.features import (
9+
NumericalFeature,
10+
CategoricalFeature,
11+
TextFeature,
12+
DateFeature,
13+
FeatureType
14+
)
15+
from kdp.processor import PreprocessingModel, OutputModeOptions
16+
17+
# Define features
18+
features = {
19+
# Numerical features
20+
"price": NumericalFeature(
21+
name="price",
22+
feature_type=FeatureType.FLOAT_NORMALIZED
23+
),
24+
"quantity": NumericalFeature(
25+
name="quantity",
26+
feature_type=FeatureType.FLOAT_RESCALED
27+
),
28+
29+
# Categorical features
30+
"category": CategoricalFeature(
31+
name="category",
32+
feature_type=FeatureType.STRING_CATEGORICAL,
33+
embedding_size=32
34+
),
35+
"brand": CategoricalFeature(
36+
name="brand",
37+
feature_type=FeatureType.STRING_CATEGORICAL,
38+
embedding_size=16
39+
),
40+
41+
# Text features
42+
"description": TextFeature(
43+
name="description",
44+
feature_type=FeatureType.TEXT,
45+
max_tokens=100
46+
),
47+
"title": TextFeature(
48+
name="title",
49+
feature_type=FeatureType.TEXT,
50+
max_tokens=50, # max number of tokens to keep
51+
),
52+
53+
# Date features
54+
"sale_date": DateFeature(
55+
name="sale_date",
56+
feature_type=FeatureType.DATE,
57+
add_season=True, # adds one-hot season indicator (summer, winter, etc) defaults to False
58+
)
59+
}
60+
61+
# Create sample data
62+
df = pd.DataFrame({
63+
"price": [10.5, 20.0, 15.75, 30.25, 25.50] * 20,
64+
"quantity": [5, 10, 3, 8, 12] * 20,
65+
"category": ["electronics", "books", "clothing", "food", "toys"] * 20,
66+
"brand": ["brandA", "brandB", "brandC", "brandD", "brandE"] * 20,
67+
"description": [
68+
"High quality product with great features",
69+
"Must-read book for enthusiasts",
70+
"Comfortable and stylish clothing",
71+
"Fresh and organic produce",
72+
"Educational toy for children"
73+
] * 20,
74+
"title": [
75+
"Premium Device",
76+
"Best Seller Book",
77+
"Fashion Item",
78+
"Organic Food",
79+
"Kids Toy"
80+
] * 20,
81+
"sale_date": [
82+
"2023-01-15",
83+
"2023-02-20",
84+
"2023-03-25",
85+
"2023-04-30",
86+
"2023-05-05"
87+
] * 20
88+
})
89+
90+
# Save to CSV
91+
df.to_csv("sample_data.csv", index=False)
92+
93+
# Create preprocessor with both transformer blocks and attention
94+
ppr = PreprocessingModel(
95+
path_data="sample_data.csv",
96+
features_specs=features,
97+
output_mode=OutputModeOptions.CONCAT,
98+
99+
# Transformer block configuration
100+
transfo_placement="all_features", # Choose between (categorical|all_features)
101+
transfo_nr_blocks=2, # Number of transformer blocks
102+
transfo_nr_heads=4, # Number of attention heads in transformer
103+
transfo_ff_units=64, # Feed-forward units in transformer
104+
transfo_dropout_rate=0.1, # Dropout rate for transformer
105+
106+
# Tabular attention configuration
107+
tabular_attention=True,
108+
tabular_attention_placement="all_features", # Choose between (none|numeric|categorical|all_features| multi_resolution)
109+
tabular_attention_heads=3, # Number of attention heads
110+
tabular_attention_dim=32, # Attention dimension
111+
tabular_attention_dropout=0.1, # Attention dropout rate
112+
tabular_attention_embedding_dim=16, # Embedding dimension
113+
114+
# Other parameters
115+
overwrite_stats=True, # Force stats generation, recommended to be set to True
116+
)
117+
118+
# Build the preprocessor
119+
result = ppr.build_preprocessor()
120+
```
121+
122+
Now if one wants to plot, use the Neural Network for predictions or just get the statistics, use the following:
123+
124+
```python
125+
# Plot the model architecture
126+
ppr.plot_model("complex_model.png")
127+
128+
# Get predictions with an example test batch from the example data
129+
test_batch = tf.data.Dataset.from_tensor_slices(dict(df.head(3))).batch(3)
130+
predictions = result["model"].predict(test_batch)
131+
print("Output shape:", predictions.shape)
132+
133+
# Print feature statistics
134+
print("\nFeature Statistics:")
135+
for feature_type, features in ppr.get_feature_statistics().items():
136+
if isinstance(features, dict):
137+
print(f"\n{feature_type}:")
138+
for feature_name, stats in features.items():
139+
print(f" {feature_name}: {list(stats.keys())}")
140+
```
141+
142+
143+
Here is the plot of the model:
144+
![Complex Model](imgs/complex_model.png)
209 KB
Loading
214 KB
Loading
200 KB
Loading

docs/imgs/complex_model.png

275 KB
Loading

docs/tabular_attention.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ model = PreprocessingModel(
3434
)
3535
```
3636

37+
![Standard TabularAttention](imgs/attention_example_standard.png)
38+
39+
### Categorical Tabular Attention
40+
41+
```python
42+
from kdp.processor import PreprocessingModel, TabularAttentionPlacementOptions
43+
44+
model = PreprocessingModel(
45+
# ... other parameters ...
46+
tabular_attention=True,
47+
tabular_attention_heads=4,
48+
tabular_attention_dim=64,
49+
tabular_attention_dropout=0.1,
50+
tabular_attention_embedding_dim=32, # Dimension for categorical embeddings
51+
tabular_attention_placement=TabularAttentionPlacementOptions.CATEGORICAL.value,
52+
)
53+
```
54+
55+
![Categorical TabularAttention](imgs/attention_example_categorical.png)
56+
3757
### Multi-Resolution TabularAttention
3858

3959
```python
@@ -50,6 +70,8 @@ model = PreprocessingModel(
5070
)
5171
```
5272

73+
![Multi-Resolution TabularAttention](imgs/attention_example_multi_resolution.png)
74+
5375
## Configuration Options
5476

5577
### Common Options

0 commit comments

Comments
 (0)