# Feature 3.2 — Feature Engineering (RFM, Basket Diversity, Cross‑brand)
This notebook implements Feature 3.2: compute RFM and behavioral features, enforce anti‑leakage (as‑of date), run screening (correlation/cardinality), define preprocessing (imputation/log transforms fit on TRAIN only), persist versioned features, and log metadata to MLflow.

Notes:
- Synthetic fallback is included so this runs without data access; switch to Spark tables for real data.
- Outputs: versioned features in artifacts (Parquet) and optional Delta write if Spark is available.

In [None]:
# Imports & setup
import os, json, random, warnings, io
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score
from sklearn.feature_selection import mutual_info_regression

warnings.filterwarnings('ignore')
sns.set(style='whitegrid')
try:
    import mlflow
except Exception:
    mlflow = None

SEED = int(os.environ.get('SEED', 42))
np.random.seed(SEED); random.seed(SEED)
AS_OF_DATE = pd.to_datetime(os.environ.get('AS_OF_DATE', '2024-12-31'))
USE_SYNTHETIC = os.environ.get('USE_SYNTHETIC', 'true').lower() in ('1','true','yes')
FEATURE_VERSION = os.environ.get('FEATURE_VERSION', 'v1')
SOURCE_SNAPSHOT = os.environ.get('SOURCE_SNAPSHOT', datetime.utcnow().strftime('%Y-%m-%d'))
ARTIFACT_DIR = 'artifacts/features'
os.makedirs(ARTIFACT_DIR, exist_ok=True)
print({'seed': SEED, 'as_of_date': str(AS_OF_DATE.date()), 'use_synthetic': USE_SYNTHETIC, 'feature_version': FEATURE_VERSION})

## Load data (synthetic fallback or Spark tables)
Expected minimum columns:
- customers: customer_id, brand
- transactions: customer_id, tx_date, amount, cost, brand, category

In [None]:
def make_synthetic(n_customers=1200, start='2023-01-01', end='2024-12-31'):
    rng = np.random.default_rng(SEED)
    customers = pd.DataFrame({
        'customer_id': np.arange(1, n_customers+1),
        'brand': rng.choice(['Contoso','EuroStyle'], size=n_customers, p=[0.6,0.4])
    })
    dates = pd.date_range(start=start, end=end, freq='D')
    cats = ['Shoes','Apparel','Accessories','Home','Beauty']
    rows = []
    for cid, brand in customers[['customer_id','brand']].itertuples(index=False):
        k = rng.poisson(12)
        if k == 0: continue
        tx_days = np.sort(rng.choice(dates, size=k, replace=False))
        for d in tx_days:
            cat = rng.choice(cats)
            base = 60 if brand=='Contoso' else 55
            amount = float(np.round(rng.normal(base, 22), 2))
            amount = max(5.0, amount)
            cost = float(np.round(amount * rng.uniform(0.5, 0.8), 2))
            rows.append((cid, pd.Timestamp(d), amount, cost, brand, cat))
    tx = pd.DataFrame(rows, columns=['customer_id','tx_date','amount','cost','brand','category'])
    return customers, tx

if USE_SYNTHETIC:
    customers_df, tx_df = make_synthetic()
else:
    try:
        from pyspark.sql import SparkSession
        spark = SparkSession.builder.getOrCreate()
        tx_df = spark.table('silver.transactions').toPandas()
        customers_df = spark.table('silver.customers').toPandas()
    except Exception as e:
        raise RuntimeError('Implement real data loading or set USE_SYNTHETIC=True') from e

customers_df.head(3), tx_df.head(3)

## Anti‑leakage: filter transactions to events on/before AS_OF_DATE

In [None]:
tx_df['tx_date'] = pd.to_datetime(tx_df['tx_date'])
tx_preT = tx_df[tx_df['tx_date'] <= AS_OF_DATE].copy()
post_count = (tx_df['tx_date'] > AS_OF_DATE).sum()
print({'post_as_of_transactions_dropped': int(post_count)})

## RFM features anchored at AS_OF_DATE
- Recency: days since last purchase before/on AS_OF_DATE
- Frequency: count of purchases in last 365 days before AS_OF_DATE
- Monetary: sum and average of amount in last 365 days

In [None]:
win_start = AS_OF_DATE - pd.Timedelta(days=365)
tx_12m = tx_preT[(tx_preT['tx_date'] > win_start)]
last_tx = tx_preT.groupby('customer_id')['tx_date'].max().rename('last_tx')
recency_days = (AS_OF_DATE - last_tx).dt.days.rename('recency_days')
freq_12m = tx_12m.groupby('customer_id').size().rename('freq_12m')
monetary_sum_12m = tx_12m.groupby('customer_id')['amount'].sum().rename('monetary_sum_12m')
monetary_avg_12m = tx_12m.groupby('customer_id')['amount'].mean().rename('monetary_avg_12m')
rfm = customers_df[['customer_id','brand']].set_index('customer_id')
rfm = rfm.join([recency_days, freq_12m, monetary_sum_12m, monetary_avg_12m]).fillna({
    'recency_days': 1e9, 'freq_12m': 0, 'monetary_sum_12m': 0.0, 'monetary_avg_12m': 0.0
}).reset_index()
rfm.head()

