
## üìí Notebook Summary: ML Training Data Preparation

This notebook covers the following steps for preparing ML training and validation datasets for product recommendation:

1. **Read Feature Store Data**
   - Loaded candidate features for user-product pairs from the feature store.

2. **Extract Candidate Time**
   - Computed the latest interaction timestamp per user for time-based splitting.

3. **Define Training Cutoff**
   - Established a global cutoff time (e.g., 30 days before the latest event) to separate past and future data.

4. **Generate Labels**
   - Identified future purchases after the cutoff as positive labels.
   - Joined features with labels, assigning 1 for future purchases and 0 otherwise.

5. **Time-Based Train/Validation Split**
   - Split data into train and validation sets based on candidate interaction time (e.g., last 7 days for validation).

6. **Filter Train Users**
   - Removed users from the training set who have no positive labels to ensure meaningful training.

7. **Create Final Datasets**
   - Dropped unnecessary columns and prepared ranking-ready train and validation DataFrames.

8. **Snapshot Data**
   - Saved train and validation datasets as Delta tables with unique snapshot IDs and timestamps for reproducibility.

9. **Sanity Checks**
   - Displayed label distributions for both train and validation sets to verify data balance.

---

**Outcome:**  
You now have reproducible, leakage-free train and validation datasets for ML model development, with all steps documented and data snapshots saved for future reference.

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.functions import current_timestamp, lit
import uuid

spark.conf.set("spark.databricks.remoteFiltering.blockSelfJoins", "false")

gold = "kusha_solutions.product_recomendation"

# ============================================================
# üîí FIXED TIME CONFIG (MATCHES FAKER + FEATURE ENG)
# ============================================================
FEATURE_CUTOFF = "2025-12-01 00:00:00"
VALID_CUTOFF   = "2025-12-20 00:00:00"

print("üìå Feature cutoff :", FEATURE_CUTOFF)
print("üìå Validation cutoff :", VALID_CUTOFF)

# ============================================================
# 1Ô∏è‚É£ READ FROZEN FEATURES (LEAKAGE-FREE)
# ============================================================
features = spark.table(
    "kusha_solutions.product_recomendation.fs_canddiate_features"
)

# ============================================================
# 2Ô∏è‚É£ TRAIN LABELS (PURCHASES BETWEEN CUToffs)
# ============================================================
train_purchases = (
    spark.table(f"{gold}.gold_sales_enriched")
         .filter(F.col("EventTime") >= FEATURE_CUTOFF)
         .filter(F.col("EventTime") < VALID_CUTOFF)
         .filter(F.lower(F.col("InteractionType")) == "purchase")
         .select("CustomerID", "ProductID")
         .distinct()
         .withColumn("label", F.lit(1))
)

train_df = (
    features
    .join(train_purchases, ["CustomerID", "ProductID"], "left")
    .withColumn("label", F.coalesce("label", F.lit(0)))
)

# ============================================================
# 3Ô∏è‚É£ VALIDATION LABELS (PURCHASES AFTER VALID CUTOFF)
# ============================================================
valid_purchases = (
    spark.table(f"{gold}.gold_sales_enriched")
         .filter(F.col("EventTime") >= VALID_CUTOFF)
         .filter(F.lower(F.col("InteractionType")) == "purchase")
         .select("CustomerID", "ProductID")
         .distinct()
         .withColumn("label", F.lit(1))
)

valid_df = (
    features
    .join(valid_purchases, ["CustomerID", "ProductID"], "left")
    .withColumn("label", F.coalesce("label", F.lit(0)))
)

# ============================================================
# 4Ô∏è‚É£ SANITY CHECK
# ============================================================
print("üìä TRAIN label distribution")
train_df.groupBy("label").count().show()

print("üìä VALID label distribution")
valid_df.groupBy("label").count().show()

# ============================================================
# 5Ô∏è‚É£ SNAPSHOT TRAIN & VALID (REPRODUCIBLE)
# ============================================================
snapshot_id = str(uuid.uuid4())

train_snapshot = (
    train_df
      .withColumn("snapshot_id", lit(snapshot_id))
      .withColumn("snapshot_ts", current_timestamp())
      .withColumn("feature_cutoff", lit(FEATURE_CUTOFF))
      .withColumn("valid_cutoff", lit(VALID_CUTOFF))
)

train_snapshot.write \
    .mode("append") \
    .format("delta") \
    .saveAsTable(f"{gold}.ml_train_snapshot")

valid_snapshot = (
    valid_df
      .withColumn("snapshot_id", lit(snapshot_id))
      .withColumn("snapshot_ts", current_timestamp())
      .withColumn("feature_cutoff", lit(FEATURE_CUTOFF))
      .withColumn("valid_cutoff", lit(VALID_CUTOFF))
)

valid_snapshot.write \
    .mode("append") \
    .format("delta") \
    .saveAsTable(f"{gold}.ml_valid_snapshot")

print("‚úÖ Time-based train & validation snapshots saved")
print("üìå Snapshot ID:", snapshot_id)