## Basket diversity and cross‑brand features (last 12 months)

In [None]:
div_12m = tx_12m.groupby('customer_id')['category'].nunique().rename('basket_diversity_12m')
brand_counts = tx_preT.groupby(['customer_id'])['brand'].nunique().rename('brand_count_all')
has_both_brands = (brand_counts >= 2).astype(int).rename('has_both_brands')
features = rfm.set_index('customer_id').join([div_12m, has_both_brands]).fillna({'basket_diversity_12m': 0, 'has_both_brands': 0}).reset_index()
features.head()

## Screening: correlation, mutual information, and cardinality

In [None]:
num_cols = ['recency_days','freq_12m','monetary_sum_12m','monetary_avg_12m','basket_diversity_12m','has_both_brands']
corr = features[num_cols].corr()
plt.figure(figsize=(6,4)); sns.heatmap(corr, annot=True, fmt='.2f', cmap='Blues'); plt.title('Correlation heatmap'); plt.tight_layout()
plt.savefig(os.path.join(ARTIFACT_DIR, 'corr_heatmap.png'))
# Fake target for MI example (replace with real task later), using freq as proxy
mi = mutual_info_regression(features[num_cols].fillna(0), features['freq_12m'].fillna(0))
mi_series = pd.Series(mi, index=num_cols).sort_values(ascending=False)
mi_series.to_csv(os.path.join(ARTIFACT_DIR, 'mutual_information.csv'))
card = features.nunique().sort_values(ascending=False)
card.to_csv(os.path.join(ARTIFACT_DIR, 'cardinality.csv'))
print('Saved heatmap and screening metrics to', ARTIFACT_DIR)

## Train‑only preprocessing (spec)
Use splits from Feature 3.1 if available; otherwise, sample a train mask. Fit imputers on TRAIN only and save params.

In [None]:
split_path = 'artifacts/splits/feature_3_1_splits.csv'
if os.path.exists(split_path):
    splits = pd.read_csv(split_path)
    features = features.merge(splits, on='customer_id', how='left')
else:
    features['split'] = np.where(np.random.rand(len(features)) < 0.7, 'train', 'test')

train_mask = features['split'].fillna('train').eq('train')
impute_params = {}
for c in num_cols:
    mean_val = float(features.loc[train_mask, c].fillna(0).mean())
    impute_params[c] = {'strategy': 'mean', 'value': mean_val}
with open(os.path.join(ARTIFACT_DIR, 'impute_params.json'), 'w') as f:
    json.dump(impute_params, f, indent=2)
print('Saved impute params for TRAIN split')

## Persist versioned features (Parquet) and attempt Delta if Spark is available

In [None]:
features['version'] = FEATURE_VERSION
features['created_ts'] = pd.Timestamp.utcnow()
features['source_snapshot'] = SOURCE_SNAPSHOT
features['as_of_date'] = AS_OF_DATE
out_parquet = os.path.join(ARTIFACT_DIR, f'customer_features_{FEATURE_VERSION}.parquet')
features.to_parquet(out_parquet, index=False)
print('Wrote', out_parquet)

# Try Spark + Delta write
try:
    from pyspark.sql import SparkSession
    spark = SparkSession.builder.getOrCreate()
    sdf = spark.createDataFrame(features)
    # Example managed table write (adjust catalog/schema as needed)
    target_table = f'silver.customer_features_{FEATURE_VERSION}'
    sdf.write.mode('overwrite').format('delta').saveAsTable(target_table)
    print('Saved Delta table', target_table)
except Exception as e:
    print('Delta write skipped:', e)

## Log metadata to MLflow (optional)

In [None]:
if mlflow is not None:
    try:
        mlflow.set_experiment('/Feature_3_2_Features')
        with mlflow.start_run(run_name=f'features_{FEATURE_VERSION}'):
            mlflow.log_params({
                'version': FEATURE_VERSION,
                'as_of_date': str(AS_OF_DATE.date()),
                'source_snapshot': SOURCE_SNAPSHOT,
                'feature_count': int(features.shape[1])
            })
            mlflow.log_artifact(os.path.join(ARTIFACT_DIR, 'corr_heatmap.png'))
            mlflow.log_artifact(os.path.join(ARTIFACT_DIR, 'mutual_information.csv'))
            mlflow.log_artifact(os.path.join(ARTIFACT_DIR, 'cardinality.csv'))
            mlflow.log_artifact(os.path.join(ARTIFACT_DIR, 'impute_params.json'))
            mlflow.log_artifact(out_parquet)
        print('Logged to MLflow')
    except Exception as e:
        print('MLflow logging skipped:', e)
else:
    print('MLflow not available')